Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/octopal/runtime/workers/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,16 +865,36 @@ async def answer_instruction(
worker_id: str,
request_id: str,
instruction: str,
answerer_worker_id: str | None = None,
) -> bool:
key = (str(worker_id).strip(), str(request_id).strip())
answerer_id = str(answerer_worker_id or "").strip()
if answerer_id:
worker = await asyncio.to_thread(self.store.get_worker, key[0])
parent_worker_id = str(getattr(worker, "parent_worker_id", "") or "").strip()
if worker is None or parent_worker_id != answerer_id:
await self._append_audit(
"worker_instruction_answer_denied",
correlation_id=key[0],
data={
"worker_id": key[0],
"request_id": key[1],
"answerer_worker_id": answerer_id,
"reason": "not_parent_worker",
},
)
return False
future = self._instruction_waiters.get(key)
if future is None or future.done():
return False
await asyncio.to_thread(self.store.update_worker_status, key[0], "running")
audit_data = {"worker_id": key[0], "request_id": key[1]}
if answerer_id:
audit_data["answerer_worker_id"] = answerer_id
await self._append_audit(
"worker_instruction_answered",
correlation_id=key[0],
data={"worker_id": key[0], "request_id": key[1]},
data=audit_data,
)
future.set_result(str(instruction or "").strip())
return True
Expand Down
27 changes: 27 additions & 0 deletions src/octopal/tools/workers/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,19 @@ async def _tool_answer_worker_instruction(args: dict[str, object], ctx: dict[str
worker = octo.store.get_worker(worker_id)
if worker is None:
return json.dumps({"status": "not_found", "worker_id": worker_id}, ensure_ascii=False)
answerer_worker_id = _answerer_worker_id(ctx)
if answerer_worker_id is not None and not _is_direct_child_worker(
worker,
parent_worker_id=answerer_worker_id,
):
return json.dumps(
{
"status": "unauthorized",
"worker_id": worker_id,
"message": "Only a worker's direct parent can answer its instruction request.",
},
ensure_ascii=False,
)
if not request_id:
request = _extract_worker_instruction_request(worker.output)
request_id = str(request.get("request_id") or "").strip() if request else ""
Expand All @@ -1215,6 +1228,7 @@ async def _tool_answer_worker_instruction(args: dict[str, object], ctx: dict[str
worker_id=worker_id,
request_id=request_id,
instruction=instruction,
answerer_worker_id=answerer_worker_id,
)
return json.dumps(
{
Expand All @@ -1226,6 +1240,19 @@ async def _tool_answer_worker_instruction(args: dict[str, object], ctx: dict[str
)


def _answerer_worker_id(ctx: dict[str, object]) -> str | None:
caller_worker = ctx.get("worker")
spec = getattr(caller_worker, "spec", None)
if spec is None:
return None
answerer_id = str(getattr(spec, "run_id", "") or getattr(spec, "id", "") or "").strip()
return answerer_id or None


def _is_direct_child_worker(worker: object, *, parent_worker_id: str) -> bool:
return str(getattr(worker, "parent_worker_id", "") or "").strip() == parent_worker_id


def _tool_get_worker_status(args: dict[str, object], ctx: dict[str, object]) -> str:
octo: Octo = ctx["octo"]
worker_id = str(args.get("worker_id", "")).strip()
Expand Down
100 changes: 100 additions & 0 deletions tests/test_worker_suspend_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import json
from datetime import UTC, datetime
from pathlib import Path
from types import SimpleNamespace

from octopal.infrastructure.config.settings import Settings
from octopal.infrastructure.store.models import AuditEvent, WorkerRecord
from octopal.runtime.workers.contracts import WorkerResult, WorkerSpec
from octopal.runtime.workers.runtime import WorkerRuntime
from octopal.tools.workers.management import _tool_answer_worker_instruction


class _Store:
Expand Down Expand Up @@ -208,6 +210,104 @@ async def _run() -> str:
assert store.audit_events[-1].event_type == "worker_instruction_answered"


def test_runtime_answer_instruction_rejects_non_parent_answerer(tmp_path: Path) -> None:
child = _worker_record("child-1", "awaiting_instruction").model_copy(
update={"parent_worker_id": "parent-1"}
)
store = _Store({"child-1": child})
runtime = WorkerRuntime(
store=store,
policy=_Policy(),
workspace_dir=tmp_path,
launcher=object(),
settings=Settings(),
)

async def _run() -> bool:
future = asyncio.get_running_loop().create_future()
runtime._instruction_waiters[("child-1", "req-1")] = future
answered = await runtime.answer_instruction(
worker_id="child-1",
request_id="req-1",
instruction="attacker instruction",
answerer_worker_id="unrelated-parent",
)
assert future.done() is False
return answered

answered = asyncio.run(_run())

assert answered is False
assert store.status_updates == []
assert store.audit_events[-1].event_type == "worker_instruction_answer_denied"
assert store.audit_events[-1].data["answerer_worker_id"] == "unrelated-parent"


def test_answer_worker_instruction_requires_direct_parent_from_worker_context(
tmp_path: Path,
) -> None:
child = _worker_record("child-1", "awaiting_instruction").model_copy(
update={
"parent_worker_id": "parent-1",
"output": {
"instruction_request": {
"request_id": "req-1",
"worker_id": "child-1",
"target": "parent",
"question": "Which path?",
"created_at": "2026-04-18T12:00:00+00:00",
}
},
}
)
store = _Store({"child-1": child})
runtime = WorkerRuntime(
store=store,
policy=_Policy(),
workspace_dir=tmp_path,
launcher=object(),
settings=Settings(),
)
octo = SimpleNamespace(store=store, runtime=runtime)
attacker_spec = WorkerSpec(
id="attacker-worker",
task="coordinate unrelated work",
inputs={},
system_prompt="s",
available_tools=["start_child_worker"],
granted_capabilities=[],
timeout_seconds=30,
max_thinking_steps=5,
run_id="attacker-worker",
)
parent_spec = attacker_spec.model_copy(update={"id": "parent-1", "run_id": "parent-1"})

async def _run() -> tuple[dict[str, object], dict[str, object], str]:
future = asyncio.get_running_loop().create_future()
runtime._instruction_waiters[("child-1", "req-1")] = future
attacker_result = json.loads(
await _tool_answer_worker_instruction(
{"worker_id": "child-1", "instruction": "attacker instruction"},
{"octo": octo, "worker": SimpleNamespace(spec=attacker_spec)},
)
)
assert future.done() is False
parent_result = json.loads(
await _tool_answer_worker_instruction(
{"worker_id": "child-1", "instruction": "parent instruction"},
{"octo": octo, "worker": SimpleNamespace(spec=parent_spec)},
)
)
return attacker_result, parent_result, await future

attacker_result, parent_result, instruction = asyncio.run(_run())

assert attacker_result["status"] == "unauthorized"
assert parent_result["status"] == "answered"
assert instruction == "parent instruction"
assert store.status_updates == [("child-1", "running")]


def test_runtime_enqueues_octo_instruction_request_and_resumes(tmp_path: Path) -> None:
store = _Store({})
runtime = WorkerRuntime(
Expand Down