From 9f971164524c12dd64bdb1a2dbdd2c4d3dc4a1f7 Mon Sep 17 00:00:00 2001 From: Slava Trofimov <26082149+pmbstyle@users.noreply.github.com> Date: Mon, 18 May 2026 12:03:44 -0400 Subject: [PATCH 1/2] harden worker file-write routing --- src/octopal/cli/branding.py | 5 +- .../providers/litellm_provider.py | 52 +++++-- src/octopal/runtime/workers/agent_worker.py | 131 +++++++++++++++++- src/octopal/runtime/workers/allowed_paths.py | 29 +++- src/octopal/runtime/workers/launcher.py | 8 +- src/octopal/tools/workers/management.py | 74 ++++++++++ src/octopal/utils.py | 2 + tests/test_agent_worker_contracts.py | 29 ++++ tests/test_heartbeat.py | 1 + ...test_litellm_provider_payload_hardening.py | 21 +++ tests/test_worker_launcher_isolation.py | 4 +- tests/test_worker_router.py | 74 ++++++++++ 12 files changed, 402 insertions(+), 28 deletions(-) create mode 100644 tests/test_agent_worker_contracts.py diff --git a/src/octopal/cli/branding.py b/src/octopal/cli/branding.py index f59bbbeb..72e1e70f 100644 --- a/src/octopal/cli/branding.py +++ b/src/octopal/cli/branding.py @@ -35,7 +35,10 @@ def print_banner() -> None: """).strip() output_encoding = (sys.stdout.encoding or "utf-8").lower() - banner_text.encode(output_encoding, errors="strict") + try: + banner_text.encode(output_encoding, errors="strict") + except UnicodeEncodeError: + banner_text = "OCTOPAL" tagline = Text("Your trusted AI pal", style=f"italic {OCTO_SILVER}") subline = Text("SECURE MULTI-AGENT EXECUTION RUNTIME", style=OCTO_WHITE) diff --git a/src/octopal/infrastructure/providers/litellm_provider.py b/src/octopal/infrastructure/providers/litellm_provider.py index 20011eb5..377e6ba7 100644 --- a/src/octopal/infrastructure/providers/litellm_provider.py +++ b/src/octopal/infrastructure/providers/litellm_provider.py @@ -350,7 +350,7 @@ async def complete_with_tools( messages: list[Message | dict], *, tools: list[dict], - tool_choice: str = "auto", + tool_choice: object = "auto", **kwargs: object, ) -> dict: """Complete a chat request with tool/function calling.""" @@ -537,7 +537,7 @@ async def _complete_with_tools_adaptive_response_format( *, messages: list[dict[str, Any]], tools: list[dict[str, Any]], - tool_choice: str, + tool_choice: object, request_kwargs: dict[str, object], ) -> tuple[Any, str]: requested_response_format = request_kwargs.get("response_format") @@ -606,12 +606,16 @@ async def _acompletion_with_resilience(self, **kwargs: object) -> Any: async def _acompletion_guarded(self, **kwargs: object) -> Any: async with self._semaphore: + timeout_seconds = _coerce_timeout_seconds(kwargs.get("timeout")) try: - response = await acompletion( - model=self._model, - api_base=self._api_base, - api_key=self._api_key, - **kwargs, + response = await _await_with_runtime_timeout( + acompletion( + model=self._model, + api_base=self._api_base, + api_key=self._api_key, + **kwargs, + ), + timeout_seconds=timeout_seconds, ) except Exception as exc: if not _is_closed_client_error(exc): @@ -619,17 +623,23 @@ async def _acompletion_guarded(self, **kwargs: object) -> Any: logger.warning( "LiteLLM client was closed mid-request; retrying once with a fresh completion call" ) - response = await acompletion( - model=self._model, - api_base=self._api_base, - api_key=self._api_key, - **kwargs, + response = await _await_with_runtime_timeout( + acompletion( + model=self._model, + api_base=self._api_base, + api_key=self._api_key, + **kwargs, + ), + timeout_seconds=timeout_seconds, ) # LiteLLM can occasionally return a nested awaitable object on # provider-error paths (seen on Python 3.14). Unwrap it to avoid # "coroutine ... was never awaited" warnings and leaked coroutines. while inspect.isawaitable(response): - response = await response + response = await _await_with_runtime_timeout( + response, + timeout_seconds=timeout_seconds, + ) return response @@ -644,6 +654,22 @@ def _serialize_message(message: Message | dict) -> dict: return serialized +def _coerce_timeout_seconds(value: object) -> float | None: + try: + timeout = float(value) if value is not None else None + except (TypeError, ValueError): + return None + if timeout is None or timeout <= 0: + return None + return timeout + + +async def _await_with_runtime_timeout(awaitable: Any, *, timeout_seconds: float | None) -> Any: + if timeout_seconds is None: + return await awaitable + return await asyncio.wait_for(awaitable, timeout=timeout_seconds) + + def _normalize_plain_messages(messages: list[dict[str, Any]]) -> list[dict[str, str]]: normalized: list[dict[str, str]] = [] for message in messages: diff --git a/src/octopal/runtime/workers/agent_worker.py b/src/octopal/runtime/workers/agent_worker.py index e5f50603..70b7665c 100644 --- a/src/octopal/runtime/workers/agent_worker.py +++ b/src/octopal/runtime/workers/agent_worker.py @@ -15,6 +15,7 @@ import json import os import random +import re import time import traceback from pathlib import Path @@ -131,6 +132,50 @@ "start_child_worker", "start_workers_parallel", } +_WRITE_TASK_TOKENS = { + "append", + "create", + "created", + "creates", + "draft", + "edit", + "edits", + "save", + "saved", + "update", + "updates", + "write", + "writes", + "writing", +} +_FILE_TASK_TOKENS = { + "artifact", + "config", + "csv", + "doc", + "document", + "draft", + "file", + "files", + "json", + "markdown", + "md", + "note", + "notes", + "path", + "report", + "text", + "toml", + "workspace", + "yaml", + "yml", +} +_FILE_PATH_HINT_RE = re.compile( + r"(?:^|[\s`'\"])[\w./\\-]+\." + r"(?:cfg|conf|csv|html|ini|json|log|md|py|toml|txt|ya?ml)" + r"(?:$|[\s`'\",.:;])", + re.IGNORECASE, +) def _parse_positive_int_env(name: str, default: int) -> int: @@ -155,6 +200,31 @@ def _parse_nonnegative_int_env(name: str, default: int) -> int: return value if value >= 0 else default +def _tokenize_task(text: str) -> set[str]: + return set(re.findall(r"[a-z0-9_]+", (text or "").lower())) + + +def _task_requires_workspace_write(task: str) -> bool: + tokens = _tokenize_task(task) + return bool(tokens & _WRITE_TASK_TOKENS) and ( + bool(tokens & _FILE_TASK_TOKENS) or bool(_FILE_PATH_HINT_RE.search(task or "")) + ) + + +def _fs_write_completion_missing(task: str, available_tools: list[str], tools_used: list[str]) -> bool: + normalized_available = {str(tool).strip().lower() for tool in available_tools} + normalized_used = {str(tool).strip().lower() for tool in tools_used} + return ( + "fs_write" in normalized_available + and "fs_write" not in normalized_used + and _task_requires_workspace_write(task) + ) + + +def _force_tool_choice(tool_name: str) -> dict[str, dict[str, str] | str]: + return {"type": "function", "function": {"name": tool_name}} + + def _extract_tool_progress_key(tool_name: str | None, tool_result: Any) -> str | None: normalized_tool = str(tool_name or "").strip() structured = _decode_structured_tool_result(tool_result) @@ -710,6 +780,7 @@ async def mcp_proxy_handler(args: dict, ctx: dict, s_id=s_id, t_name=t_name): {tool_descriptions} Use available tools through normal tool calls. Do not emit ad-hoc JSON tool_use blocks. +If the task asks you to create, write, save, update, or edit a workspace file and fs_write is available, you must call fs_write before returning a result. Do not claim a file was written until the fs_write tool returns successfully. {coordination_prompt} @@ -783,8 +854,19 @@ async def mcp_proxy_handler(args: dict, ctx: dict, s_id=s_id, t_name=t_name): while thinking_steps < effective_max_steps: llm_start = time.perf_counter() + force_fs_write = _fs_write_completion_missing( + spec.task, + list(spec.available_tools or []), + tools_used, + ) try: - response = await _call_llm(provider, messages, filtered_tools) + response = await _call_llm( + provider, + messages, + filtered_tools, + tool_choice=_force_tool_choice("fs_write") if force_fs_write else "auto", + response_format_enabled=not force_fs_write, + ) except Exception as exc: telemetry["llm_latency_ms_total"] += int((time.perf_counter() - llm_start) * 1000) error_text = str(exc) @@ -1082,6 +1164,20 @@ async def mcp_proxy_handler(args: dict, ctx: dict, s_id=s_id, t_name=t_name): # Try to parse structured JSON result, including fenced JSON blocks. result_block = _extract_result_block(content) if result_block is not None: + if _fs_write_completion_missing(spec.task, list(spec.available_tools or []), tools_used): + messages.append({"role": "assistant", "content": content}) + messages.append( + { + "role": "user", + "content": ( + "The task requires an actual fs_write tool call before completion. " + "Call fs_write with the requested path and content now, then return " + "the structured result only after fs_write succeeds." + ), + } + ) + thinking_steps += 1 + continue cycle_steps = thinking_steps + 1 return WorkerResult( status=( @@ -1100,6 +1196,20 @@ async def mcp_proxy_handler(args: dict, ctx: dict, s_id=s_id, t_name=t_name): # If model produced plain text with no tool call, treat it as completion. if content: + if _fs_write_completion_missing(spec.task, list(spec.available_tools or []), tools_used): + messages.append({"role": "assistant", "content": content}) + messages.append( + { + "role": "user", + "content": ( + "The task requires an actual fs_write tool call before completion. " + "Call fs_write with the requested path and content now, then return " + "the final answer only after fs_write succeeds." + ), + } + ) + thinking_steps += 1 + continue cycle_steps = thinking_steps + 1 return WorkerResult( summary=content, @@ -1168,6 +1278,9 @@ async def _call_llm( provider: LiteLLMProvider, messages: list[dict], tools: list, + *, + tool_choice: object = "auto", + response_format_enabled: bool = True, ) -> dict: """Call LLM with tools using the centralized provider.""" # Build OpenAI-style tools format @@ -1183,17 +1296,21 @@ async def _call_llm( for t in tools ] - response_format = { - "type": "json_schema", - "json_schema": {"name": "worker_result", "schema": _RESULT_SCHEMA}, - } + response_format = None + if response_format_enabled: + response_format = { + "type": "json_schema", + "json_schema": {"name": "worker_result", "schema": _RESULT_SCHEMA}, + } # Provider handles adaptive response_format downgrade when a route does not # support schema-constrained outputs. + request_kwargs: dict[str, Any] = {"tool_choice": tool_choice} + if response_format is not None: + request_kwargs["response_format"] = response_format response = await provider.complete_with_tools( messages=messages, tools=openai_tools if openai_tools else [], - tool_choice="auto", - response_format=response_format, + **request_kwargs, ) # Return in expected format: {"content": "...", "tool_calls": [...]} diff --git a/src/octopal/runtime/workers/allowed_paths.py b/src/octopal/runtime/workers/allowed_paths.py index 86306e5e..8518caba 100644 --- a/src/octopal/runtime/workers/allowed_paths.py +++ b/src/octopal/runtime/workers/allowed_paths.py @@ -44,6 +44,30 @@ def _workspace_relative_path( return rel_path.as_posix() +def _existing_parent_workspace_path( + raw_path: object, + *, + workspace_dir: Path | None = None, +) -> str | None: + raw = str(raw_path or "").strip().strip("`'\".,;:)") + if not raw: + return None + rel = _workspace_relative_path(raw, workspace_dir=workspace_dir) + if not rel: + return None + rel_path = Path(rel) + if not rel_path.suffix: + return None + + workspace = _workspace_path(workspace_dir) + parent = rel_path.parent + while str(parent) not in {"", "."}: + if (workspace / parent).is_dir(): + return parent.as_posix() + parent = parent.parent + return None + + def normalize_allowed_paths( value: object, *, @@ -77,11 +101,14 @@ def infer_allowed_paths_from_task( seen: set[str] = set() inferred: list[str] = [] for match in _PATH_TOKEN_RE.finditer(task or ""): + raw_path = match.group("path") rel = _workspace_relative_path( - match.group("path"), + raw_path, workspace_dir=workspace_dir, require_exists=True, ) + if not rel: + rel = _existing_parent_workspace_path(raw_path, workspace_dir=workspace_dir) if not rel or rel in seen: continue seen.add(rel) diff --git a/src/octopal/runtime/workers/launcher.py b/src/octopal/runtime/workers/launcher.py index 81516084..631c3809 100644 --- a/src/octopal/runtime/workers/launcher.py +++ b/src/octopal/runtime/workers/launcher.py @@ -80,7 +80,7 @@ async def launch( host_worker_dir = Path(cwd).resolve() container_worker_dir = f"{container_ws}/workers/{worker_id}" cmd_args.extend(["-v", f"{host_worker_dir}:{container_worker_dir}"]) - container_env = _filter_container_env(env, worker_workspace=container_worker_dir) + container_env = _filter_container_env(env, container_workspace=container_ws) for key, value in container_env.items(): cmd_args.extend(["-e", f"{key}={value}"]) cmd_args.extend(["-e", f"HOME={container_worker_dir}"]) @@ -149,7 +149,7 @@ async def launch( def _filter_container_env( - env: dict[str, str], *, worker_workspace: str | None = None + env: dict[str, str], *, container_workspace: str | None = None ) -> dict[str, str]: # Container env must be explicit; keep only a safe subset. allowed = { @@ -168,8 +168,8 @@ def _filter_container_env( "FIRECRAWL_API_KEY", } filtered = {key: value for key, value in env.items() if key in allowed} - if worker_workspace: - filtered["OCTOPAL_WORKSPACE_DIR"] = worker_workspace + if container_workspace: + filtered["OCTOPAL_WORKSPACE_DIR"] = container_workspace return filtered diff --git a/src/octopal/tools/workers/management.py b/src/octopal/tools/workers/management.py index 5d3697a5..7baf4eb3 100644 --- a/src/octopal/tools/workers/management.py +++ b/src/octopal/tools/workers/management.py @@ -60,6 +60,51 @@ "mcp_exec", "analyze_image", } +_FILE_WRITE_TASK_TOKENS = { + "append", + "create", + "created", + "creates", + "draft", + "edit", + "edits", + "save", + "saved", + "update", + "updates", + "write", + "writes", + "writing", +} +_FILE_ARTIFACT_TASK_TOKENS = { + "artifact", + "config", + "csv", + "doc", + "document", + "draft", + "file", + "files", + "json", + "markdown", + "md", + "note", + "notes", + "path", + "report", + "text", + "toml", + "workspace", + "yaml", + "yml", +} +_FILESYSTEM_WRITE_TOOL_TOKENS = {"fs_write", "write_file"} +_FILE_PATH_HINT_RE = re.compile( + r"(?:^|[\s`'\"])[\w./\\-]+\." + r"(?:cfg|conf|csv|html|ini|json|log|md|py|toml|txt|ya?ml)" + r"(?:$|[\s`'\",.:;])", + re.IGNORECASE, +) def get_worker_tools() -> list[ToolSpec]: @@ -1948,6 +1993,12 @@ def _template_supports_image_analysis(template: object) -> bool: return any(token in descriptor for token in ("image", "vision", "visual")) +def _template_supports_workspace_write(template: object) -> bool: + available_tools = {str(t).lower() for t in getattr(template, "available_tools", [])} + permissions = set(_normalize_worker_permissions(getattr(template, "required_permissions", []))) + return bool(_FILESYSTEM_WRITE_TOOL_TOKENS & available_tools) or "filesystem_write" in permissions + + def _validate_template_requirements( template: object, *, @@ -1980,6 +2031,12 @@ def _validate_template_requirements( "analysis capability. Use a vision-capable model/tool directly or a worker template with " "image, vision, mcp_call, or mcp_exec support." ) + if _task_requests_workspace_write(task, _tokenize(task)) and not _template_supports_workspace_write(template): + return ( + f"{error_prefix}: worker '{getattr(template, 'id', '')}' does not advertise workspace write " + "capability for this file/artifact task. Use worker_id='auto' or a worker template with " + "fs_write/filesystem_write support." + ) return None @@ -2015,6 +2072,12 @@ def _tokenize(text: str) -> set[str]: return set(_TOKEN_RE.findall((text or "").lower())) +def _task_requests_workspace_write(task: str, task_tokens: set[str]) -> bool: + has_write_verb = bool(task_tokens & _FILE_WRITE_TASK_TOKENS) + has_artifact_hint = bool(task_tokens & _FILE_ARTIFACT_TASK_TOKENS) + return has_write_verb and (has_artifact_hint or bool(_FILE_PATH_HINT_RE.search(task or ""))) + + def _select_worker_template( *, templates: list[object], @@ -2030,6 +2093,7 @@ def _select_worker_template( task_tokens = _tokenize(task) if not task_tokens: task_tokens = {"task"} + task_requests_workspace_write = _task_requests_workspace_write(task, task_tokens) best: dict[str, object] | None = None for template in templates: @@ -2060,6 +2124,9 @@ def _select_worker_template( available_tools = [str(t).lower() for t in getattr(template, "available_tools", [])] permissions = _normalize_worker_permissions(getattr(template, "required_permissions", [])) + has_filesystem_write = ( + bool(_FILESYSTEM_WRITE_TOOL_TOKENS & set(available_tools)) or "filesystem_write" in permissions + ) if required_tools: matched_tools = sum(1 for t in required_tools if t in available_tools) @@ -2072,6 +2139,13 @@ def _select_worker_template( score += matched_perms * 4.0 reasons.append(f"required_permissions={matched_perms}/{len(required_permissions)}") + if task_requests_workspace_write: + if has_filesystem_write: + score += 7.0 + reasons.append("filesystem_write_bonus") + else: + score -= 5.0 + reasons.append("filesystem_write_missing_penalty") if "web" in task_tokens and any("web" in t for t in available_tools): score += 3.0 reasons.append("web_tool_bonus") diff --git a/src/octopal/utils.py b/src/octopal/utils.py index 6a66948d..fc965315 100644 --- a/src/octopal/utils.py +++ b/src/octopal/utils.py @@ -312,6 +312,8 @@ def looks_like_textual_tool_invocation(text: str) -> bool: trimmed = re.sub(r"[\s\W_]+$", "", trimmed, flags=re.UNICODE).strip() if not trimmed: return False + if trimmed.isupper() and "_" in trimmed: + return False return bool( _TEXTUAL_TOOL_NAME_RE.fullmatch(trimmed) diff --git a/tests/test_agent_worker_contracts.py b/tests/test_agent_worker_contracts.py new file mode 100644 index 00000000..e408f3f2 --- /dev/null +++ b/tests/test_agent_worker_contracts.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from octopal.runtime.workers.agent_worker import ( + _force_tool_choice, + _fs_write_completion_missing, + _task_requires_workspace_write, +) + + +def test_workspace_write_task_detection_requires_write_intent_and_file_hint() -> None: + assert _task_requires_workspace_write( + "Create a short markdown report at experiments/qa/marker-worker-report.md" + ) + assert not _task_requires_workspace_write("Summarize the latest provider news") + + +def test_fs_write_completion_missing_requires_available_but_unused_tool() -> None: + task = "Write the report to experiments/qa/marker-worker-report.md" + + assert _fs_write_completion_missing(task, ["fs_read", "fs_write"], []) + assert not _fs_write_completion_missing(task, ["fs_read", "fs_write"], ["fs_write"]) + assert not _fs_write_completion_missing(task, ["web_search"], []) + + +def test_force_tool_choice_uses_openai_function_shape() -> None: + assert _force_tool_choice("fs_write") == { + "type": "function", + "function": {"name": "fs_write"}, + } diff --git a/tests/test_heartbeat.py b/tests/test_heartbeat.py index 91c399b2..cfbc4f3b 100644 --- a/tests/test_heartbeat.py +++ b/tests/test_heartbeat.py @@ -415,6 +415,7 @@ def test_detect_textual_tool_invocation(): assert not looks_like_textual_tool_invocation("SCHEDULED_TASK_DONE") assert not looks_like_textual_tool_invocation("SCHEDULED_TASK_BLOCKED") assert not looks_like_textual_tool_invocation("SCHEDULER_IDLE") + assert not looks_like_textual_tool_invocation("A2A_OK") assert not looks_like_textual_tool_invocation("Result ready. NO_USER_RESPONSE") assert not looks_like_textual_tool_invocation("Проверяю расписание:") assert not looks_like_textual_tool_invocation("Checking schedule... check_schedule") diff --git a/tests/test_litellm_provider_payload_hardening.py b/tests/test_litellm_provider_payload_hardening.py index a21d8150..23e2ef97 100644 --- a/tests/test_litellm_provider_payload_hardening.py +++ b/tests/test_litellm_provider_payload_hardening.py @@ -126,6 +126,27 @@ def test_serialize_message_coerces_null_tool_call_content() -> None: assert serialized["tool_calls"][0]["function"]["name"] == "dummy_tool" +def test_complete_enforces_runtime_timeout(monkeypatch) -> None: + async def _fake_acompletion(**kwargs): + await asyncio.sleep(1) + return _response("too-late") + + settings = _settings() + settings.litellm_timeout = 0.01 + monkeypatch.setattr( + "octopal.infrastructure.providers.litellm_provider.acompletion", + _fake_acompletion, + ) + provider = LiteLLMProvider(settings) + + try: + asyncio.run(provider.complete([{"role": "user", "content": "hello"}])) + except RuntimeError as exc: + assert "LiteLLM completion failed" in str(exc) + else: + raise AssertionError("completion should time out") + + def test_complete_retries_once_when_client_was_closed(monkeypatch) -> None: calls = {"n": 0} diff --git a/tests/test_worker_launcher_isolation.py b/tests/test_worker_launcher_isolation.py index 9fde080f..6f07c0f1 100644 --- a/tests/test_worker_launcher_isolation.py +++ b/tests/test_worker_launcher_isolation.py @@ -49,7 +49,7 @@ async def _fake_exec(*args, **kwargs): assert f"{worker_dir}:/workspace/workers/worker-1" in args assert f"{workspace / 'skills'}:/workspace/workers/worker-1/skills" in args assert "-e" in args - assert "OCTOPAL_WORKSPACE_DIR=/workspace/workers/worker-1" in args + assert "OCTOPAL_WORKSPACE_DIR=/workspace" in args assert "HOME=/workspace/workers/worker-1" in args assert "PYTHONPATH=src" in args assert "BRAVE_API_KEY=brave-test-key" in args @@ -96,7 +96,7 @@ async def _fake_exec(*args, **kwargs): assert f"{workspace / 'skills'}:/workspace/workers/worker-1/skills" in args assert f"{shared_dir}:/workspace/src" in args assert f"{shared_dir}:/workspace/workers/worker-1/src" in args - assert "OCTOPAL_WORKSPACE_DIR=/workspace/workers/worker-1" in args + assert "OCTOPAL_WORKSPACE_DIR=/workspace" in args assert "HOME=/workspace/workers/worker-1" in args assert f"{workspace}:/workspace" not in args diff --git a/tests/test_worker_router.py b/tests/test_worker_router.py index 993ebd57..289ed414 100644 --- a/tests/test_worker_router.py +++ b/tests/test_worker_router.py @@ -71,6 +71,32 @@ def test_select_worker_template_rejects_missing_required_tools() -> None: assert selected is None +def test_select_worker_template_prefers_filesystem_worker_for_file_write_task() -> None: + templates = [ + _template( + "file_editor", + "File Editor", + "Safely edits text and config files in the workspace", + ["fs_read", "fs_write"], + ["filesystem_read", "filesystem_write"], + ), + _template( + "web_search_ranked", + "Web Search Ranked", + "Search the web and return a ranked list of relevant sources", + ["web_search"], + ["network"], + ), + ] + selected = _select_worker_template( + templates=templates, + task="Create a short markdown report at experiments/qa/marker-worker-report.md with risks and mitigations.", + ) + assert selected is not None + assert selected["template"].id == "file_editor" + assert "filesystem_write_bonus" in selected["reason"] + + def test_select_worker_template_requires_image_capability_for_image_tasks() -> None: templates = [ _template("moltbook_orchestrator", "Presence Manager", "Sequential task manager", ["fs_read"], ["filesystem_read"]), @@ -293,6 +319,42 @@ async def _scenario() -> str: assert "does not advertise image/vision analysis capability" in result +def test_start_worker_rejects_explicit_worker_without_workspace_write_capability() -> None: + templates = [ + _template("web_researcher", "Web Researcher", "Searches web", ["web_search"], ["network"]), + ] + + class _Store: + def list_worker_templates(self): + return templates + + def get_worker_template(self, worker_id: str): + for t in templates: + if t.id == worker_id: + return t + return None + + class _Octo: + def __init__(self) -> None: + self.store = _Store() + + async def _start_worker_async(self, **kwargs): + raise AssertionError("worker launch should have been rejected") + + async def _scenario() -> str: + return await _tool_start_worker( + { + "task": "Create a short markdown report at experiments/qa/marker-worker-report.md.", + "worker_id": "web_researcher", + }, + {"octo": _Octo(), "chat_id": 123}, + ) + + result = asyncio.run(_scenario()) + assert "does not advertise workspace write capability" in result + assert "fs_write/filesystem_write" in result + + def test_start_worker_infers_existing_workspace_paths(monkeypatch, tmp_path) -> None: image_path = tmp_path / "tmp" / "telegram_images" / "img_test.jpg" image_path.parent.mkdir(parents=True) @@ -302,3 +364,15 @@ def test_start_worker_infers_existing_workspace_paths(monkeypatch, tmp_path) -> inferred = _infer_allowed_paths_from_task("Inspect tmp/telegram_images/img_test.jpg") assert inferred == ["tmp/telegram_images/img_test.jpg"] + + +def test_start_worker_infers_existing_parent_for_new_workspace_file(monkeypatch, tmp_path) -> None: + report_dir = tmp_path / "experiments" / "qa" + report_dir.mkdir(parents=True) + monkeypatch.setenv("OCTOPAL_WORKSPACE_DIR", str(tmp_path)) + + inferred = _infer_allowed_paths_from_task( + "Create experiments/qa/new-agent-report.md with the requested summary" + ) + + assert inferred == ["experiments/qa"] From 77c38187838eccdf758a9f611c7f78499a51e5ba Mon Sep 17 00:00:00 2001 From: Slava Trofimov <26082149+pmbstyle@users.noreply.github.com> Date: Mon, 18 May 2026 12:11:39 -0400 Subject: [PATCH 2/2] fix agent loop test call fakes --- tests/test_agent_loop_improvements.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_agent_loop_improvements.py b/tests/test_agent_loop_improvements.py index 08ae114c..099f9a48 100644 --- a/tests/test_agent_loop_improvements.py +++ b/tests/test_agent_loop_improvements.py @@ -668,7 +668,7 @@ async def _noop_log(level: str, message: str) -> None: ] ) - async def _fake_call_llm(provider, messages, tools): + async def _fake_call_llm(provider, messages, tools, **_kwargs): return next(responses) async def _fake_execute_tool( @@ -712,7 +712,7 @@ async def _noop_log(level: str, message: str) -> None: ) monkeypatch.setattr("octopal.runtime.workers.agent_worker.get_tools", lambda: []) - async def _fake_call_llm(provider, messages, tools): + async def _fake_call_llm(provider, messages, tools, **_kwargs): return {"content": ""} monkeypatch.setattr("octopal.runtime.workers.agent_worker._call_llm", _fake_call_llm) @@ -769,7 +769,7 @@ async def _noop_log(level: str, message: str) -> None: ] ) - async def _fake_call_llm(provider, messages, tools): + async def _fake_call_llm(provider, messages, tools, **_kwargs): return next(responses) async def _fake_execute_tool( @@ -841,7 +841,7 @@ async def _noop_log(level: str, message: str) -> None: ] ) - async def _fake_call_llm(provider, messages, tools): + async def _fake_call_llm(provider, messages, tools, **_kwargs): tool_names = {tool.name for tool in tools} assert "request_instruction" in tool_names assert "answer_worker_instruction" not in tool_names @@ -969,7 +969,7 @@ async def _noop_log(level: str, message: str) -> None: ] ) - async def _fake_call_llm(provider, messages, tools): + async def _fake_call_llm(provider, messages, tools, **_kwargs): tool_names = {tool.name for tool in tools} assert "request_instruction" in tool_names assert "answer_worker_instruction" in tool_names @@ -1065,7 +1065,7 @@ async def _fake_log(level: str, message: str) -> None: call_state = {"llm_calls": 0} executed_tools: list[tuple[str | None, dict]] = [] - async def _fake_call_llm(provider, messages, tools): + async def _fake_call_llm(provider, messages, tools, **_kwargs): tool_names = {tool.name for tool in tools} assert "request_instruction" in tool_names assert "answer_worker_instruction" in tool_names @@ -1198,7 +1198,7 @@ async def _noop_log(level: str, message: str) -> None: call_state = {"llm_calls": 0} executed_tools: list[str | None] = [] - async def _fake_call_llm(provider, messages, tools): + async def _fake_call_llm(provider, messages, tools, **_kwargs): call_state["llm_calls"] += 1 if call_state["llm_calls"] == 1: return { @@ -1416,7 +1416,7 @@ async def _fake_log(level: str, message: str) -> None: ] ) - async def _fake_call_llm(provider, messages, tools): + async def _fake_call_llm(provider, messages, tools, **_kwargs): return next(responses) executed_tools: list[str] = []