Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- **`opencontractserver.tasks.*` graduated out of the mypy baseline** (issue #1482, refs #1447). All sixteen task modules (`agent_tasks`, `analyzer_tasks`, `badge_tasks`, `corpus_tasks`, `data_extract_tasks`, `doc_analysis_tasks`, `doc_tasks`, `embeddings_task`, `export_tasks`, `export_tasks_v2`, `extract_orchestrator_tasks`, `fork_tasks`, `import_tasks`, `import_tasks_v2`, `lookup_tasks`, `memory_tasks`) now type-check cleanly. Their `[mypy-…]` blocks are removed from `mypy.ini` and the corresponding 227 lines are pruned from `docs/typing/mypy_baseline.txt`. Additional tightening: `LabelLookupPythonType.{text_labels,doc_labels}` narrowed from `dict[str | int, …]` to `dict[str, …]` (the producer in `utils.etl.build_label_lookups` already emits string keys), and the `corpus_id` parameter on `build_label_lookups` was corrected to `int`. `utils.importing.{load_or_create_labels, import_doc_annotations}` accept `Mapping` instead of `dict` so callers can pass `OpenContractDocExport` / `AnnotationLabelPythonType` TypedDicts without `cast()`.

- **Typing: graduate `opencontractserver.conversations.models` from mypy baseline** (refs Issue #1447, closes #1478): removed the last remaining `[mypy-opencontractserver.conversations.models]` `ignore_errors` block from `mypy.ini` and pruned 29 baseline error lines from `docs/typing/mypy_baseline.txt`. Per-file fixes in `opencontractserver/conversations/models.py`:
- Replaced the `Optional["AbstractBaseUser"]` parameter on `ConversationQuerySet.visible_to_user` and `ChatMessageQuerySet.visible_to_user` with `"UserModel | AnonymousUser | None"` (forward-referenced via `TYPE_CHECKING` import of `opencontractserver.users.models.User as UserModel`). Cleared the 2× `AnonymousUser → AbstractBaseUser | None` assignment errors at the `if user is None: user = AnonymousUser()` reassignment, the 6× `[union-attr]` errors on `.is_superuser` / `.is_anonymous` / `.id`, and the `Incompatible type for lookup 'creator'` error on `Document.objects.filter(creator=user)` (after narrowing past the `is_anonymous` early-return, mypy resolves user to `User`).
- Aligned `SoftDeleteManager.visible_to_user`, `ConversationManager.visible_to_user`, and `ChatMessageManager.visible_to_user` signatures with `BaseVisibilityManager.visible_to_user` (added `lightweight: bool = False` and `with_doc_label_annotations: bool = False` kwargs forwarded to `super()` for parity), eliminating the 3× `Signature of "visible_to_user" incompatible with supertype` `[override]` errors without resorting to `# type: ignore[override]`.
Expand Down
227 changes: 0 additions & 227 deletions docs/typing/mypy_baseline.txt

Large diffs are not rendered by default.

53 changes: 4 additions & 49 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -109,55 +109,10 @@ ignore_errors = True
[mypy-opencontractserver.shared.decorators]
ignore_errors = True

# --- opencontractserver.tasks (16 files) ---

[mypy-opencontractserver.tasks.agent_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.analyzer_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.badge_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.corpus_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.data_extract_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.doc_analysis_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.doc_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.embeddings_task]
ignore_errors = True

[mypy-opencontractserver.tasks.export_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.export_tasks_v2]
ignore_errors = True

[mypy-opencontractserver.tasks.extract_orchestrator_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.fork_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.import_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.import_tasks_v2]
ignore_errors = True

[mypy-opencontractserver.tasks.lookup_tasks]
ignore_errors = True

[mypy-opencontractserver.tasks.memory_tasks]
ignore_errors = True
# --- opencontractserver.tasks ---
#
# Issue #1482 graduated all 16 modules out of the baseline. The package is
# now type-clean.

# --- opencontractserver.tests (206 files) ---

Expand Down
28 changes: 17 additions & 11 deletions opencontractserver/tasks/agent_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import asyncio
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast

from asgiref.sync import async_to_sync
from celery import shared_task
Expand Down Expand Up @@ -172,8 +172,8 @@ def generate_agent_response(

# 4. Generate response with streaming
accumulated_content = ""
sources_data = []
timeline_data = []
sources_data: list[dict[str, Any]] = []
timeline_data: list[dict[str, Any]] = []

async def run_agent():
nonlocal accumulated_content
Expand Down Expand Up @@ -202,7 +202,7 @@ async def run_agent():
# Stream the agent response
# Pass store_messages=False since we handle message persistence ourselves
# (we already created response_message above with parent_message set)
async for event in agent.stream(user_message, store_messages=False):
async for event in await agent.stream(user_message, store_messages=False):
if isinstance(event, ContentEvent):
# Token/content chunk
token = event.content
Expand Down Expand Up @@ -497,7 +497,7 @@ def run_agent_corpus_action(

# Update result record with failure before retrying
try:
result, _ = AgentActionResult.objects.get_or_create(
failure_record, _ = AgentActionResult.objects.get_or_create(
corpus_action_id=corpus_action_id,
document_id=document_id,
defaults={
Expand All @@ -506,10 +506,14 @@ def run_agent_corpus_action(
"started_at": timezone.now(),
},
)
result.status = AgentActionResult.Status.FAILED
result.error_message = str(exc)[:1000] # Truncate to prevent DB bloat
result.completed_at = timezone.now()
result.save(update_fields=["status", "error_message", "completed_at"])
failure_record.status = AgentActionResult.Status.FAILED
failure_record.error_message = str(exc)[
:1000
] # Truncate to prevent DB bloat
failure_record.completed_at = timezone.now()
failure_record.save(
update_fields=["status", "error_message", "completed_at"]
)
except Exception as e:
logger.error(f"[AgentCorpusAction] Failed to mark result as failed: {e}")

Expand Down Expand Up @@ -867,12 +871,14 @@ def get_or_create_and_claim():
# Create agent with pre-authorization (skip approval gate).
# restrict_tool_names limits the agent to ONLY the resolved tool
# set, preventing tool overload from the ~17 default tools.
# ``tools`` is a ``list[str]``; the API accepts any ToolType
# (str | CoreTool | callable) but ``list`` is invariant, so cast.
agent = await agents.for_document(
document=document,
corpus=action.corpus,
user_id=user_id,
system_prompt=system_prompt,
tools=tools,
tools=cast("list[Any]", tools),
streaming=False,
skip_approval_gate=True,
restrict_tool_names=tools,
Expand Down Expand Up @@ -1168,7 +1174,7 @@ def get_or_create_result():
corpus=action.corpus,
user_id=user_id,
system_prompt=system_prompt,
tools=tools,
tools=cast("list[Any]", tools),
streaming=False,
skip_approval_gate=True,
)
Expand Down
6 changes: 3 additions & 3 deletions opencontractserver/tasks/analyzer_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def start_analysis(


@celery_app.task()
def request_gremlin_manifest(gremlin_id: str | int) -> list[AnalyzerManifest]:
def request_gremlin_manifest(gremlin_id: int) -> list[AnalyzerManifest]:
logger.info("request_gremlin_manifest() - Start...")

gremlin = GremlinEngine.objects.get(id=gremlin_id)
Expand All @@ -67,7 +67,7 @@ def request_gremlin_manifest(gremlin_id: str | int) -> list[AnalyzerManifest]:
# )
logger.info("request_gremlin_manifest() - End.")

return analyzer_manifests
return analyzer_manifests or []


@celery_app.task()
Expand Down Expand Up @@ -96,7 +96,7 @@ def mark_analysis_complete(analysis_id: str | int, doc_ids: list[int | str]) ->

analysis = Analysis.objects.get(pk=analysis_id)
analysis.analysis_completed = timezone.now()
analysis.analyzed_documents.add(*doc_ids)
analysis.analyzed_documents.add(*[int(doc_id) for doc_id in doc_ids])
analysis.save()

# Mark any related CorpusActionExecutions as completed
Expand Down
5 changes: 1 addition & 4 deletions opencontractserver/tasks/badge_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
import logging
from typing import Optional

from django.contrib.auth import get_user_model

from config import celery_app
from opencontractserver.badges.models import Badge, BadgeTypeChoices, UserBadge
from opencontractserver.conversations.models import ChatMessage
from opencontractserver.corpuses.models import Corpus
from opencontractserver.users.models import User

logger = logging.getLogger(__name__)

User = get_user_model()


class BadgeCriteriaType:
"""Constants for badge criteria types."""
Expand Down
54 changes: 30 additions & 24 deletions opencontractserver/tasks/corpus_tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from datetime import timedelta
from typing import Any

from celery import chord, group, shared_task
from django.db import transaction
Expand Down Expand Up @@ -77,7 +78,7 @@ def run_task_name_analyzer(
lambda: chord(
group(
[
task_func.s(
task_func.s( # type: ignore[attr-defined]
doc_id=doc_id,
analysis_id=analysis.id,
corpus_id=(
Expand All @@ -103,8 +104,11 @@ def process_analyzer(
analysis_input_data: dict | None = None,
) -> Analysis:

if analyzer is None:
raise ValueError("process_analyzer requires a non-null analyzer")

logger.info(
f"process_analyzer called - user_id: {user_id}, analyzer: {analyzer.id if analyzer else None}"
f"process_analyzer called - user_id: {user_id}, analyzer: {analyzer.id}"
)
logger.info(f"corpus_id: {corpus_id}, document_ids: {document_ids}")
logger.info(f"analysis_input_data: {analysis_input_data}")
Expand Down Expand Up @@ -149,7 +153,7 @@ def process_corpus_action(
document_ids: list[str | int],
user_id: str | int,
trigger: str | None = None,
):
) -> None:
"""
Process corpus actions for given documents with execution tracking.

Expand Down Expand Up @@ -181,6 +185,10 @@ def process_corpus_action(

summary = {"actions_processed": 0, "executions_queued": 0}

# Celery JSON serialisation may widen int → str.
int_document_ids: list[int] = [int(d) for d in document_ids]
int_user_id: int = int(user_id)

for action in actions:
# Create execution records for tracking
# Use trigger or default to add_document for backwards compatibility
Expand All @@ -191,9 +199,9 @@ def process_corpus_action(
with transaction.atomic():
executions = CorpusActionExecution.bulk_queue(
corpus_action=action,
document_ids=document_ids,
document_ids=int_document_ids,
trigger=execution_trigger,
user_id=user_id,
user_id=int_user_id,
)
execution_map = {ex.document_id: ex for ex in executions}

Expand All @@ -209,7 +217,7 @@ def process_corpus_action(
corpus=action.corpus,
name=f"Action {action.name} for {action.corpus.title}",
fieldset=action.fieldset,
creator_id=user_id,
creator_id=int_user_id,
corpus_action=action,
)
extract.started = timezone.now()
Expand All @@ -235,8 +243,8 @@ def process_corpus_action(

fieldset = action.fieldset

for document_id in document_ids:
execution = execution_map.get(document_id)
for document_id in int_document_ids:
doc_execution = execution_map.get(document_id)

with transaction.atomic():
row_results = DocumentAnalysisRow(
Expand All @@ -252,19 +260,19 @@ def process_corpus_action(
extract=extract,
column=column,
data_definition=column.output_type,
creator_id=user_id,
creator_id=int_user_id,
document_id=document_id,
)
set_permissions_for_obj_to_user(
user_id, cell, [PermissionTypes.CRUD]
int_user_id, cell, [PermissionTypes.CRUD]
)

# Add data cell to tracking
row_results.data.add(cell)

# Track affected datacell in execution record
if execution:
execution.add_affected_object(
if doc_execution:
doc_execution.add_affected_object(
"datacell", cell.id, column_name=column.name
)

Expand All @@ -277,11 +285,11 @@ def process_corpus_action(
continue

# Add the task to the group
tasks.append(task_func.si(cell.pk))
tasks.append(task_func.si(cell.pk)) # type: ignore[attr-defined]

# Save updated affected_objects for this execution
if execution:
execution.save(update_fields=["affected_objects"])
if doc_execution:
doc_execution.save(update_fields=["affected_objects"])

# Capture extract_id and execution_ids for the lambda closure
extract_id_for_closure = extract.id
Expand Down Expand Up @@ -332,13 +340,13 @@ def on_commit_callback():
)

# Pass execution_id to agent task for tracking
for document_id in document_ids:
execution = execution_map.get(document_id)
for document_id in int_document_ids:
doc_execution = execution_map.get(document_id)
run_agent_corpus_action.delay(
corpus_action_id=action.id,
document_id=document_id,
user_id=user_id,
execution_id=execution.id if execution else None,
user_id=int_user_id,
execution_id=doc_execution.id if doc_execution else None,
)

else:
Expand All @@ -352,8 +360,6 @@ def on_commit_callback():
f"{summary['executions_queued']} executions queued"
)

return summary


# --------------------------------------------------------------------------- #
# Engagement Metrics Tasks (Epic #565)
Expand Down Expand Up @@ -510,7 +516,7 @@ def update_all_corpus_engagement_metrics():
# Queue individual update tasks
for corpus_id in corpus_ids:
transaction.on_commit(
lambda cid=corpus_id: update_corpus_engagement_metrics.apply_async(
lambda cid=corpus_id: update_corpus_engagement_metrics.apply_async( # type: ignore[misc]
args=[cid]
)
)
Expand Down Expand Up @@ -750,7 +756,7 @@ def ensure_embeddings_for_corpus(
f"corpus={corpus_id}"
)

result = {
result: dict[str, Any] = {
"structural_set_id": structural_set_id,
"corpus_id": corpus_id,
"embedders_checked": [],
Expand Down Expand Up @@ -916,7 +922,7 @@ def reembed_corpus(
f"reembed_corpus() - corpus={corpus_id}, new_embedder={new_embedder_path}"
)

result = {
result: dict[str, Any] = {
"corpus_id": corpus_id,
"new_embedder_path": new_embedder_path,
"total_annotations": 0,
Expand Down
Loading
Loading