diff --git a/src/octopal/runtime/workers/runtime.py b/src/octopal/runtime/workers/runtime.py index 1be0899..55239f5 100644 --- a/src/octopal/runtime/workers/runtime.py +++ b/src/octopal/runtime/workers/runtime.py @@ -865,16 +865,36 @@ async def answer_instruction( worker_id: str, request_id: str, instruction: str, + answerer_worker_id: str | None = None, ) -> bool: key = (str(worker_id).strip(), str(request_id).strip()) + answerer_id = str(answerer_worker_id or "").strip() + if answerer_id: + worker = await asyncio.to_thread(self.store.get_worker, key[0]) + parent_worker_id = str(getattr(worker, "parent_worker_id", "") or "").strip() + if worker is None or parent_worker_id != answerer_id: + await self._append_audit( + "worker_instruction_answer_denied", + correlation_id=key[0], + data={ + "worker_id": key[0], + "request_id": key[1], + "answerer_worker_id": answerer_id, + "reason": "not_parent_worker", + }, + ) + return False future = self._instruction_waiters.get(key) if future is None or future.done(): return False await asyncio.to_thread(self.store.update_worker_status, key[0], "running") + audit_data = {"worker_id": key[0], "request_id": key[1]} + if answerer_id: + audit_data["answerer_worker_id"] = answerer_id await self._append_audit( "worker_instruction_answered", correlation_id=key[0], - data={"worker_id": key[0], "request_id": key[1]}, + data=audit_data, ) future.set_result(str(instruction or "").strip()) return True diff --git a/src/octopal/tools/workers/management.py b/src/octopal/tools/workers/management.py index aa01967..5d3697a 100644 --- a/src/octopal/tools/workers/management.py +++ b/src/octopal/tools/workers/management.py @@ -1192,6 +1192,19 @@ async def _tool_answer_worker_instruction(args: dict[str, object], ctx: dict[str worker = octo.store.get_worker(worker_id) if worker is None: return json.dumps({"status": "not_found", "worker_id": worker_id}, ensure_ascii=False) + answerer_worker_id = _answerer_worker_id(ctx) + if answerer_worker_id is not None and not _is_direct_child_worker( + worker, + parent_worker_id=answerer_worker_id, + ): + return json.dumps( + { + "status": "unauthorized", + "worker_id": worker_id, + "message": "Only a worker's direct parent can answer its instruction request.", + }, + ensure_ascii=False, + ) if not request_id: request = _extract_worker_instruction_request(worker.output) request_id = str(request.get("request_id") or "").strip() if request else "" @@ -1215,6 +1228,7 @@ async def _tool_answer_worker_instruction(args: dict[str, object], ctx: dict[str worker_id=worker_id, request_id=request_id, instruction=instruction, + answerer_worker_id=answerer_worker_id, ) return json.dumps( { @@ -1226,6 +1240,19 @@ async def _tool_answer_worker_instruction(args: dict[str, object], ctx: dict[str ) +def _answerer_worker_id(ctx: dict[str, object]) -> str | None: + caller_worker = ctx.get("worker") + spec = getattr(caller_worker, "spec", None) + if spec is None: + return None + answerer_id = str(getattr(spec, "run_id", "") or getattr(spec, "id", "") or "").strip() + return answerer_id or None + + +def _is_direct_child_worker(worker: object, *, parent_worker_id: str) -> bool: + return str(getattr(worker, "parent_worker_id", "") or "").strip() == parent_worker_id + + def _tool_get_worker_status(args: dict[str, object], ctx: dict[str, object]) -> str: octo: Octo = ctx["octo"] worker_id = str(args.get("worker_id", "")).strip() diff --git a/tests/test_worker_suspend_resume.py b/tests/test_worker_suspend_resume.py index 805ff6b..bfd3f7f 100644 --- a/tests/test_worker_suspend_resume.py +++ b/tests/test_worker_suspend_resume.py @@ -4,11 +4,13 @@ import json from datetime import UTC, datetime from pathlib import Path +from types import SimpleNamespace from octopal.infrastructure.config.settings import Settings from octopal.infrastructure.store.models import AuditEvent, WorkerRecord from octopal.runtime.workers.contracts import WorkerResult, WorkerSpec from octopal.runtime.workers.runtime import WorkerRuntime +from octopal.tools.workers.management import _tool_answer_worker_instruction class _Store: @@ -208,6 +210,104 @@ async def _run() -> str: assert store.audit_events[-1].event_type == "worker_instruction_answered" +def test_runtime_answer_instruction_rejects_non_parent_answerer(tmp_path: Path) -> None: + child = _worker_record("child-1", "awaiting_instruction").model_copy( + update={"parent_worker_id": "parent-1"} + ) + store = _Store({"child-1": child}) + runtime = WorkerRuntime( + store=store, + policy=_Policy(), + workspace_dir=tmp_path, + launcher=object(), + settings=Settings(), + ) + + async def _run() -> bool: + future = asyncio.get_running_loop().create_future() + runtime._instruction_waiters[("child-1", "req-1")] = future + answered = await runtime.answer_instruction( + worker_id="child-1", + request_id="req-1", + instruction="attacker instruction", + answerer_worker_id="unrelated-parent", + ) + assert future.done() is False + return answered + + answered = asyncio.run(_run()) + + assert answered is False + assert store.status_updates == [] + assert store.audit_events[-1].event_type == "worker_instruction_answer_denied" + assert store.audit_events[-1].data["answerer_worker_id"] == "unrelated-parent" + + +def test_answer_worker_instruction_requires_direct_parent_from_worker_context( + tmp_path: Path, +) -> None: + child = _worker_record("child-1", "awaiting_instruction").model_copy( + update={ + "parent_worker_id": "parent-1", + "output": { + "instruction_request": { + "request_id": "req-1", + "worker_id": "child-1", + "target": "parent", + "question": "Which path?", + "created_at": "2026-04-18T12:00:00+00:00", + } + }, + } + ) + store = _Store({"child-1": child}) + runtime = WorkerRuntime( + store=store, + policy=_Policy(), + workspace_dir=tmp_path, + launcher=object(), + settings=Settings(), + ) + octo = SimpleNamespace(store=store, runtime=runtime) + attacker_spec = WorkerSpec( + id="attacker-worker", + task="coordinate unrelated work", + inputs={}, + system_prompt="s", + available_tools=["start_child_worker"], + granted_capabilities=[], + timeout_seconds=30, + max_thinking_steps=5, + run_id="attacker-worker", + ) + parent_spec = attacker_spec.model_copy(update={"id": "parent-1", "run_id": "parent-1"}) + + async def _run() -> tuple[dict[str, object], dict[str, object], str]: + future = asyncio.get_running_loop().create_future() + runtime._instruction_waiters[("child-1", "req-1")] = future + attacker_result = json.loads( + await _tool_answer_worker_instruction( + {"worker_id": "child-1", "instruction": "attacker instruction"}, + {"octo": octo, "worker": SimpleNamespace(spec=attacker_spec)}, + ) + ) + assert future.done() is False + parent_result = json.loads( + await _tool_answer_worker_instruction( + {"worker_id": "child-1", "instruction": "parent instruction"}, + {"octo": octo, "worker": SimpleNamespace(spec=parent_spec)}, + ) + ) + return attacker_result, parent_result, await future + + attacker_result, parent_result, instruction = asyncio.run(_run()) + + assert attacker_result["status"] == "unauthorized" + assert parent_result["status"] == "answered" + assert instruction == "parent instruction" + assert store.status_updates == [("child-1", "running")] + + def test_runtime_enqueues_octo_instruction_request_and_resumes(tmp_path: Path) -> None: store = _Store({}) runtime = WorkerRuntime(