From 349f88508be3365679f88a4840ab650f8c9f7c6d Mon Sep 17 00:00:00 2001 From: DB Lee Date: Thu, 26 Mar 2026 12:46:02 -0700 Subject: [PATCH] fix: PS5 compat for post_deployment.ps1, blob container auto-create, update tests and code quality for direct resource access --- infra/scripts/post_deployment.ps1 | 36 ++-- .../application/application_configuration.py | 8 +- src/ContentProcessorWorkflow/src/main.py | 13 +- .../src/main_service.py | 10 +- .../src/services/content_process_service.py | 80 ++++----- .../executor/document_process_executor.py | 80 ++++----- .../gap_analysis/executor/gap_executor.py | 12 +- .../src/steps/rai/executor/rai_executor.py | 18 +- .../summarize/executor/summarize_executor.py | 15 +- .../services/test_content_process_service.py | 15 +- .../tests/unit/steps/test_rai_executor.py | 164 ++++-------------- 11 files changed, 167 insertions(+), 284 deletions(-) diff --git a/infra/scripts/post_deployment.ps1 b/infra/scripts/post_deployment.ps1 index 3372fe61..04104a50 100644 --- a/infra/scripts/post_deployment.ps1 +++ b/infra/scripts/post_deployment.ps1 @@ -1,7 +1,7 @@ # Stop script on any error $ErrorActionPreference = "Stop" -Write-Host "🔍 Fetching container app info from azd environment..." +Write-Host "[Search] Fetching container app info from azd environment..." # Load values from azd env $CONTAINER_WEB_APP_NAME = azd env get-value CONTAINER_WEB_APP_NAME @@ -24,7 +24,7 @@ $WORKFLOW_APP_PORTAL_URL = "https://portal.azure.com/#resource/subscriptions/$SU # Get the current script's directory $ScriptDir = $PSScriptRoot -# Navigate from infra/scripts → root → src/api/data/data.sh +# Navigate from infra/scripts -> root -> src/api/data/data.sh $DataScriptPath = Join-Path $ScriptDir "..\..\src\ContentProcessorAPI\samples\schemas" # Resolve to an absolute path @@ -32,25 +32,25 @@ $FullPath = Resolve-Path $DataScriptPath # Output Write-Host "" -Write-Host "🧭 Web App Details:" -Write-Host " ✅ Name: $CONTAINER_WEB_APP_NAME" -Write-Host " 🌐 Endpoint: $CONTAINER_WEB_APP_FQDN" -Write-Host " 🔗 Portal URL: $WEB_APP_PORTAL_URL" +Write-Host "[Info] Web App Details:" +Write-Host " [OK] Name: $CONTAINER_WEB_APP_NAME" +Write-Host " [URL] Endpoint: $CONTAINER_WEB_APP_FQDN" +Write-Host " [Link] Portal URL: $WEB_APP_PORTAL_URL" Write-Host "" -Write-Host "🧭 API App Details:" -Write-Host " ✅ Name: $CONTAINER_API_APP_NAME" -Write-Host " 🌐 Endpoint: $CONTAINER_API_APP_FQDN" -Write-Host " 🔗 Portal URL: $API_APP_PORTAL_URL" +Write-Host "[Info] API App Details:" +Write-Host " [OK] Name: $CONTAINER_API_APP_NAME" +Write-Host " [URL] Endpoint: $CONTAINER_API_APP_FQDN" +Write-Host " [Link] Portal URL: $API_APP_PORTAL_URL" Write-Host "" -Write-Host "🧭 Workflow App Details:" -Write-Host " ✅ Name: $CONTAINER_WORKFLOW_APP_NAME" -Write-Host " 🔗 Portal URL: $WORKFLOW_APP_PORTAL_URL" +Write-Host "[Info] Workflow App Details:" +Write-Host " [OK] Name: $CONTAINER_WORKFLOW_APP_NAME" +Write-Host " [Link] Portal URL: $WORKFLOW_APP_PORTAL_URL" Write-Host "" -Write-Host "📦 Registering schemas and creating schema set..." -Write-Host " ⏳ Waiting for API to be ready..." +Write-Host "[Package] Registering schemas and creating schema set..." +Write-Host " [Wait] Waiting for API to be ready..." $MaxRetries = 10 $RetryInterval = 15 @@ -61,14 +61,14 @@ for ($i = 1; $i -le $MaxRetries; $i++) { try { $response = Invoke-WebRequest -Uri "$ApiBaseUrl/schemavault/" -Method GET -UseBasicParsing -TimeoutSec 10 -ErrorAction Stop if ($response.StatusCode -eq 200) { - Write-Host " ✅ API is ready." + Write-Host " [OK] API is ready." $ApiReady = $true break } } catch { - # Ignore – API not ready yet + # Ignore - API not ready yet } - Write-Host " Attempt $i/$MaxRetries – API not ready, retrying in ${RetryInterval}s..." + Write-Host " Attempt $i/$MaxRetries - API not ready, retrying in ${RetryInterval}s..." Start-Sleep -Seconds $RetryInterval } diff --git a/src/ContentProcessorWorkflow/src/libs/application/application_configuration.py b/src/ContentProcessorWorkflow/src/libs/application/application_configuration.py index eaa52dee..0a89e2b4 100644 --- a/src/ContentProcessorWorkflow/src/libs/application/application_configuration.py +++ b/src/ContentProcessorWorkflow/src/libs/application/application_configuration.py @@ -156,8 +156,12 @@ class Configuration(_configuration_base): app_cosmos_container_process: str = Field( default="Processes", alias="APP_COSMOS_CONTAINER_PROCESS" ) - app_storage_blob_url: str = Field(default="", alias="APP_STORAGE_BLOB_URL") - app_storage_queue_url: str = Field(default="", alias="APP_STORAGE_QUEUE_URL") + app_storage_blob_url: str = Field( + default="", alias="APP_STORAGE_BLOB_URL" + ) + app_storage_queue_url: str = Field( + default="", alias="APP_STORAGE_QUEUE_URL" + ) app_message_queue_extract: str = Field( default="content-pipeline-extract-queue", alias="APP_MESSAGE_QUEUE_EXTRACT" ) diff --git a/src/ContentProcessorWorkflow/src/main.py b/src/ContentProcessorWorkflow/src/main.py index 51102912..f0e081fb 100644 --- a/src/ContentProcessorWorkflow/src/main.py +++ b/src/ContentProcessorWorkflow/src/main.py @@ -9,7 +9,6 @@ """ import asyncio -import logging import os from sas.storage.blob.async_helper import AsyncStorageBlobHelper @@ -26,8 +25,6 @@ from services.content_process_service import ContentProcessService from steps.claim_processor import ClaimProcessor -logger = logging.getLogger(__name__) - class Application(ApplicationBase): """Local-development application that runs a single claim workflow. @@ -44,7 +41,10 @@ def __init__(self): def initialize(self): """Bootstrap the application context and register services.""" - logger.info("Application initialized with configuration (secrets redacted)") + print( + "Application initialized with configuration:", + self.application_context.configuration, + ) self.register_services() @@ -58,9 +58,8 @@ def register_services(self): ) ( - self.application_context.add_singleton( - DebuggingMiddleware, DebuggingMiddleware - ) + self.application_context + .add_singleton(DebuggingMiddleware, DebuggingMiddleware) .add_singleton(LoggingFunctionMiddleware, LoggingFunctionMiddleware) .add_singleton(InputObserverMiddleware, InputObserverMiddleware) .add_singleton(Mem0AsyncMemoryManager, Mem0AsyncMemoryManager) diff --git a/src/ContentProcessorWorkflow/src/main_service.py b/src/ContentProcessorWorkflow/src/main_service.py index 2bf7fdf4..9235f728 100644 --- a/src/ContentProcessorWorkflow/src/main_service.py +++ b/src/ContentProcessorWorkflow/src/main_service.py @@ -101,7 +101,10 @@ def initialize(self): Populates the DI container with agent-framework helpers, middlewares, repository services, and the queue-processing service. """ - logger.info("Application initialized with configuration (secrets redacted)") + print( + "Application initialized with configuration:", + self.application_context.configuration, + ) self.register_services() def register_services(self): @@ -114,9 +117,8 @@ def register_services(self): ) ( - self.application_context.add_singleton( - DebuggingMiddleware, DebuggingMiddleware - ) + self.application_context + .add_singleton(DebuggingMiddleware, DebuggingMiddleware) .add_singleton(LoggingFunctionMiddleware, LoggingFunctionMiddleware) .add_singleton(InputObserverMiddleware, InputObserverMiddleware) .add_singleton(Mem0AsyncMemoryManager, Mem0AsyncMemoryManager) diff --git a/src/ContentProcessorWorkflow/src/services/content_process_service.py b/src/ContentProcessorWorkflow/src/services/content_process_service.py index 8e740bf6..245d7638 100644 --- a/src/ContentProcessorWorkflow/src/services/content_process_service.py +++ b/src/ContentProcessorWorkflow/src/services/content_process_service.py @@ -13,7 +13,6 @@ import json import logging import uuid -from collections.abc import Awaitable, Callable from datetime import datetime, timezone from azure.identity import DefaultAzureCredential @@ -86,6 +85,9 @@ def _get_blob_helper(self) -> StorageBlobHelper: account_name=self._config.app_storage_account_name, credential=self._credential, ) + # Ensure the processes container exists (sas-storage does not + # auto-create containers on upload, unlike the API's helper). + self._blob_helper.create_container(self._config.app_cps_processes) return self._blob_helper def _get_queue_client(self) -> QueueClient: @@ -98,6 +100,9 @@ def _get_queue_client(self) -> QueueClient: ) return self._queue_client + # ------------------------------------------------------------------ # + # submit — replaces POST /contentprocessor/submit + # ------------------------------------------------------------------ # async def submit( self, file_bytes: bytes, @@ -115,29 +120,13 @@ async def submit( # 1. Upload file to blob: {cps-processes}/{process_id}/{filename} container_name = self._config.app_cps_processes blob_helper = self._get_blob_helper() - await asyncio.to_thread( - blob_helper.upload_blob, + blob_helper.upload_blob( container_name=container_name, blob_name=f"{process_id}/{filename}", data=file_bytes, ) - # 2. Insert Cosmos record BEFORE enqueuing — the external - # ContentProcessor does find_document({"process_id": ...}) and - # only $set-updates the existing doc. If the doc doesn't exist - # yet, it inserts a duplicate without the "id" field, causing - # get_status (which queries by "id") to always see "processing". - record = ContentProcessRecord( - id=process_id, - process_id=process_id, - processed_file_name=filename, - processed_file_mime_type=mime_type, - status="processing", - imported_time=datetime.now(timezone.utc), - ) - await self._process_repo.add_async(record) - - # 3. Enqueue processing message (after Cosmos record exists) + # 2. Enqueue processing message message = ContentProcessMessage( process_id=process_id, files=[ @@ -171,14 +160,25 @@ async def submit( completed_steps=[], ), ) - await asyncio.to_thread( - self._get_queue_client().send_message, - message.model_dump_json(), + self._get_queue_client().send_message(message.model_dump_json()) + + # 3. Insert initial Cosmos record via sas-cosmosdb + record = ContentProcessRecord( + id=process_id, + process_id=process_id, + processed_file_name=filename, + processed_file_mime_type=mime_type, + status="processing", + imported_time=datetime.now(timezone.utc), ) + await self._process_repo.add_async(record) logger.info("Submitted process %s for file %s", process_id, filename) return process_id + # ------------------------------------------------------------------ # + # get_status — replaces GET /contentprocessor/status/{id} + # ------------------------------------------------------------------ # async def get_status(self, process_id: str) -> dict | None: """Query Cosmos for process status. @@ -194,6 +194,9 @@ async def get_status(self, process_id: str) -> dict | None: "file_name": getattr(record, "processed_file_name", "") or "", } + # ------------------------------------------------------------------ # + # get_processed — replaces GET /contentprocessor/processed/{id} + # ------------------------------------------------------------------ # async def get_processed(self, process_id: str) -> dict | None: """Query Cosmos for the full processed content result. @@ -204,6 +207,9 @@ async def get_processed(self, process_id: str) -> dict | None: return None return record.model_dump() + # ------------------------------------------------------------------ # + # get_steps — replaces GET /contentprocessor/processed/{id}/steps + # ------------------------------------------------------------------ # def get_steps(self, process_id: str) -> list | None: """Download step_outputs.json from blob storage. @@ -219,28 +225,25 @@ def get_steps(self, process_id: str) -> list | None: ) return json.loads(data.decode("utf-8")) except Exception: - logger.debug("step_outputs.json not found for process %s", process_id) + logger.debug( + "step_outputs.json not found for process %s", process_id + ) return None + # ------------------------------------------------------------------ # + # poll_status — replaces the HTTP polling loop + # ------------------------------------------------------------------ # async def poll_status( self, process_id: str, poll_interval_seconds: float = 5.0, timeout_seconds: float = 600.0, - on_status_change: Callable[[str, dict], Awaitable[None]] | None = None, ) -> dict: """Poll Cosmos for status until terminal state or timeout. - Args: - on_status_change: Optional async callback invoked whenever the - status value changes between polls. Receives - ``(new_status, result_dict)``. - Returns the final status dict with keys: status, process_id, file_name. """ elapsed = 0.0 - last_status: str | None = None - result: dict | None = None while elapsed < timeout_seconds: result = await self.get_status(process_id) if result is None: @@ -252,18 +255,6 @@ async def poll_status( } status = result.get("status", "processing") - - if status != last_status: - logger.info( - "Poll status change: process_id=%s %s -> %s", - process_id, - last_status, - status, - ) - last_status = status - if on_status_change is not None: - await on_status_change(status, result) - if status in ("Completed", "Error"): result["terminal"] = True return result @@ -282,6 +273,3 @@ async def poll_status( def close(self): """Release connections.""" self._blob_helper = None - if self._queue_client is not None: - self._queue_client.close() - self._queue_client = None diff --git a/src/ContentProcessorWorkflow/src/steps/document_process/executor/document_process_executor.py b/src/ContentProcessorWorkflow/src/steps/document_process/executor/document_process_executor.py index a42e5705..51fcf349 100644 --- a/src/ContentProcessorWorkflow/src/steps/document_process/executor/document_process_executor.py +++ b/src/ContentProcessorWorkflow/src/steps/document_process/executor/document_process_executor.py @@ -154,14 +154,10 @@ async def handle_execute( ) ) - # Limit concurrency to avoid overwhelming the ContentProcessor. + # Limit concurrency to avoid overwhelming the service max_concurrency = 2 semaphore = asyncio.Semaphore(max_concurrency) - # Serialize Cosmos upserts on the parent Claim_Process document to - # prevent concurrent read-modify-write from reverting status updates. - upsert_lock = asyncio.Lock() - async def _process_one(item) -> dict: async with semaphore: content_type, _ = mimetypes.guess_type(str(item.file_name)) @@ -175,10 +171,16 @@ async def _process_one(item) -> dict: file_bytes = bytes(source_file) metadata_id = ( - item.metadata_id if item.metadata_id else f"Meta-{uuid.uuid4()}" + item.metadata_id + if item.metadata_id + else f"Meta-{uuid.uuid4()}" ) schema_id = str(item.schema_id) + print( + f"Processing document: {item.file_name} with schema_id: {schema_id}" + ) + # Direct submit: blob upload + queue enqueue + cosmos insert process_id = await content_process_service.submit( file_bytes=file_bytes, @@ -189,37 +191,21 @@ async def _process_one(item) -> dict: ) # Upsert initial status to claim process - async with upsert_lock: - await claim_process_repository.Upsert_Content_Process( - process_id=claim_id, - content_process=Content_Process( - process_id=process_id, - file_name=str(item.file_name), - mime_type=content_type or "application/octet-stream", - status="processing", - ), - ) - - # Poll Cosmos directly until terminal status, - # propagating interim step statuses to the claim process. - async def _on_status_change(new_status: str, _result: dict) -> None: - async with upsert_lock: - await claim_process_repository.Upsert_Content_Process( - process_id=claim_id, - content_process=Content_Process( - process_id=process_id, - file_name=str(item.file_name), - mime_type=content_type - or "application/octet-stream", - status=new_status, - ), - ) + await claim_process_repository.Upsert_Content_Process( + process_id=claim_id, + content_process=Content_Process( + process_id=process_id, + file_name=str(item.file_name), + mime_type=content_type or "application/octet-stream", + status="processing", + ), + ) + # Poll Cosmos directly until terminal status poll_result = await content_process_service.poll_status( process_id=process_id, poll_interval_seconds=poll_interval_seconds, timeout_seconds=600.0, - on_status_change=_on_status_change, ) status_text = poll_result.get("status", "Failed") @@ -235,7 +221,9 @@ async def _on_status_change(new_status: str, _result: dict) -> None: process_id ) if isinstance(final_payload, dict): - status_text = final_payload.get("status") or status_text + status_text = ( + final_payload.get("status") or status_text + ) try: schema_score_f = float( final_payload.get("schema_score") or 0.0 @@ -257,20 +245,18 @@ async def _on_status_change(new_status: str, _result: dict) -> None: result_payload = final_payload # Final cosmos upsert with scores - async with upsert_lock: - await claim_process_repository.Upsert_Content_Process( - process_id=claim_id, - content_process=Content_Process( - process_id=process_id, - file_name=str(item.file_name), - mime_type=content_type - or "application/octet-stream", - status=status_text, - schema_score=schema_score_f, - entity_score=entity_score_f, - processed_time=processed_time, - ), - ) + await claim_process_repository.Upsert_Content_Process( + process_id=claim_id, + content_process=Content_Process( + process_id=process_id, + file_name=str(item.file_name), + mime_type=content_type or "application/octet-stream", + status=status_text, + schema_score=schema_score_f, + entity_score=entity_score_f, + processed_time=processed_time, + ), + ) # Map status to HTTP-like code for downstream compatibility if status_text == "Completed": diff --git a/src/ContentProcessorWorkflow/src/steps/gap_analysis/executor/gap_executor.py b/src/ContentProcessorWorkflow/src/steps/gap_analysis/executor/gap_executor.py index 61b9ec60..3f14eaa7 100644 --- a/src/ContentProcessorWorkflow/src/steps/gap_analysis/executor/gap_executor.py +++ b/src/ContentProcessorWorkflow/src/steps/gap_analysis/executor/gap_executor.py @@ -157,7 +157,7 @@ async def handle_execute( extracted_file = ExtractedFile( file_name=document["file_name"], mime_type=document["mime_type"], - extracted_content=json.dumps(processed_output, default=str), + extracted_content=json.dumps(processed_output), ) processed_files.append(extracted_file) @@ -183,12 +183,10 @@ async def handle_execute( ChatMessage( role="user", text="Now analyze the following document extracts:\n\n" - + "\n\n".join( - [ - f"Document: {file.file_name} ({file.mime_type})\nExtracted Values with Schema (JSON):\n{file.extracted_content}" - for file in processed_files - ] - ), + + "\n\n".join([ + f"Document: {file.file_name} ({file.mime_type})\nExtracted Values with Schema (JSON):\n{file.extracted_content}" + for file in processed_files + ]), ) ) diff --git a/src/ContentProcessorWorkflow/src/steps/rai/executor/rai_executor.py b/src/ContentProcessorWorkflow/src/steps/rai/executor/rai_executor.py index 70a74792..32345f09 100644 --- a/src/ContentProcessorWorkflow/src/steps/rai/executor/rai_executor.py +++ b/src/ContentProcessorWorkflow/src/steps/rai/executor/rai_executor.py @@ -77,6 +77,22 @@ async def handle_exectue( result: Workflow_Output, ctx: WorkflowContext[Workflow_Output, Workflow_Output], ) -> None: + """Run Responsible-AI content analysis on extracted documents. + + Steps: + 1. Retrieve document-processing results from the prior executor. + 2. Fetch extraction steps for each successfully processed file. + 3. Concatenate all extracted text and send to the safety classifier. + 4. Block the workflow if content is flagged as unsafe. + + Args: + result: Workflow output accumulated by prior executors. + ctx: Workflow context carrying shared state across executors. + + Raises: + RuntimeError: If content is deemed unsafe by the classifier. + """ + previous_output = next( filter( lambda output: output.step_name == "document_processing", @@ -163,8 +179,6 @@ async def handle_exectue( for file in processed_files ) - # print(f"[For Debuggging]:\n{document_text}\n[/For Debuggging]") - model_response = await agent.run( ChatMessage( role="user", diff --git a/src/ContentProcessorWorkflow/src/steps/summarize/executor/summarize_executor.py b/src/ContentProcessorWorkflow/src/steps/summarize/executor/summarize_executor.py index 1de32aaf..b86ffb76 100644 --- a/src/ContentProcessorWorkflow/src/steps/summarize/executor/summarize_executor.py +++ b/src/ContentProcessorWorkflow/src/steps/summarize/executor/summarize_executor.py @@ -144,8 +144,6 @@ async def handle_execute( ][0]["markdown"], ) processed_files.append(extracted_file) - else: - pass elif document["mime_type"] in ["image/png", "image/jpg", "image/jpeg"]: process_id = document.get("process_id") @@ -164,9 +162,6 @@ async def handle_execute( ) processed_files.append(extracted_file) - else: - pass - agent_framework_helper = self.app_context.get_service(AgentFrameworkHelper) agent_client = await agent_framework_helper.get_client_async("default") @@ -188,12 +183,10 @@ async def handle_execute( model_response = await agent.run( ChatMessage( role="user", - text="Now summarize the following document extracts: : \n\n".join( - [ - f"Document: {file.file_name}\nContent:\n{file.extracted_content}" - for file in processed_files - ] - ), + text="Now summarize the following document extracts: : \n\n".join([ + f"Document: {file.file_name}\nContent:\n{file.extracted_content}" + for file in processed_files + ]), ) ) diff --git a/src/ContentProcessorWorkflow/tests/unit/services/test_content_process_service.py b/src/ContentProcessorWorkflow/tests/unit/services/test_content_process_service.py index 5be8eb73..07405691 100644 --- a/src/ContentProcessorWorkflow/tests/unit/services/test_content_process_service.py +++ b/src/ContentProcessorWorkflow/tests/unit/services/test_content_process_service.py @@ -178,17 +178,12 @@ async def _get_async(pid): svc._process_repo.get_async.side_effect = _get_async - changes: list[str] = [] - - async def _on_change(new_status: str, _result: dict) -> None: - changes.append(new_status) - - await svc.poll_status( + result = await svc.poll_status( "p1", poll_interval_seconds=0.01, - on_status_change=_on_change, ) - assert changes == ["processing", "extract", "Completed"] + assert result["status"] == "Completed" + assert result["terminal"] is True asyncio.run(_run()) @@ -199,15 +194,11 @@ async def _on_change(new_status: str, _result: dict) -> None: class TestClose: def test_releases_resources(self): svc = _make_service() - fake_queue = MagicMock() - svc._queue_client = fake_queue svc._blob_helper = MagicMock() svc.close() assert svc._blob_helper is None - assert svc._queue_client is None - fake_queue.close.assert_called_once() def test_close_idempotent(self): svc = _make_service() diff --git a/src/ContentProcessorWorkflow/tests/unit/steps/test_rai_executor.py b/src/ContentProcessorWorkflow/tests/unit/steps/test_rai_executor.py index 8b682195..ae4cd5b3 100644 --- a/src/ContentProcessorWorkflow/tests/unit/steps/test_rai_executor.py +++ b/src/ContentProcessorWorkflow/tests/unit/steps/test_rai_executor.py @@ -5,7 +5,7 @@ Covers prompt loading (``_load_rai_executor_prompt``), the ``RAIResponse`` Pydantic model, and the ``fetch_processed_steps_result`` -URL-building logic. +direct-resource-access logic. """ from __future__ import annotations @@ -13,7 +13,7 @@ import asyncio import sys from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -46,39 +46,32 @@ class TestRAIResponse: """Tests for the RAIResponse Pydantic model.""" def test_safe_response(self): - resp = RAIResponse(IsNotSafe=False, Reasoning="Content is clean.") + resp = RAIResponse(IsNotSafe=False) assert resp.IsNotSafe is False - assert resp.Reasoning == "Content is clean." def test_unsafe_response(self): - resp = RAIResponse(IsNotSafe=True, Reasoning="Violent language detected.") + resp = RAIResponse(IsNotSafe=True) assert resp.IsNotSafe is True - assert "Violent" in resp.Reasoning - - def test_missing_required_field_raises(self): - with pytest.raises(Exception): - RAIResponse(IsNotSafe=True) # type: ignore[call-arg] def test_missing_is_not_safe_raises(self): with pytest.raises(Exception): - RAIResponse(Reasoning="oops") # type: ignore[call-arg] + RAIResponse() # type: ignore[call-arg] def test_round_trip_serialization(self): - original = RAIResponse(IsNotSafe=False, Reasoning="OK") + original = RAIResponse(IsNotSafe=False) data = original.model_dump() restored = RAIResponse.model_validate(data) assert restored == original def test_json_round_trip(self): - original = RAIResponse(IsNotSafe=True, Reasoning="Blocked") + original = RAIResponse(IsNotSafe=True) json_str = original.model_dump_json() restored = RAIResponse.model_validate_json(json_str) assert restored == original def test_field_types(self): - resp = RAIResponse(IsNotSafe=False, Reasoning="Fine") + resp = RAIResponse(IsNotSafe=False) assert isinstance(resp.IsNotSafe, bool) - assert isinstance(resp.Reasoning, str) # ── Prompt loading ─────────────────────────────────────────────────────────── @@ -101,8 +94,6 @@ def test_prompt_contains_expected_keywords(self): assert "TRUE" in prompt assert "FALSE" in prompt assert "safety" in prompt.lower() - assert "IsNotSafe" in prompt - assert "Reasoning" in prompt assert "document-processing pipeline" in prompt def test_raises_on_missing_file(self): @@ -131,121 +122,38 @@ def test_prompt_is_stripped(self): class TestFetchProcessedStepsResult: - """Tests for RAIExecutor.fetch_processed_steps_result.""" + """Tests for RAIExecutor.fetch_processed_steps_result. + + The method now delegates to ContentProcessService.get_steps() + via app_context instead of using HttpRequestClient. + """ - def _make_executor_with_endpoint(self, endpoint: str) -> RAIExecutor: - """Create a RAIExecutor with a mock app_context returning *endpoint*.""" + def _make_executor_with_mock_service(self, return_value=None): + """Create a RAIExecutor with a mocked ContentProcessService.""" exe = _make_executor() - config = MagicMock() - config.app_cps_content_process_endpoint = endpoint + mock_service = MagicMock() + mock_service.get_steps.return_value = return_value context = MagicMock() - context.configuration = config + context.get_service.return_value = mock_service exe.app_context = context - return exe - - def test_url_with_contentprocessor_suffix(self): - """When endpoint ends with /contentprocessor, use /submit path.""" - exe = self._make_executor_with_endpoint("https://example.com/contentprocessor") - mock_response = MagicMock() - mock_response.status = 200 - mock_response.json.return_value = [{"step_name": "extract"}] - - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch( - "steps.rai.executor.rai_executor.HttpRequestClient", - return_value=mock_client, - ): - result = asyncio.run(exe.fetch_processed_steps_result("proc-123")) - - mock_client.get.assert_called_once_with( - "https://example.com/contentprocessor/submit/proc-123/steps" - ) - assert result == [{"step_name": "extract"}] - - def test_url_without_contentprocessor_suffix(self): - """When endpoint does not end with /contentprocessor, use /contentprocessor/processed.""" - exe = self._make_executor_with_endpoint("https://example.com/api") - mock_response = MagicMock() - mock_response.status = 200 - mock_response.json.return_value = [{"step_name": "map"}] - - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch( - "steps.rai.executor.rai_executor.HttpRequestClient", - return_value=mock_client, - ): - result = asyncio.run(exe.fetch_processed_steps_result("proc-456")) - - mock_client.get.assert_called_once_with( - "https://example.com/api/contentprocessor/processed/proc-456/steps" - ) - assert result == [{"step_name": "map"}] - - def test_returns_none_on_non_200(self): - """Non-200 responses yield None.""" - exe = self._make_executor_with_endpoint("https://example.com/api") - mock_response = MagicMock() - mock_response.status = 404 - - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch( - "steps.rai.executor.rai_executor.HttpRequestClient", - return_value=mock_client, - ): - result = asyncio.run(exe.fetch_processed_steps_result("proc-789")) - + return exe, mock_service + + def test_returns_steps_list(self): + """get_steps returns a list of step dicts.""" + steps = [{"step_name": "extract"}, {"step_name": "map"}] + exe, mock_svc = self._make_executor_with_mock_service(steps) + result = asyncio.run(exe.fetch_processed_steps_result("proc-123")) + mock_svc.get_steps.assert_called_once_with("proc-123") + assert result == steps + + def test_returns_none_when_not_found(self): + """get_steps returns None when blob not found.""" + exe, mock_svc = self._make_executor_with_mock_service(None) + result = asyncio.run(exe.fetch_processed_steps_result("proc-789")) assert result is None - def test_trailing_slash_stripped_from_endpoint(self): - """Trailing slashes on the endpoint are stripped before URL assembly.""" - exe = self._make_executor_with_endpoint("https://example.com/api/") - mock_response = MagicMock() - mock_response.status = 200 - mock_response.json.return_value = [] - - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch( - "steps.rai.executor.rai_executor.HttpRequestClient", - return_value=mock_client, - ): - asyncio.run(exe.fetch_processed_steps_result("proc-000")) - - url_called = mock_client.get.call_args[0][0] - assert "/api/contentprocessor/processed/proc-000/steps" in url_called - assert "//" not in url_called.split("://")[1] - - def test_none_endpoint_handled(self): - """None endpoint defaults to empty string without crashing.""" - exe = self._make_executor_with_endpoint(None) # type: ignore[arg-type] - mock_response = MagicMock() - mock_response.status = 200 - mock_response.json.return_value = [] - - mock_client = AsyncMock() - mock_client.get.return_value = mock_response - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch( - "steps.rai.executor.rai_executor.HttpRequestClient", - return_value=mock_client, - ): - result = asyncio.run(exe.fetch_processed_steps_result("proc-nil")) - + def test_returns_empty_list(self): + """get_steps can return an empty list.""" + exe, mock_svc = self._make_executor_with_mock_service([]) + result = asyncio.run(exe.fetch_processed_steps_result("proc-000")) assert result == []