diff --git a/backend/agentpal/agents/sub_agent.py b/backend/agentpal/agents/sub_agent.py index e4d0a61..e03212b 100644 --- a/backend/agentpal/agents/sub_agent.py +++ b/backend/agentpal/agents/sub_agent.py @@ -366,18 +366,47 @@ async def _update_status( result: str | None = None, error: str | None = None, ) -> None: + from agentpal.services.task_event_bus import task_event_bus + self._task.status = status if result is not None: self._task.result = result if error is not None: self._task.error = error - if status in (TaskStatus.DONE, TaskStatus.FAILED): + + # 设置时间戳 + if status == TaskStatus.RUNNING and not self._task.started_at: + self._task.started_at = datetime.now(timezone.utc) + if status in (TaskStatus.DONE, TaskStatus.FAILED, TaskStatus.CANCELLED): self._task.finished_at = datetime.now(timezone.utc) + self._task.completed_at = self._task.finished_at + try: await self._db.flush() except Exception: pass + # 发射任务状态事件 + event_type_map = { + TaskStatus.RUNNING: "task.started", + TaskStatus.DONE: "task.completed", + TaskStatus.FAILED: "task.failed", + TaskStatus.CANCELLED: "task.cancelled", + TaskStatus.INPUT_REQUIRED: "task.input_required", + TaskStatus.PAUSED: "task.paused", + } + if status in event_type_map: + await task_event_bus.emit( + self._task.id, + event_type_map[status], + { + "status": status.value, + "result": result[:500] if result else None, + "error": error[:500] if error else None, + }, + f"任务状态变更为 {status.value}", + ) + # 任务终态时发布 WebSocket 通知 if status in (TaskStatus.DONE, TaskStatus.FAILED): try: @@ -407,9 +436,184 @@ async def _update_status( pass # 通知失败不影响主流程 def _log(self, event_type: str, data: dict[str, Any]) -> None: - """向执行日志追加一条记录。""" + """向执行日志追加一条记录,并同步发射到 TaskEventBus。""" + import asyncio + + from agentpal.services.task_event_bus import task_event_bus + self._execution_log.append({ "type": event_type, "timestamp": datetime.now(timezone.utc).isoformat(), **data, }) + + # 同步发射到事件总线(用于实时 SSE 推送) + event_type_map = { + "tool_start": "tool.start", + "tool_done": "tool.complete", + "llm_response": "llm.message", + "user_message": "user.message", + } + if event_type in event_type_map: + message_map = { + "tool_start": f"开始执行工具 {data.get('name', '')}", + "tool_done": f"工具 {data.get('name', '')} 执行完成", + "llm_response": "LLM 响应", + "user_message": "用户消息", + } + # 不要在这里 await,避免阻塞主流程 + asyncio.create_task( + task_event_bus.emit( + self._task.id, + event_type_map[event_type], + data, + message_map.get(event_type, ""), + ) + ) + + async def _emit_progress(self, pct: int, message: str) -> None: + """发射进度更新事件。""" + import asyncio + + from agentpal.services.task_event_bus import task_event_bus + + self._task.progress_pct = pct + self._task.progress_message = message + + try: + await self._db.flush() + except Exception: + pass + + asyncio.create_task( + task_event_bus.emit( + self._task.id, + "task.progress", + {"pct": pct, "message": message}, + message, + ) + ) + + async def produce_artifact( + self, + artifact_type: str, + content: str, + title: str | None = None, + extra: dict[str, Any] | None = None, + ) -> str: + """生成任务产出物并保存到数据库。 + + Args: + artifact_type: 产出物类型(如 "code", "doc", "analysis", "summary") + content: 产出物内容(文本或 JSON) + title: 人类可读标题 + extra: 额外元数据 + + Returns: + 产出物 ID + """ + from agentpal.models.session import TaskArtifact + + artifact_id = f"art_{uuid.uuid4().hex[:12]}" + artifact = TaskArtifact( + id=artifact_id, + task_id=self._task.id, + artifact_type=artifact_type, + content=content, + title=title or f"{artifact_type}_{artifact_id[:8]}", + extra=extra or {}, + ) + self._db.add(artifact) + await self._db.flush() + + # 发射事件 + from agentpal.services.task_event_bus import task_event_bus + + asyncio.create_task( + task_event_bus.emit( + self._task.id, + "artifact.created", + { + "artifact_id": artifact_id, + "artifact_type": artifact_type, + "title": artifact.title, + }, + f"生成产出物:{artifact.title}", + ) + ) + + return artifact_id + + async def request_user_input( + self, + question: str, + context: str | None = None, + ) -> str: + """请求用户输入,将任务状态设置为 INPUT_REQUIRED 并等待。 + + Args: + question: 向用户提出的问题 + context: 可选的上下文信息 + + Returns: + 用户提供的输入内容 + """ + # 保存当前问题到 meta + if self._task.meta is None: + self._task.meta = {} + self._task.meta["input_request"] = { + "question": question, + "context": context, + "requested_at": datetime.now(timezone.utc).isoformat(), + } + + # 记录到执行日志 + self._log("input_requested", {"question": question, "context": context}) + + # 更新任务状态为 INPUT_REQUIRED + await self._update_status(TaskStatus.INPUT_REQUIRED) + + # 等待用户输入(轮询检查 meta) + while True: + await asyncio.sleep(2) # 每 2 秒检查一次 + # 刷新任务数据 + await self._db.refresh(self._task) + if self._task.meta and "user_input" in self._task.meta: + user_input = self._task.meta["user_input"] + self._log("input_received", {"input": user_input[:500]}) + # 清除输入请求 + del self._task.meta["input_request"] + await self._db.flush() + return user_input + # 检查任务是否被恢复执行 + if self._task.status == TaskStatus.PENDING: + # 任务已恢复,获取用户输入 + user_input = self._task.meta.get("user_input", "") + self._log("input_received", {"input": user_input[:500]}) + return user_input + # 检查任务是否被取消 + if self._task.status == TaskStatus.CANCELLED: + self._log("task_cancelled", {"reason": self._task.meta.get("cancel_reason", "用户取消")}) + raise asyncio.CancelledError("任务已被取消") + + async def cancel(self, reason: str = "用户取消") -> None: + """取消正在运行的任务。 + + Args: + reason: 取消原因 + """ + self._log("cancelling", {"reason": reason}) + await self._update_status(TaskStatus.CANCELLED) + + # 保存取消原因 + if self._task.meta is None: + self._task.meta = {} + self._task.meta["cancel_reason"] = reason + self._task.meta["cancelled_at"] = datetime.now(timezone.utc).isoformat() + + try: + await self._db.flush() + except Exception: + pass + + logger.info(f"SubAgent task {self._task.id} cancelled: {reason}") diff --git a/backend/agentpal/api/v1/endpoints/agent.py b/backend/agentpal/api/v1/endpoints/agent.py index f64ad1d..920c5d7 100644 --- a/backend/agentpal/api/v1/endpoints/agent.py +++ b/backend/agentpal/api/v1/endpoints/agent.py @@ -60,6 +60,8 @@ class DispatchRequest(BaseModel): agent_name: str | None = None priority: int = Field(default=5, ge=1, le=10, description="任务优先级 1-10(10 最高)") max_retries: int = Field(default=3, ge=0, le=10, description="最大重试次数 0-10") + blocking: bool = Field(default=False, description="是否阻塞等待任务完成") + wait_seconds: int = Field(default=120, ge=0, description="阻塞模式下的最大等待时间(秒)") class TaskStatusResponse(BaseModel): @@ -149,31 +151,76 @@ async def event_stream() -> AsyncGenerator[str, None]: @router.post("/dispatch", response_model=TaskStatusResponse) async def dispatch_sub_agent(req: DispatchRequest, db: AsyncSession = Depends(get_db)): - """派遣 SubAgent 异步执行任务。""" + """派遣 SubAgent 异步执行任务。 + + 支持两种模式: + - 非阻塞模式(blocking=False):立即返回任务 ID 和初始状态 + - 阻塞模式(blocking=True):等待任务完成后返回最终结果 + """ settings = get_settings() memory = MemoryFactory.create(settings.memory_backend, db=db) assistant = PersonalAssistant(session_id=req.parent_session_id, memory=memory, db=db) - task = await assistant.dispatch_sub_agent( - task_prompt=req.task_prompt, - db=db, - context=req.context, - task_type=req.task_type, - agent_name=req.agent_name, - priority=req.priority, - max_retries=req.max_retries, - ) - return TaskStatusResponse( - task_id=task.id, - status=task.status, - result=task.result, - error=task.error, - agent_name=task.agent_name, - task_type=task.task_type, - priority=task.priority, - retry_count=task.retry_count, - max_retries=task.max_retries, - created_at=utc_isoformat(task.created_at), - ) + + if req.blocking: + # 阻塞模式:使用工具层面的 dispatch_sub_agent + from agentpal.tools.builtin import dispatch_sub_agent as builtin_dispatch + + result_text = await builtin_dispatch( + task_prompt=req.task_prompt, + parent_session_id=req.parent_session_id, + task_type=req.task_type or "", + agent_name=req.agent_name or "", + wait_seconds=req.wait_seconds, + blocking=True, + ) + + # 从数据库获取最新的 task 记录 + from sqlalchemy import select + from agentpal.models.agent import SubAgentTask + + result = await db.execute( + select(SubAgentTask) + .where(SubAgentTask.parent_session_id == req.parent_session_id) + .order_by(SubAgentTask.created_at.desc()) + .limit(1) + ) + task = result.scalars().first() + + return TaskStatusResponse( + task_id=task.id if task else "unknown", + status=task.status.value if task else "unknown", + result=result_text, + error=getattr(task, "error", None), + agent_name=getattr(task, "agent_name", None), + task_type=getattr(task, "task_type", None), + priority=getattr(task, "priority", 5), + retry_count=getattr(task, "retry_count", 0), + max_retries=getattr(task, "max_retries", 3), + created_at=utc_isoformat(task.created_at) if task else None, + ) + else: + # 非阻塞模式:立即返回 + task = await assistant.dispatch_sub_agent( + task_prompt=req.task_prompt, + db=db, + context=req.context, + task_type=req.task_type, + agent_name=req.agent_name, + priority=req.priority, + max_retries=req.max_retries, + ) + return TaskStatusResponse( + task_id=task.id, + status=task.status, + result=task.result, + error=task.error, + agent_name=task.agent_name, + task_type=task.task_type, + priority=task.priority, + retry_count=task.retry_count, + max_retries=task.max_retries, + created_at=utc_isoformat(task.created_at), + ) @router.get("/tasks/{task_id}", response_model=TaskStatusResponse) diff --git a/backend/agentpal/api/v1/endpoints/tasks.py b/backend/agentpal/api/v1/endpoints/tasks.py new file mode 100644 index 0000000..82cb265 --- /dev/null +++ b/backend/agentpal/api/v1/endpoints/tasks.py @@ -0,0 +1,371 @@ +"""SubAgent 任务管理 API 端点。""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import AsyncGenerator + +from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from agentpal.database import get_db +from agentpal.models.session import SubAgentTask, TaskArtifact, TaskEvent, TaskStatus + +router = APIRouter() + + +class SubmitUserInputRequest(BaseModel): + """提交用户输入的请求体。""" + + user_input: str + continue_execution: bool = True + + +@router.get("/{task_id}") +async def get_task(task_id: str, db: AsyncSession = Depends(get_db)): + """获取单个任务详情。""" + result = await db.execute(select(SubAgentTask).where(SubAgentTask.id == task_id)) + task = result.scalar_one_or_none() + if task is None: + raise HTTPException(status_code=404, detail=f"任务 '{task_id}' 不存在") + return task + + +@router.get("/{task_id}/events") +async def stream_task_events(task_id: str, db: AsyncSession = Depends(get_db)): + """SSE 流:实时推送 SubAgent 任务事件。 + + 客户端连接后,会立即收到历史事件,然后持续接收新事件直到任务结束。 + """ + # 验证任务是否存在 + result = await db.execute(select(SubAgentTask).where(SubAgentTask.id == task_id)) + task = result.scalar_one_or_none() + if task is None: + raise HTTPException(status_code=404, detail=f"任务 '{task_id}' 不存在") + + from agentpal.services.task_event_bus import task_event_bus + + # 订阅事件总线 + queue = task_event_bus.subscribe(task_id) + + async def event_generator() -> AsyncGenerator[str, None]: + try: + # 先发送历史事件 + history_result = await db.execute( + select(TaskEvent) + .where(TaskEvent.task_id == task_id) + .order_by(TaskEvent.created_at.asc()) + ) + history_events = history_result.scalars().all() + for event in history_events: + yield f"data: {json.dumps({'event_type': event.event_type, 'event_data': event.event_data or {}, 'message': event.message, 'created_at': event.created_at.isoformat()})}\n\n" + + # 持续监听新事件 + while True: + try: + import asyncio + + event = await asyncio.wait_for(queue.get(), timeout=30.0) + # 从数据库获取完整的 event 记录(包含 created_at) + latest_result = await db.execute( + select(TaskEvent) + .where(TaskEvent.task_id == task_id, TaskEvent.event_type == event["event_type"]) + .order_by(TaskEvent.created_at.desc()) + .limit(1) + ) + latest_event = latest_result.scalar_one_or_none() + event_with_time = { + "event_type": event["event_type"], + "event_data": event["event_data"], + "message": event["message"], + "created_at": latest_event.created_at.isoformat() if latest_event else None, + } + yield f"data: {json.dumps(event_with_time)}\n\n" + except TimeoutError: + # 心跳:检查任务状态 + if task.status in (TaskStatus.DONE, TaskStatus.FAILED, TaskStatus.CANCELLED): + break + # 重新获取任务状态 + refresh_result = await db.execute(select(SubAgentTask).where(SubAgentTask.id == task_id)) + refreshed_task = refresh_result.scalar_one_or_none() + if refreshed_task and refreshed_task.status in (TaskStatus.DONE, TaskStatus.FAILED, TaskStatus.CANCELLED): + break + continue + except asyncio.CancelledError: + break + + finally: + # 取消订阅 + task_event_bus.unsubscribe(task_id, queue) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", # Nginx: 禁用缓冲 + }, + ) + + +@router.get("/{task_id}/artifacts") +async def list_task_artifacts(task_id: str, db: AsyncSession = Depends(get_db)): + """列出任务的所有产出物。""" + # 验证任务是否存在 + result = await db.execute(select(SubAgentTask).where(SubAgentTask.id == task_id)) + task = result.scalar_one_or_none() + if task is None: + raise HTTPException(status_code=404, detail=f"任务 '{task_id}' 不存在") + + artifacts_result = await db.execute( + select(TaskArtifact).where(TaskArtifact.task_id == task_id).order_by(TaskArtifact.created_at.asc()) + ) + artifacts = artifacts_result.scalars().all() + return artifacts + + +@router.get("/{task_id}/artifacts/{artifact_id}") +async def get_task_artifact(task_id: str, artifact_id: str, db: AsyncSession = Depends(get_db)): + """获取单个产出物内容。""" + result = await db.execute( + select(TaskArtifact).where(TaskArtifact.id == artifact_id, TaskArtifact.task_id == task_id) + ) + artifact = result.scalar_one_or_none() + if artifact is None: + raise HTTPException(status_code=404, detail=f"产出物 '{artifact_id}' 不存在") + return artifact + + +@router.post("/{task_id}/input") +async def submit_user_input( + task_id: str, + request: SubmitUserInputRequest, + db: AsyncSession = Depends(get_db), +): + """向处于 INPUT_REQUIRED 状态的任务提交用户输入并恢复执行。 + + Args: + task_id: 任务 ID + user_input: 用户提供的输入内容 + continue_execution: 是否继续执行任务(默认 True) + + Returns: + 任务当前状态 + """ + result = await db.execute(select(SubAgentTask).where(SubAgentTask.id == task_id)) + task = result.scalar_one_or_none() + if task is None: + raise HTTPException(status_code=404, detail=f"任务 '{task_id}' 不存在") + + if task.status != TaskStatus.INPUT_REQUIRED: + raise HTTPException( + status_code=400, + detail=f"任务当前状态为 '{task.status.value}',不需要用户输入", + ) + + # 将用户输入存储到任务的 meta 字段中 + if task.meta is None: + task.meta = {} + task.meta["user_input"] = request.user_input + task.meta["user_input_timestamp"] = json.dumps(__import__("datetime").datetime.now().isoformat()) + + # 如果 continue_execution 为 True,则将任务状态改为 PENDING 以恢复执行 + if request.continue_execution: + task.status = TaskStatus.PENDING + task.meta["resumed_at"] = json.dumps(__import__("datetime").datetime.now(timezone.utc).isoformat()) + + await db.commit() + + # 发射恢复事件 + from agentpal.services.task_event_bus import task_event_bus + + asyncio = __import__("asyncio") + asyncio.create_task( + task_event_bus.emit( + task_id, + "task.resumed", + {"user_input": request.user_input[:500]}, + "任务已恢复执行", + ) + ) + + return { + "task_id": task_id, + "status": task.status.value, + "message": "用户输入已提交,任务已恢复执行" if request.continue_execution else "用户输入已提交,任务保持暂停", + } + + +# ────────────────────────────────────────────────────────── +# Artifact 相关 API +# ────────────────────────────────────────────────────────── + + +class ArtifactCreate(BaseModel): + """创建产出物的请求体。""" + + task_id: str + name: str + artifact_type: str = "text" + content: str | None = None + file_path: str | None = None + mime_type: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +@router.post("/artifacts", response_model=dict[str, str], tags=["artifacts"]) +async def create_artifact(request: ArtifactCreate, db: AsyncSession = Depends(get_async_db)) -> dict[str, str]: + """为任务创建产出物(代码、报告、图表等)。""" + import uuid + + from agentpal.models.session import TaskArtifact + + artifact_id = str(uuid.uuid4()) + artifact = TaskArtifact( + id=artifact_id, + task_id=request.task_id, + name=request.name, + artifact_type=request.artifact_type, + content=request.content, + file_path=request.file_path, + mime_type=request.mime_type, + metadata=request.metadata, + ) + db.add(artifact) + await db.commit() + + # Emit event + from agentpal.services.task_event_bus import task_event_bus + + asyncio.create_task( + task_event_bus.emit( + request.task_id, + "task.artifact_created", + {"artifact_id": artifact_id, "name": request.name}, + f"已创建产出物:{request.name}", + ) + ) + + return {"artifact_id": artifact_id, "status": "created"} + + +@router.get("/{task_id}/artifacts", response_model=list[dict[str, Any]], tags=["artifacts"]) +async def list_artifacts(task_id: str, db: AsyncSession = Depends(get_async_db)) -> list[dict[str, Any]]: + """获取任务的所有产出物列表。""" + from sqlalchemy import select + + from agentpal.models.session import TaskArtifact + + result = await db.execute(select(TaskArtifact).where(TaskArtifact.task_id == task_id).order_by(TaskArtifact.created_at.desc())) + artifacts = result.scalars().all() + + return [ + { + "id": a.id, + "name": a.name, + "artifact_type": a.artifact_type, + "mime_type": a.mime_type, + "size_bytes": a.size_bytes, + "created_at": a.created_at.isoformat() if a.created_at else None, + } + for a in artifacts + ] + + +@router.get("/artifacts/{artifact_id}", response_model=dict[str, Any], tags=["artifacts"]) +async def get_artifact(artifact_id: str, db: AsyncSession = Depends(get_async_db)) -> dict[str, Any]: + """获取单个产出物的详细内容。""" + from sqlalchemy import select + + from agentpal.models.session import TaskArtifact + + result = await db.execute(select(TaskArtifact).where(TaskArtifact.id == artifact_id)) + artifact = result.scalar_one_or_none() + + if not artifact: + raise HTTPException(status_code=404, detail="Artifact not found") + + return { + "id": artifact.id, + "task_id": artifact.task_id, + "name": artifact.name, + "artifact_type": artifact.artifact_type, + "content": artifact.content, + "file_path": artifact.file_path, + "mime_type": artifact.mime_type, + "size_bytes": artifact.size_bytes, + "extra": artifact.extra, + "created_at": artifact.created_at.isoformat() if artifact.created_at else None, + } + + +# ────────────────────────────────────────────────────────── +# Task Cancel API +# ────────────────────────────────────────────────────────── + + +class CancelTaskRequest(BaseModel): + """取消任务的请求体。""" + + reason: str | None = "用户取消" + + +@router.post("/{task_id}/cancel", response_model=dict[str, Any], tags=["tasks"]) +async def cancel_task( + task_id: str, + request: CancelTaskRequest | None = None, + db: AsyncSession = Depends(get_async_db), +) -> dict[str, Any]: + """取消正在运行的 SubAgent 任务。""" + from sqlalchemy import select + + from agentpal.models.session import SubAgentTask, TaskStatus + + result = await db.execute(select(SubAgentTask).where(SubAgentTask.id == task_id)) + task = result.scalar_one_or_none() + + if not task: + raise HTTPException(status_code=404, detail="Task not found") + + # 只有运行中或暂停的任务可以取消 + if task.status not in (TaskStatus.RUNNING, TaskStatus.PAUSED, TaskStatus.PENDING, TaskStatus.INPUT_REQUIRED): + return { + "task_id": task_id, + "status": task.status.value, + "message": f"任务状态为 {task.status.value},无需取消", + } + + # 更新任务状态 + task.status = TaskStatus.CANCELLED + if task.meta is None: + task.meta = {} + task.meta["cancel_reason"] = request.reason if request else "用户取消" + task.meta["cancelled_at"] = datetime.now(timezone.utc).isoformat() + task.finished_at = datetime.now(timezone.utc) + + await db.commit() + + # 发射事件 + from agentpal.services.task_event_bus import task_event_bus + + asyncio.create_task( + task_event_bus.emit( + task_id, + "task.cancelled", + {"reason": task.meta["cancel_reason"]}, + f"任务已取消:{task.meta['cancel_reason']}", + ) + ) + + return { + "task_id": task_id, + "status": TaskStatus.CANCELLED.value, + "message": "任务已成功取消", + } + + diff --git a/backend/agentpal/api/v1/router.py b/backend/agentpal/api/v1/router.py index c73a253..10f5560 100644 --- a/backend/agentpal/api/v1/router.py +++ b/backend/agentpal/api/v1/router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from agentpal.api.v1.endpoints import agent, channel, config, cron, dashboard, memory, notifications, providers, session, skills, sub_agents, tools, workspace +from agentpal.api.v1.endpoints import agent, channel, config, cron, dashboard, memory, notifications, providers, session, skills, sub_agents, tasks, tools, workspace router = APIRouter() router.include_router(agent.router, prefix="/agent", tags=["agent"]) @@ -12,6 +12,7 @@ router.include_router(config.router, prefix="/config", tags=["config"]) router.include_router(providers.router, prefix="/providers", tags=["providers"]) router.include_router(sub_agents.router, prefix="/sub-agents", tags=["sub-agents"]) +router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) router.include_router(cron.router, prefix="/cron", tags=["cron"]) router.include_router(dashboard.router, prefix="/dashboard", tags=["dashboard"]) router.include_router(memory.router, prefix="/memory", tags=["memory"]) diff --git a/backend/agentpal/database.py b/backend/agentpal/database.py index 5bdf4ee..e253f31 100644 --- a/backend/agentpal/database.py +++ b/backend/agentpal/database.py @@ -94,6 +94,15 @@ async def run_migrations() -> None: ("sub_agent_tasks", "agent_name", "ALTER TABLE sub_agent_tasks ADD COLUMN agent_name VARCHAR(64)"), ("sub_agent_tasks", "task_type", "ALTER TABLE sub_agent_tasks ADD COLUMN task_type VARCHAR(64)"), ("sub_agent_tasks", "execution_log", "ALTER TABLE sub_agent_tasks ADD COLUMN execution_log JSON NOT NULL DEFAULT '[]'"), + # sub_agent_tasks: input_prompt / input_response / progress_pct / progress_message / started_at / completed_at (added for bidirectional comm) + ("sub_agent_tasks", "input_prompt", "ALTER TABLE sub_agent_tasks ADD COLUMN input_prompt TEXT"), + ("sub_agent_tasks", "input_response", "ALTER TABLE sub_agent_tasks ADD COLUMN input_response TEXT"), + ("sub_agent_tasks", "progress_pct", "ALTER TABLE sub_agent_tasks ADD COLUMN progress_pct INTEGER DEFAULT 0"), + ("sub_agent_tasks", "progress_message", "ALTER TABLE sub_agent_tasks ADD COLUMN progress_message TEXT"), + ("sub_agent_tasks", "started_at", "ALTER TABLE sub_agent_tasks ADD COLUMN started_at DATETIME"), + ("sub_agent_tasks", "completed_at", "ALTER TABLE sub_agent_tasks ADD COLUMN completed_at DATETIME"), + # Phase 6: Rename metadata to extra in task_artifacts (metadata is reserved) + ("task_artifacts", "extra", "ALTER TABLE task_artifacts ADD COLUMN extra JSON"), ] async with engine.begin() as conn: for table, column, sql in migrations: @@ -105,6 +114,11 @@ async def run_migrations() -> None: if column not in existing_cols: await conn.execute(__import__("sqlalchemy").text(sql)) + # 创建新表(task_artifacts, task_events) + from agentpal.models.session import TaskArtifact, TaskEvent + await conn.run_sync(TaskArtifact.metadata.create_all) + await conn.run_sync(TaskEvent.metadata.create_all) + async def get_db() -> AsyncGenerator[AsyncSession, None]: """FastAPI Depends 注入用。""" diff --git a/backend/agentpal/models/__init__.py b/backend/agentpal/models/__init__.py index d1ac652..6a49136 100644 --- a/backend/agentpal/models/__init__.py +++ b/backend/agentpal/models/__init__.py @@ -2,7 +2,14 @@ from agentpal.models.cron import CronJob, CronJobExecution from agentpal.models.memory import MemoryRecord from agentpal.models.message import AgentMessage -from agentpal.models.session import SessionRecord, SessionStatus, SubAgentTask, TaskStatus +from agentpal.models.session import ( + SessionRecord, + SessionStatus, + SubAgentTask, + TaskStatus, + TaskArtifact, + TaskEvent, +) from agentpal.models.skill import SkillRecord from agentpal.models.tool import ToolCallLog, ToolConfig @@ -17,6 +24,8 @@ "SubAgentDefinition", "SubAgentTask", "TaskStatus", + "TaskArtifact", + "TaskEvent", "ToolConfig", "ToolCallLog", ] diff --git a/backend/agentpal/models/session.py b/backend/agentpal/models/session.py index 239ffd9..cad1d49 100644 --- a/backend/agentpal/models/session.py +++ b/backend/agentpal/models/session.py @@ -24,6 +24,8 @@ class TaskStatus(StrEnum): DONE = "done" FAILED = "failed" CANCELLED = "cancelled" + INPUT_REQUIRED = "input-required" # 需要用户提供输入 + PAUSED = "paused" # 暂停状态 class SessionRecord(Base): @@ -77,6 +79,12 @@ class SubAgentTask(Base): - agent_name: 执行此任务的 SubAgent 角色名 - task_type: 任务类型(用于 SubAgent 角色路由) - execution_log: 完整执行日志(LLM 对话 + 工具调用) + - input_prompt: 请求用户输入时的提示语(INPUT_REQUIRED 状态时使用) + - input_response: 用户提供的输入内容 + - progress_pct: 进度百分比(0-100) + - progress_message: 进度描述信息 + - started_at: 实际开始执行时间 + - completed_at: 执行完成时间(无论成功失败) """ __tablename__ = "sub_agent_tasks" @@ -102,7 +110,93 @@ class SubAgentTask(Base): retry_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False, server_default="0") max_retries: Mapped[int] = mapped_column(Integer, default=3, nullable=False, server_default="3") + # ── Input-Required 协议 ─────────────────────────────── + input_prompt: Mapped[str | None] = mapped_column(Text, nullable=True) + input_response: Mapped[str | None] = mapped_column(Text, nullable=True) + + # ── 进度跟踪 ────────────────────────────────────────── + progress_pct: Mapped[int | None] = mapped_column(Integer, nullable=True, server_default="0") + progress_message: Mapped[str | None] = mapped_column(Text, nullable=True) + + # ── 时间戳 ──────────────────────────────────────────── + started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + __table_args__ = ( Index("ix_task_parent_status", "parent_session_id", "status"), Index("ix_task_priority", "status", "priority"), ) + + +class TaskArtifact(Base): + """SubAgent 任务产出物。 + + SubAgent 在执行过程中可以产生多个中间产物或最终产物, + 例如:生成的代码文件、分析报告、图表等。 + + 字段说明: + - task_id: 关联的任务 ID + - name: 产出物名称(例如:"analysis_report.md") + - artifact_type: 产出物类型:file/text/image/data + - content: 文本内容(text 类型时使用) + - file_path: 文件路径(file 类型时使用) + - mime_type: MIME 类型(image/png, text/markdown 等) + - size_bytes: 文件大小(字节) + - extra: 额外元数据 + """ + + __tablename__ = "task_artifacts" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + task_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True) + name: Mapped[str] = mapped_column(String(256), nullable=False) + artifact_type: Mapped[str] = mapped_column(String(32), nullable=False) # file/text/image/data + content: Mapped[str | None] = mapped_column(Text, nullable=True) + file_path: Mapped[str | None] = mapped_column(Text, nullable=True) + mime_type: Mapped[str | None] = mapped_column(String(128), nullable=True) + size_bytes: Mapped[int | None] = mapped_column(Integer, nullable=True) + extra: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + __table_args__ = ( + Index("ix_artifact_task", "task_id", "created_at"), + ) + + +class TaskEvent(Base): + """SubAgent 任务事件日志。 + + 记录 SubAgent 执行过程中的关键事件,用于: + - 实时进度推送(SSE) + - 调试和审计 + - 前端时间线展示 + + 事件类型: + - task.started: 任务开始执行 + - task.progress: 进度更新 + - task.input_required: 请求用户输入 + - task.artifact_created: 产出物生成 + - task.completed: 任务完成 + - task.failed: 任务失败 + - task.cancelled: 任务取消 + - tool.start: 工具调用开始 + - tool.complete: 工具调用完成 + - llm.message: LLM 消息 + """ + + __tablename__ = "task_events" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + task_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True) + event_type: Mapped[str] = mapped_column(String(64), nullable=False) + event_data: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + message: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), index=True + ) + + __table_args__ = ( + Index("ix_event_task_created", "task_id", "created_at"), + ) diff --git a/backend/agentpal/runtimes/__init__.py b/backend/agentpal/runtimes/__init__.py new file mode 100644 index 0000000..d1cd02c --- /dev/null +++ b/backend/agentpal/runtimes/__init__.py @@ -0,0 +1,39 @@ +"""Agent Runtime 包 — 统一的 Agent 运行时抽象层。 + +提供多种运行时实现: +- InternalSubAgentRuntime: 内置 SubAgent(本地执行) +- HTTPAgentRuntime: 远程 HTTP Agent 服务 +- LangGraphRuntime: LangGraph 工作流引擎(预留) +""" + +from agentpal.runtimes.base import ( + BaseAgentRuntime, + ExecutionResult, + RuntimeConfig, + RuntimeStatus, + ToolCall, + ToolResult, +) +from agentpal.runtimes.http import HTTPAgentRuntime +from agentpal.runtimes.internal import InternalSubAgentRuntime +from agentpal.runtimes.registry import ( + RuntimeRegistry, + get_runtime, + list_available_runtimes, + runtime_registry, +) + +__all__ = [ + "BaseAgentRuntime", + "ExecutionResult", + "RuntimeConfig", + "RuntimeStatus", + "ToolCall", + "ToolResult", + "InternalSubAgentRuntime", + "HTTPAgentRuntime", + "RuntimeRegistry", + "runtime_registry", + "get_runtime", + "list_available_runtimes", +] diff --git a/backend/agentpal/runtimes/base.py b/backend/agentpal/runtimes/base.py new file mode 100644 index 0000000..36e303a --- /dev/null +++ b/backend/agentpal/runtimes/base.py @@ -0,0 +1,304 @@ +"""BaseAgentRuntime — Agent 运行时抽象基类。 + +定义统一的 Agent 运行时接口,支持多种 Agent 提供者: +- InternalSubAgentRuntime: 内置 SubAgent(本地执行) +- HTTPAgentRuntime: 远程 HTTP Agent 服务(如 pi-mono/OpenClaw) +- LangGraphRuntime: LangGraph 工作流引擎 + +核心概念: +- Runtime: 负责 Agent 的生命周期管理和执行 +- 每个 Runtime 持有独立的 session/memory/db 上下文 +- 支持流式和非流式两种执行模式 +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, AsyncGenerator + + +class RuntimeStatus(Enum): + """运行时状态。""" + + IDLE = "idle" # 空闲,可接受新任务 + RUNNING = "running" # 正在执行任务 + PAUSED = "paused" # 暂停(等待输入或其他事件) + ERROR = "error" # 错误状态 + + +@dataclass +class RuntimeConfig: + """运行时配置。 + + Attributes: + runtime_type: 运行时类型标识符(如 "internal", "http", "langgraph") + model_config: 模型配置 dict + max_tool_rounds: 最大工具调用轮次 + timeout_seconds: 执行超时(秒) + extra: 额外配置参数 + """ + + runtime_type: str + model_config: dict[str, Any] | None = None + max_tool_rounds: int = 16 + timeout_seconds: float = 300.0 + extra: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ExecutionResult: + """执行结果。 + + Attributes: + success: 是否成功 + output: 输出文本 + error: 错误信息(如果有) + metadata: 元数据(token 用量、执行时间等) + """ + + success: bool + output: str = "" + error: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +class BaseAgentRuntime(ABC): + """Agent 运行时抽象基类。 + + 职责: + 1. 管理 Agent 生命周期(初始化、执行、清理) + 2. 提供统一的执行接口(同步/流式) + 3. 处理状态转换和错误恢复 + 4. 支持取消操作 + + 子类需要实现: + - _initialize(): 初始化运行时 + - _execute_core(): 核心执行逻辑 + - _stream_core(): 流式执行逻辑 + - _cleanup(): 清理资源 + - _cancel(): 取消当前执行 + """ + + def __init__( + self, + session_id: str, + config: RuntimeConfig, + db: Any | None = None, + memory: Any | None = None, + parent_session_id: str | None = None, + ) -> None: + """初始化运行时。 + + Args: + session_id: 会话 ID + config: 运行时配置 + db: 数据库 session(可选) + memory: 记忆模块(可选) + parent_session_id: 父会话 ID(用于 Agent 间通信) + """ + self.session_id = session_id + self.config = config + self.db = db + self.memory = memory + self.parent_session_id = parent_session_id + self._status = RuntimeStatus.IDLE + self._current_task_id: str | None = None + + # ── 公共接口 ──────────────────────────────────────────── + + @property + def status(self) -> RuntimeStatus: + """获取当前状态。""" + return self._status + + @property + def current_task_id(self) -> str | None: + """获取当前执行的任务 ID。""" + return self._current_task_id + + async def execute( + self, + task_prompt: str, + task_id: str | None = None, + **kwargs: Any, + ) -> ExecutionResult: + """执行任务(非流式)。 + + Args: + task_prompt: 任务提示词 + task_id: 任务 ID(可选) + **kwargs: 额外参数 + + Returns: + ExecutionResult: 执行结果 + """ + self._status = RuntimeStatus.RUNNING + self._current_task_id = task_id + + try: + await self._initialize() + result = await self._execute_core(task_prompt, **kwargs) + self._status = RuntimeStatus.IDLE + return result + except asyncio.CancelledError: + self._status = RuntimeStatus.IDLE + raise + except Exception as e: + self._status = RuntimeStatus.ERROR + return ExecutionResult( + success=False, + error=f"{type(e).__name__}: {e}", + ) + finally: + await self._cleanup() + self._current_task_id = None + + async def stream( + self, + task_prompt: str, + task_id: str | None = None, + **kwargs: Any, + ) -> AsyncGenerator[dict[str, Any], None]: + """流式执行任务。 + + Args: + task_prompt: 任务提示词 + task_id: 任务 ID(可选) + **kwargs: 额外参数 + + Yields: + SSE 事件 dict,包含: + - {"type": "thinking_delta", "delta": "..."} + - {"type": "tool_start", "id": "...", "name": "...", "input": {...}} + - {"type": "tool_done", "id": "...", "output": "..."} + - {"type": "text_delta", "delta": "..."} + - {"type": "done", "result": "..."} + - {"type": "error", "message": "..."} + """ + self._status = RuntimeStatus.RUNNING + self._current_task_id = task_id + + try: + await self._initialize() + async for event in self._stream_core(task_prompt, **kwargs): + yield event + self._status = RuntimeStatus.IDLE + except asyncio.CancelledError: + self._status = RuntimeStatus.IDLE + raise + except Exception as e: + self._status = RuntimeStatus.ERROR + yield {"type": "error", "message": f"{type(e).__name__}: {e}"} + finally: + await self._cleanup() + self._current_task_id = None + + async def cancel(self) -> None: + """取消当前执行。""" + if self._status == RuntimeStatus.RUNNING: + await self._cancel() + self._status = RuntimeStatus.IDLE + + # ── 抽象方法(子类必须实现) ───────────────────────────── + + @abstractmethod + async def _initialize(self) -> None: + """初始化运行时。 + + 子类在此处进行资源初始化,如: + - 加载模型 + - 建立连接 + - 准备工具集 + """ + + @abstractmethod + async def _execute_core(self, task_prompt: str, **kwargs: Any) -> ExecutionResult: + """核心执行逻辑(非流式)。 + + Args: + task_prompt: 任务提示词 + **kwargs: 额外参数 + + Returns: + ExecutionResult: 执行结果 + """ + + @abstractmethod + def _stream_core( + self, task_prompt: str, **kwargs: Any + ) -> AsyncGenerator[dict[str, Any], None]: + """核心流式执行逻辑。 + + Args: + task_prompt: 任务提示词 + **kwargs: 额外参数 + + Yields: + SSE 事件 dict + """ + + @abstractmethod + async def _cleanup(self) -> None: + """清理资源。 + + 子类在此处释放资源,如: + - 关闭连接 + - 释放内存 + - 保存状态 + """ + + @abstractmethod + async def _cancel(self) -> None: + """取消当前执行。 + + 子类实现优雅的中断逻辑。 + """ + + # ── 辅助方法 ──────────────────────────────────────────── + + def _log(self, event_type: str, data: dict[str, Any]) -> None: + """记录执行日志(可选,子类可重写)。 + + Args: + event_type: 事件类型 + data: 事件数据 + """ + # 默认实现:打印到控制台 + print(f"[{self.session_id}] {event_type}: {data}") + + +# ── 工具调用相关数据结构 ─────────────────────────────────── + + +@dataclass +class ToolCall: + """工具调用请求。 + + Attributes: + id: 调用 ID + name: 工具名称 + arguments: 参数字典 + """ + + id: str + name: str + arguments: dict[str, Any] + + +@dataclass +class ToolResult: + """工具调用结果。 + + Attributes: + id: 调用 ID + output: 输出内容 + error: 错误信息(如果有) + duration_ms: 执行时长(毫秒) + """ + + id: str + output: str = "" + error: str | None = None + duration_ms: float = 0.0 diff --git a/backend/agentpal/runtimes/http.py b/backend/agentpal/runtimes/http.py new file mode 100644 index 0000000..4be9a24 --- /dev/null +++ b/backend/agentpal/runtimes/http.py @@ -0,0 +1,335 @@ +"""HTTPAgentRuntime — 远程 HTTP Agent 运行时适配器。 + +支持连接远程 Agent 服务,如: +- pi-mono: 统一的 Agent API 网关 +- OpenClaw: 开源 Agent 框架 +- 其他兼容的 HTTP Agent 服务 + +协议规范: +- POST /chat: 非流式对话 +- POST /chat/stream: 流式对话(SSE) +- POST /cancel: 取消任务 +""" + +from __future__ import annotations + +import asyncio +import json +from typing import Any, AsyncGenerator + +import aiohttp +from loguru import logger + +from agentpal.runtimes.base import ( + BaseAgentRuntime, + ExecutionResult, + RuntimeConfig, + RuntimeStatus, +) + + +class HTTPAgentRuntime(BaseAgentRuntime): + """远程 HTTP Agent 运行时适配器。 + + 通过 HTTP 协议与远程 Agent 服务通信, + 支持流式和非流式两种模式。 + """ + + def __init__( + self, + session_id: str, + config: RuntimeConfig, + db: Any | None = None, + memory: Any | None = None, + parent_session_id: str | None = None, + base_url: str | None = None, + api_key: str | None = None, + ) -> None: + """初始化运行时。 + + Args: + session_id: 会话 ID + config: 运行时配置 + db: 数据库 session(可选) + memory: 记忆模块(可选) + parent_session_id: 父会话 ID + base_url: 远程服务基础 URL + api_key: API 密钥(可选) + """ + super().__init__( + session_id=session_id, + config=config, + db=db, + memory=memory, + parent_session_id=parent_session_id, + ) + + # 从配置或参数中获取 URL + self.base_url = base_url or config.extra.get("base_url", "http://localhost:8000") + self.api_key = api_key or config.extra.get("api_key") + + # HTTP 会话 + self._session: aiohttp.ClientSession | None = None + self._cancel_flag = False + self._current_request_ctx: aiohttp.ClientResponse | None = None + + async def _initialize(self) -> None: + """初始化 HTTP 会话。""" + connector = aiohttp.TCPConnector(limit=10, ttl_dns_cache=300) + self._session = aiohttp.ClientSession(connector=connector) + + headers = { + "Content-Type": "application/json", + } + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + self._session._default_headers.update(headers) + + logger.info(f"HTTPAgentRuntime initialized for {self.base_url}") + + async def _execute_core(self, task_prompt: str, **kwargs: Any) -> ExecutionResult: + """执行任务(非流式)。 + + Args: + task_prompt: 任务提示词 + **kwargs: 额外参数 + + Returns: + ExecutionResult: 执行结果 + """ + if self._session is None: + raise RuntimeError("HTTP session not initialized") + + start_time = asyncio.get_event_loop().time() + + # 构建请求体 + payload = { + "session_id": self.session_id, + "messages": [ + {"role": "user", "content": task_prompt} + ], + "model": self.config.model_config.get("model") if self.config.model_config else None, + "max_tokens": kwargs.get("max_tokens"), + "temperature": kwargs.get("temperature", 0.7), + } + + # 添加工具定义(如果有) + if "tools" in kwargs: + payload["tools"] = kwargs["tools"] + + try: + async with self._session.post( + f"{self.base_url}/chat", + json=payload, + timeout=aiohttp.ClientTimeout(total=self.config.timeout_seconds), + ) as resp: + self._current_request_ctx = resp + + if resp.status != 200: + error_text = await resp.text() + return ExecutionResult( + success=False, + error=f"HTTP {resp.status}: {error_text}", + metadata={"status_code": resp.status}, + ) + + result_data = await resp.json() + + elapsed = asyncio.get_event_loop().time() - start_time + + # 提取回复文本 + output_text = "" + if "choices" in result_data and result_data["choices"]: + output_text = result_data["choices"][0].get("message", {}).get("content", "") + elif "content" in result_data: + output_text = result_data["content"] + + return ExecutionResult( + success=True, + output=output_text, + metadata={ + "elapsed_seconds": elapsed, + "usage": result_data.get("usage", {}), + "model": result_data.get("model"), + }, + ) + + except asyncio.TimeoutError: + return ExecutionResult( + success=False, + error=f"Request timeout after {self.config.timeout_seconds}s", + ) + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(f"HTTPAgentRuntime execution failed: {e}") + return ExecutionResult( + success=False, + error=f"{type(e).__name__}: {e}", + ) + finally: + self._current_request_ctx = None + + async def _stream_core( + self, task_prompt: str, **kwargs: Any + ) -> AsyncGenerator[dict[str, Any], None]: + """流式执行任务(SSE)。 + + Args: + task_prompt: 任务提示词 + **kwargs: 额外参数 + + Yields: + SSE 事件 dict + """ + if self._session is None: + raise RuntimeError("HTTP session not initialized") + + # 构建请求体 + payload = { + "session_id": self.session_id, + "messages": [ + {"role": "user", "content": task_prompt} + ], + "stream": True, + "model": self.config.model_config.get("model") if self.config.model_config else None, + "temperature": kwargs.get("temperature", 0.7), + } + + try: + async with self._session.post( + f"{self.base_url}/chat/stream", + json=payload, + timeout=aiohttp.ClientTimeout(total=self.config.timeout_seconds), + ) as resp: + self._current_request_ctx = resp + + if resp.status != 200: + error_text = await resp.text() + yield {"type": "error", "message": f"HTTP {resp.status}: {error_text}"} + return + + # 解析 SSE 流 + async for line in resp.content.iter_lines(): + if self._cancel_flag: + break + + if not line: + continue + + line_str = line.decode("utf-8").strip() + + if line_str.startswith("data: "): + data_str = line_str[6:] + + if data_str.strip() == "[DONE]": + break + + try: + data = json.loads(data_str) + sse_event = self._parse_sse_data(data) + if sse_event: + yield sse_event + except json.JSONDecodeError: + logger.warning(f"Invalid SSE data: {data_str}") + + except asyncio.TimeoutError: + yield {"type": "error", "message": f"Stream timeout after {self.config.timeout_seconds}s"} + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(f"HTTPAgentRuntime stream failed: {e}") + yield {"type": "error", "message": f"{type(e).__name__}: {e}"} + finally: + self._current_request_ctx = None + + def _parse_sse_data(self, data: dict[str, Any]) -> dict[str, Any] | None: + """解析 SSE 数据块。 + + Args: + data: SSE 数据 dict + + Returns: + 标准化的 SSE 事件 dict + """ + # OpenAI 兼容格式 + if "choices" in data and data["choices"]: + delta = data["choices"][0].get("delta", {}) + + if "content" in delta and delta["content"]: + return {"type": "text_delta", "delta": delta["content"]} + + if "tool_calls" in delta and delta["tool_calls"]: + tool_call = delta["tool_calls"][0] + return { + "type": "tool_start", + "id": tool_call.get("id", ""), + "name": tool_call.get("function", {}).get("name", ""), + "input": tool_call.get("function", {}).get("arguments", {}), + } + + # 通用格式 + if "type" in data: + return data + + if "content" in data: + return {"type": "text_delta", "delta": data["content"]} + + return None + + async def _cleanup(self) -> None: + """清理 HTTP 会话。""" + if self._session: + await self._session.close() + self._session = None + logger.debug(f"HTTPAgentRuntime cleanup for session {self.session_id}") + + async def _cancel(self) -> None: + """取消当前请求。 + + 设置取消标志并关闭当前 HTTP 连接。 + """ + self._cancel_flag = True + + if self._current_request_ctx: + try: + await self._current_request_ctx.close() + except Exception: + pass + + # 尝试调用远程取消端点 + if self._session: + try: + await self._session.post( + f"{self.base_url}/cancel", + json={"session_id": self.session_id}, + timeout=aiohttp.ClientTimeout(total=5), + ) + except Exception: + pass # 忽略取消端点失败 + + logger.info(f"HTTPAgentRuntime cancelled for session {self.session_id}") + + # ── HTTPAgentRuntime 特有方法 ─────────────────────────── + + async def health_check(self) -> dict[str, Any]: + """检查远程服务健康状态。 + + Returns: + 健康检查结果 dict + """ + if self._session is None: + return {"healthy": False, "error": "Session not initialized"} + + try: + async with self._session.get( + f"{self.base_url}/health", + timeout=aiohttp.ClientTimeout(total=5), + ) as resp: + if resp.status == 200: + data = await resp.json() + return {"healthy": True, **data} + return {"healthy": False, "status": resp.status} + except Exception as e: + return {"healthy": False, "error": str(e)} diff --git a/backend/agentpal/runtimes/internal.py b/backend/agentpal/runtimes/internal.py new file mode 100644 index 0000000..7e38fa9 --- /dev/null +++ b/backend/agentpal/runtimes/internal.py @@ -0,0 +1,300 @@ +"""InternalSubAgentRuntime — 内置 SubAgent 运行时适配器。 + +将现有的 SubAgent 逻辑适配到 BaseAgentRuntime 接口, +保持向后兼容的同时提供统一的运行时抽象。 +""" + +from __future__ import annotations + +import asyncio +from typing import Any, AsyncGenerator + +from loguru import logger + +from agentpal.agents.sub_agent import SubAgent +from agentpal.memory.buffer import BufferMemory +from agentpal.models.session import SubAgentTask, TaskStatus +from agentpal.runtimes.base import ( + BaseAgentRuntime, + ExecutionResult, + RuntimeConfig, + RuntimeStatus, +) + + +class InternalSubAgentRuntime(BaseAgentRuntime): + """内置 SubAgent 运行时适配器。 + + 封装现有的 SubAgent 类,提供统一的 Runtime 接口。 + 支持: + - 非流式执行(execute) + - 流式执行(stream) + - 任务取消 + - 状态管理 + """ + + def __init__( + self, + session_id: str, + config: RuntimeConfig, + db: Any | None = None, + memory: Any | None = None, + parent_session_id: str | None = None, + task: SubAgentTask | None = None, + role_prompt: str = "", + ) -> None: + """初始化运行时。 + + Args: + session_id: 会话 ID + config: 运行时配置 + db: 数据库 session + memory: 记忆模块(可选,默认创建 BufferMemory) + parent_session_id: 父会话 ID + task: SubAgentTask 数据库记录 + role_prompt: 角色系统提示词 + """ + super().__init__( + session_id=session_id, + config=config, + db=db, + memory=memory, + parent_session_id=parent_session_id, + ) + self._task = task + self._role_prompt = role_prompt + self._sub_agent: SubAgent | None = None + self._cancel_flag = False + + async def _initialize(self) -> None: + """初始化 SubAgent。""" + if self.db is None: + raise RuntimeError("Database session is required for InternalSubAgentRuntime") + + # 如果没有传入 task,创建一个临时的 + if self._task is None: + self._task = SubAgentTask( + id=f"temp_{self.session_id}", + parent_session_id=self.parent_session_id or self.session_id, + sub_session_id=f"sub:{self.parent_session_id or self.session_id}:temp", + task_prompt="", + status=TaskStatus.PENDING, + agent_name=None, + ) + + # 创建记忆模块 + if self.memory is None: + self.memory = BufferMemory(session_id=self.session_id, db=self.db) + + # 从配置中提取模型配置 + model_config = self.config.model_config or {} + + # 创建 SubAgent 实例 + self._sub_agent = SubAgent( + session_id=self.session_id, + memory=self.memory, + task=self._task, + db=self.db, + model_config=model_config, + role_prompt=self._role_prompt, + max_tool_rounds=self.config.max_tool_rounds, + parent_session_id=self.parent_session_id or "", + ) + + self._status = RuntimeStatus.IDLE + logger.debug(f"InternalSubAgentRuntime initialized for session {self.session_id}") + + async def _execute_core(self, task_prompt: str, **kwargs: Any) -> ExecutionResult: + """执行任务(非流式)。 + + Args: + task_prompt: 任务提示词 + **kwargs: 额外参数 + + Returns: + ExecutionResult: 执行结果 + """ + if self._sub_agent is None: + raise RuntimeError("SubAgent not initialized") + + start_time = asyncio.get_event_loop().time() + + try: + # 执行任务 + result_text = await self._sub_agent.run(task_prompt) + + elapsed = asyncio.get_event_loop().time() - start_time + + return ExecutionResult( + success=True, + output=result_text, + metadata={ + "elapsed_seconds": elapsed, + "session_id": self.session_id, + "task_id": self._task.id if self._task else None, + }, + ) + except asyncio.CancelledError: + raise + except Exception as e: + elapsed = asyncio.get_event_loop().time() - start_time + logger.exception(f"InternalSubAgentRuntime execution failed: {e}") + return ExecutionResult( + success=False, + error=f"{type(e).__name__}: {e}", + metadata={ + "elapsed_seconds": elapsed, + "session_id": self.session_id, + }, + ) + + async def _stream_core( + self, task_prompt: str, **kwargs: Any + ) -> AsyncGenerator[dict[str, Any], None]: + """流式执行任务。 + + 由于 SubAgent 本身不支持真正的流式输出, + 这里采用事件监听的方式,将执行日志转换为 SSE 事件。 + + Args: + task_prompt: 任务提示词 + **kwargs: 额外参数 + + Yields: + SSE 事件 dict + """ + if self._sub_agent is None: + raise RuntimeError("SubAgent not initialized") + + # 订阅任务事件 + from agentpal.services.task_event_bus import task_event_bus + + # 创建队列接收事件 + event_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + + def on_event(event: dict[str, Any]) -> None: + """事件回调。""" + event_queue.put_nowait(event) + + # 订阅该任务的事件 + task_id = self._task.id if self._task else None + if task_id: + task_event_bus.subscribe(task_id, on_event) + + try: + # 启动任务(在后台运行) + task = asyncio.create_task(self._sub_agent.run(task_prompt)) + + # 持续监听事件 + while True: + try: + # 等待事件或任务完成 + event_future = asyncio.ensure_future(event_queue.get()) + done, pending = await asyncio.wait( + [event_future, task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # 处理已完成的事件 + for fut in done: + if fut is event_future and not task.done(): + # 收到事件 + event = event_future.result() + yield self._convert_event_to_sse(event) + elif fut is task: + # 任务完成 + result = fut.result() + yield {"type": "done", "result": result} + break + + # 检查是否有更多事件 + if not task.done(): + continue + + except asyncio.CancelledError: + task.cancel() + raise + + except asyncio.CancelledError: + logger.info(f"Stream cancelled for session {self.session_id}") + raise + except Exception as e: + logger.exception(f"Stream error: {e}") + yield {"type": "error", "message": f"{type(e).__name__}: {e}"} + finally: + # 取消订阅 + if task_id: + task_event_bus.unsubscribe(task_id, on_event) + + def _convert_event_to_sse(self, event: dict[str, Any]) -> dict[str, Any]: + """将任务事件转换为 SSE 格式。 + + Args: + event: 任务事件 dict + + Returns: + SSE 事件 dict + """ + event_type = event.get("event_type", "") + data = event.get("data", {}) + + mapping = { + "tool.start": lambda d: { + "type": "tool_start", + "id": d.get("id", ""), + "name": d.get("name", ""), + "input": d.get("input", {}), + }, + "tool.complete": lambda d: { + "type": "tool_done", + "id": d.get("id", ""), + "output": d.get("output", "")[:2000], + "duration_ms": d.get("duration_ms", 0), + }, + "llm.message": lambda d: { + "type": "text_delta", + "delta": str(d.get("content", ""))[:500], + }, + "task.progress": lambda d: { + "type": "progress", + "pct": d.get("pct", 0), + "message": d.get("message", ""), + }, + "artifact.created": lambda d: { + "type": "artifact", + "artifact_id": d.get("artifact_id", ""), + "artifact_type": d.get("artifact_type", ""), + "title": d.get("title", ""), + }, + } + + converter = mapping.get(event_type, lambda d: {"type": "unknown", "raw": d}) + return converter(data) + + async def _cleanup(self) -> None: + """清理资源。 + + 目前 SubAgent 不需要特殊清理,保留接口供未来扩展。 + """ + logger.debug(f"InternalSubAgentRuntime cleanup for session {self.session_id}") + self._sub_agent = None + + async def _cancel(self) -> None: + """取消当前执行。 + + 设置取消标志并调用 SubAgent 的 cancel 方法。 + """ + self._cancel_flag = True + if self._sub_agent and self._task: + await self._sub_agent.cancel(reason="Runtime cancelled") + logger.info(f"InternalSubAgentRuntime cancelled for session {self.session_id}") + + # ── InternalSubAgentRuntime 特有方法 ───────────────────── + + def get_sub_agent(self) -> SubAgent | None: + """获取内部的 SubAgent 实例(用于高级用法)。""" + return self._sub_agent + + def get_task(self) -> SubAgentTask | None: + """获取关联的 SubAgentTask 记录。""" + return self._task diff --git a/backend/agentpal/runtimes/registry.py b/backend/agentpal/runtimes/registry.py new file mode 100644 index 0000000..2b43e55 --- /dev/null +++ b/backend/agentpal/runtimes/registry.py @@ -0,0 +1,269 @@ +"""Runtime Registry — Agent 运行时注册表和工厂。 + +提供运行时类型的注册、查找和创建功能, +支持通过配置文件动态切换运行时实现。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Type + +from loguru import logger + +from agentpal.runtimes.base import BaseAgentRuntime, RuntimeConfig + + +@dataclass +class RuntimeDescriptor: + """运行时描述符。 + + Attributes: + name: 运行时名称(如 "internal", "http") + runtime_class: 运行时类 + description: 描述信息 + config_schema: 配置 Schema(JSON Schema 格式) + """ + + name: str + runtime_class: Type[BaseAgentRuntime] + description: str = "" + config_schema: dict[str, Any] = field(default_factory=dict) + + +class RuntimeRegistry: + """运行时注册表(单例)。 + + 用法: + # 注册运行时 + registry.register("internal", InternalSubAgentRuntime, "内置 SubAgent") + + # 获取运行时 + runtime = registry.create("internal", session_id="xxx", config=...) + """ + + _instance: RuntimeRegistry | None = None + _runtimes: dict[str, RuntimeDescriptor] + + def __new__(cls) -> RuntimeRegistry: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._runtimes = {} + return cls._instance + + def register( + self, + name: str, + runtime_class: Type[BaseAgentRuntime], + description: str = "", + config_schema: dict[str, Any] | None = None, + ) -> None: + """注册运行时类型。 + + Args: + name: 运行时名称 + runtime_class: 运行时类 + description: 描述信息 + config_schema: 配置 Schema + """ + self._runtimes[name] = RuntimeDescriptor( + name=name, + runtime_class=runtime_class, + description=description, + config_schema=config_schema or {}, + ) + logger.info(f"Registered runtime: {name}") + + def unregister(self, name: str) -> None: + """注销运行时类型。 + + Args: + name: 运行时名称 + """ + if name in self._runtimes: + del self._runtimes[name] + logger.info(f"Unregistered runtime: {name}") + + def get(self, name: str) -> RuntimeDescriptor | None: + """获取运行时描述符。 + + Args: + name: 运行时名称 + + Returns: + 运行时描述符,不存在则返回 None + """ + return self._runtimes.get(name) + + def list_runtimes(self) -> list[dict[str, Any]]: + """列出所有已注册的运行时。 + + Returns: + 运行时信息列表 + """ + return [ + { + "name": desc.name, + "description": desc.description, + "config_schema": desc.config_schema, + } + for desc in self._runtimes.values() + ] + + def create( + self, + name: str, + session_id: str, + config: RuntimeConfig, + **kwargs: Any, + ) -> BaseAgentRuntime: + """创建运行时实例。 + + Args: + name: 运行时名称 + session_id: 会话 ID + config: 运行时配置 + **kwargs: 额外参数 + + Returns: + 运行时实例 + + Raises: + ValueError: 运行时未注册 + """ + descriptor = self.get(name) + if descriptor is None: + available = ", ".join(self._runtimes.keys()) + raise ValueError( + f"Unknown runtime '{name}'. Available: {available}" + ) + + logger.info(f"Creating runtime: {name} for session {session_id}") + + return descriptor.runtime_class( + session_id=session_id, + config=config, + **kwargs, + ) + + def exists(self, name: str) -> bool: + """检查运行时是否已注册。 + + Args: + name: 运行时名称 + + Returns: + 是否存在 + """ + return name in self._runtimes + + +# ── 全局注册表实例 ───────────────────────────────────────── + +runtime_registry = RuntimeRegistry() + + +# ── 自动注册内置运行时 ──────────────────────────────────── + +def _register_builtin_runtimes() -> None: + """注册内置运行时。""" + # InternalSubAgentRuntime + try: + from agentpal.runtimes.internal import InternalSubAgentRuntime + + runtime_registry.register( + name="internal", + runtime_class=InternalSubAgentRuntime, + description="Built-in SubAgent runtime (local execution)", + config_schema={ + "type": "object", + "properties": { + "max_tool_rounds": { + "type": "integer", + "default": 16, + "description": "Maximum tool call rounds", + }, + "timeout_seconds": { + "type": "number", + "default": 300, + "description": "Execution timeout in seconds", + }, + }, + }, + ) + except ImportError as e: + logger.warning(f"Failed to register internal runtime: {e}") + + # HTTPAgentRuntime + try: + from agentpal.runtimes.http import HTTPAgentRuntime + + runtime_registry.register( + name="http", + runtime_class=HTTPAgentRuntime, + description="Remote HTTP Agent service (pi-mono, OpenClaw, etc.)", + config_schema={ + "type": "object", + "properties": { + "base_url": { + "type": "string", + "description": "Base URL of the remote service", + }, + "api_key": { + "type": "string", + "description": "API key for authentication", + }, + "timeout_seconds": { + "type": "number", + "default": 300, + "description": "Request timeout in seconds", + }, + }, + "required": ["base_url"], + }, + ) + except ImportError as e: + logger.warning(f"Failed to register http runtime: {e}") + + +# 自动注册 +_register_builtin_runtimes() + + +# ── 便捷函数 ────────────────────────────────────────────── + +def get_runtime( + runtime_type: str, + session_id: str, + config: RuntimeConfig | None = None, + **kwargs: Any, +) -> BaseAgentRuntime: + """获取运行时实例的便捷函数。 + + Args: + runtime_type: 运行时类型 + session_id: 会话 ID + config: 运行时配置 + **kwargs: 额外参数 + + Returns: + 运行时实例 + """ + if config is None: + config = RuntimeConfig(runtime_type=runtime_type) + + return runtime_registry.create( + name=runtime_type, + session_id=session_id, + config=config, + **kwargs, + ) + + +def list_available_runtimes() -> list[dict[str, Any]]: + """列出所有可用的运行时。 + + Returns: + 运行时信息列表 + """ + return runtime_registry.list_runtimes() diff --git a/backend/agentpal/services/__init__.py b/backend/agentpal/services/__init__.py index e69de29..11c0228 100644 --- a/backend/agentpal/services/__init__.py +++ b/backend/agentpal/services/__init__.py @@ -0,0 +1,15 @@ +from agentpal.services.config_file import ConfigFileService +from agentpal.services.cron_scheduler import cron_scheduler +from agentpal.services.notification_bus import notification_bus +from agentpal.services.session_event_bus import session_event_bus +from agentpal.services.skill_event_bus import skill_event_bus +from agentpal.services.task_event_bus import task_event_bus + +__all__ = [ + "ConfigFileService", + "cron_scheduler", + "notification_bus", + "session_event_bus", + "skill_event_bus", + "task_event_bus", +] diff --git a/backend/agentpal/services/task_event_bus.py b/backend/agentpal/services/task_event_bus.py new file mode 100644 index 0000000..1f789c2 --- /dev/null +++ b/backend/agentpal/services/task_event_bus.py @@ -0,0 +1,90 @@ +"""TaskEventBus — SubAgent 任务事件总线。 + +向 SSE 订阅方广播 SubAgent 任务执行过程中的事件。 + +使用方式: + # 订阅者(SSE 端点) + queue = task_event_bus.subscribe(task_id) + event = await asyncio.wait_for(queue.get(), timeout=30.0) + + # 发布者(SubAgent 执行器) + await task_event_bus.emit(task_id, "task.progress", {"pct": 50, "message": "Processing..."}) +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from loguru import logger + + +class TaskEventBus: + """Per-task asyncio 事件总线。 + + 每个 task 可有多个并发 SSE 订阅者。 + 当 SubAgent 任务执行过程中产生事件时,向所有订阅者广播。 + """ + + def __init__(self) -> None: + # task_id -> list of subscriber queues + self._subs: dict[str, list[asyncio.Queue]] = {} + + def subscribe(self, task_id: str) -> asyncio.Queue: + """订阅指定 task 的事件,返回消息队列。""" + q: asyncio.Queue = asyncio.Queue(maxsize=256) + self._subs.setdefault(task_id, []).append(q) + logger.debug(f"TaskEventBus: 新订阅者加入 task={task_id}") + return q + + def unsubscribe(self, task_id: str, queue: asyncio.Queue) -> None: + """取消订阅,客户端断开时调用。""" + subs = self._subs.get(task_id, []) + try: + subs.remove(queue) + except ValueError: + pass + if not subs: + self._subs.pop(task_id, None) + logger.debug(f"TaskEventBus: 无订阅者,清理 task={task_id}") + + async def emit(self, task_id: str, event_type: str, event_data: dict[str, Any] | None = None, message: str | None = None) -> None: + """向指定 task 的所有订阅者广播事件。 + + Args: + task_id: 任务 ID + event_type: 事件类型(如 "task.progress", "tool.start") + event_data: 事件负载数据 + message: 人类可读的消息描述 + + 队列满时丢弃本次事件(防止慢速客户端阻塞发布者)。 + """ + event = { + "event_type": event_type, + "event_data": event_data or {}, + "message": message, + } + for q in list(self._subs.get(task_id, [])): + try: + q.put_nowait(event) + except asyncio.QueueFull: + logger.warning(f"TaskEventBus: 队列已满,丢弃事件 task={task_id} event={event_type}") + pass + + async def emit_to_many(self, task_ids: list[str], event_type: str, event_data: dict[str, Any] | None = None, message: str | None = None) -> None: + """向多个任务广播同一事件(用于群发场景)。""" + for tid in task_ids: + await self.emit(tid, event_type, event_data, message) + + @property + def subscriber_count(self) -> int: + """当前全部 task 的订阅者总数(调试用)。""" + return sum(len(v) for v in self._subs.values()) + + def get_subscriber_count_for_task(self, task_id: str) -> int: + """获取指定 task 的订阅者数量(调试用)。""" + return len(self._subs.get(task_id, [])) + + +# 全局单例 +task_event_bus = TaskEventBus() diff --git a/backend/agentpal/tools/builtin.py b/backend/agentpal/tools/builtin.py index 0a21103..52c0fe8 100644 --- a/backend/agentpal/tools/builtin.py +++ b/backend/agentpal/tools/builtin.py @@ -730,6 +730,9 @@ async def dispatch_sub_agent( task_type: str = "", agent_name: str = "", wait_seconds: int = 120, + blocking: bool = False, + runtime_type: str = "internal", + runtime_config: dict[str, Any] | None = None, ) -> ToolResponse: """将子任务委托给专业 SubAgent 执行,并等待结果返回。 @@ -750,11 +753,11 @@ async def dispatch_sub_agent( from agentpal.agents.personal_assistant import _default_model_config from agentpal.agents.registry import SubAgentRegistry - from agentpal.agents.sub_agent import SubAgent from agentpal.database import AsyncSessionLocal - from agentpal.memory.factory import MemoryFactory from agentpal.models.agent import SubAgentDefinition from agentpal.models.session import SubAgentTask, TaskStatus + from agentpal.runtimes.base import ExecutionResult, RuntimeConfig + from agentpal.runtimes.registry import get_runtime async def _run() -> str: task_id = str(uuid.uuid4()) @@ -790,32 +793,75 @@ async def _run() -> str: agent_name=resolved_agent_name, task_type=task_type or None, execution_log=[], - meta={}, + meta={"blocking": blocking, "wait_seconds": wait_seconds}, ) db.add(task) await db.commit() - # 3. 运行 SubAgent(在主 event loop 内协作执行,不阻塞其他请求) - sub_memory = MemoryFactory.create("buffer") - sub_agent = SubAgent( + # 3. 构建运行时配置 + rt_config_data = { + "runtime_type": runtime_type, + "model_config": model_config, + "max_tool_rounds": max_tool_rounds, + "timeout_seconds": float(wait_seconds), + } + if runtime_config: + rt_config_data["extra"] = runtime_config + + rt_config = RuntimeConfig(**rt_config_data) + + # 4. 获取运行时实例并使用其执行任务 + runtime = get_runtime( + runtime_type=runtime_type, session_id=sub_session_id, - memory=sub_memory, - task=task, + config=rt_config, db=db, - model_config=model_config, - role_prompt=role_prompt, - max_tool_rounds=max_tool_rounds, parent_session_id=parent_session_id, + task=task, ) - result = await sub_agent.run(task_prompt) - await db.commit() - agent_label = resolved_agent_name or "SubAgent" - if task.status == TaskStatus.DONE: - return f"[{agent_label} 执行完毕]\n\n{result}" - else: - error_info = f"\n\n错误: {task.error}" if task.error else "" - return f"[{agent_label} 执行失败]\n任务 ID: {task_id}{error_info}" + try: + # 初始化运行时 + await runtime._initialize() + + if blocking: + # Blocking mode: execute and wait for completion + result: ExecutionResult = await runtime.execute(task_prompt) + + # 从数据库重新加载最新状态 + await db.refresh(task) + + agent_label = resolved_agent_name or "SubAgent" + if result.success and task.status == TaskStatus.DONE: + return f"[{agent_label}] DONE\n\n{result.output}" + elif task.status == TaskStatus.INPUT_REQUIRED: + question = task.meta.get("input_request", {}).get("question", "需要您的输入") + return f"[{agent_label}] INPUT_REQUIRED\n任务 ID: {task_id}\n问题:{question}\n\n请提供所需输入以继续执行。" + else: + error_detail = f"\n\n错误:{task.error}" if task.error else (f"\n\n错误:{result.error}" if result.error else "") + return f"[{agent_label}] FAILED\n任务 ID: {task_id}{error_detail}" + else: + # Non-blocking mode: start background task and return immediately + async def run_in_background(): + try: + await runtime.execute(task_prompt) + except Exception as e: + task.status = TaskStatus.FAILED + task.error = str(e) + await db.commit() + + asyncio.create_task(run_in_background()) + + return ( + f"[SubAgent] STARTED\n" + f"任务 ID: {task_id}\n" + f"执行者:{resolved_agent_name or 'Auto'}\n" + f"状态:{TaskStatus.RUNNING.value}\n\n" + f"任务正在后台运行,可在 /tasks 页面查看进度。" + ) + finally: + # 清理运行时资源 + await runtime._cleanup() try: result_text = await asyncio.wait_for(_run(), timeout=wait_seconds) @@ -998,8 +1044,121 @@ def execute_python_code( { "name": "dispatch_sub_agent", "func": dispatch_sub_agent, - "description": "将子任务委托给专业 SubAgent(coder/researcher)执行,等待结果返回", + "description": "将子任务委托给专业 SubAgent(coder/researcher)执行,支持阻塞/非阻塞模式", "icon": "Bot", "dangerous": False, }, ] + +# ── 12. produce_artifact ────────────────────────────────── + + +def produce_artifact( + name: str, + content: str | None = None, + artifact_type: str = "text", + file_path: str | None = None, + mime_type: str | None = None, + extra: dict[str, Any] | None = None, +) -> ToolResponse: + """创建任务产出物(代码文件、报告、图表等)。 + + SubAgent 使用此工具将执行过程中的中间产物或最终成果保存下来, + 供用户查看和下载。 + + Args: + name: 产出物名称(例如:"analysis_report.md"、"generated_code.py") + content: 文本内容(text 类型时使用) + artifact_type: 产出物类型:file/text/image/data(默认"text") + file_path: 文件路径(file 类型时使用,可以是绝对路径或相对于工作空间的相对路径) + mime_type: MIME 类型(例如:"text/markdown"、"image/png"、"application/json") + extra: 额外元数据(JSON 对象) + + Returns: + 产出物创建结果,包括 artifact_id 和访问路径 + + Example: + # 创建文本报 + produce_artifact( + name="analysis_report.md", + content="# Analysis Report\\n\\n...", + artifact_type="text", + mime_type="text/markdown" + ) + + # 保存文件 + produce_artifact( + name="generated_script.py", + file_path="/path/to/script.py", + artifact_type="file", + mime_type="text/x-python" + ) + """ + import uuid + + from agentpal.database import get_sync_db + from agentpal.models.session import TaskArtifact + + try: + # 获取当前任务 ID(从系统环境变量或线程上下文) + import os + + task_id = os.environ.get("AGENTPAL_CURRENT_TASK_ID") + if not task_id: + return _text_response("无法获取当前任务 ID,请在 dispatch_sub_agent 回调中使用此工具") + + # 确定 artifact_type 和 mime_type + if artifact_type == "text" and not mime_type: + mime_type = "text/plain" + elif artifact_type == "file" and not mime_type and file_path: + mime_type, _ = mimetypes.guess_type(file_path) + + # 计算文件大小 + size_bytes = None + if content: + size_bytes = len(content.encode("utf-8")) + elif file_path: + p = Path(file_path).expanduser() + if p.exists(): + size_bytes = p.stat().st_size + + # 保存到数据库 + artifact_id = str(uuid.uuid4()) + with get_sync_db() as db: + artifact = TaskArtifact( + id=artifact_id, + task_id=task_id, + name=name, + artifact_type=artifact_type, + content=content, + file_path=file_path, + mime_type=mime_type, + size_bytes=size_bytes, + extra=extra or {}, + ) + db.add(artifact) + db.commit() + + # 发射事件 + import asyncio + + from agentpal.services.task_event_bus import task_event_bus + + asyncio.create_task( + task_event_bus.emit( + task_id, + "task.artifact_created", + {"artifact_id": artifact_id, "name": name, "type": artifact_type}, + f"已创建产出物:{name}", + ) + ) + + return _text_response( + f"[产出物已创建]\n名称:{name}\nID: {artifact_id}\n类型:{artifact_type}\n大小:{size_bytes or 0} 字节" + ) + + except Exception as e: + return _text_response(f"产出物创建失败:{e}") + + + diff --git a/backend/tests/integration/test_runtimes_integration.py b/backend/tests/integration/test_runtimes_integration.py new file mode 100644 index 0000000..3f2bdc9 --- /dev/null +++ b/backend/tests/integration/test_runtimes_integration.py @@ -0,0 +1,165 @@ +"""Agent Runtime 集成测试。 + +验证运行时架构与现有系统的集成: +- InternalSubAgentRuntime 与数据库会话的集成 +- dispatch_sub_agent 使用新运行时架构 +""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from agentpal.runtimes.base import ExecutionResult, RuntimeConfig +from agentpal.runtimes.internal import InternalSubAgentRuntime + + +@pytest.mark.asyncio +class TestInternalSubAgentRuntimeIntegration: + """InternalSubAgentRuntime 与数据库集成测试。""" + + @pytest.fixture + async def db_session(self): + """创建异步数据库会话。""" + from agentpal.database import AsyncSessionLocal, init_db + + # 确保数据库已初始化 + await init_db() + + async with AsyncSessionLocal() as db: + yield db + + @pytest.fixture + def runtime_config(self): + """创建运行时配置。""" + return RuntimeConfig( + runtime_type="internal", + model_config={"model": "claude-sonnet-4-5-20250929"}, + max_tool_rounds=5, + timeout_seconds=60.0, + ) + + async def test_runtime_initialization_with_db( + self, + db_session, + runtime_config, + ): + """运行时应该能够用数据库会话正确初始化。""" + from agentpal.models.session import SubAgentTask, TaskStatus + + # 创建一个测试任务 + task = SubAgentTask( + id="test-task-integration", + parent_session_id="test-parent", + sub_session_id="sub:test-parent:integration", + task_prompt="Test task prompt", + status=TaskStatus.PENDING, + agent_name=None, + ) + db_session.add(task) + await db_session.commit() + + # 创建运行时 + runtime = InternalSubAgentRuntime( + session_id=task.sub_session_id, + config=runtime_config, + db=db_session, + parent_session_id=task.parent_session_id, + task=task, + ) + + # 初始化和清理都应该正常工作 + await runtime._initialize() + assert runtime._status.value == "idle" + assert runtime._sub_agent is not None + + await runtime._cleanup() + + async def test_runtime_execute_with_mocked_subagent( + self, + db_session, + runtime_config, + ): + """运行时 execute 方法应与 Mock SubAgent 协同工作。""" + from agentpal.models.session import SubAgentTask, TaskStatus + + task = SubAgentTask( + id="test-execute-task", + parent_session_id="test-parent", + sub_session_id="sub:test-parent:execute", + task_prompt="Analyze this data", + status=TaskStatus.PENDING, + agent_name="researcher", + ) + db_session.add(task) + await db_session.commit() + + runtime = InternalSubAgentRuntime( + session_id=task.sub_session_id, + config=runtime_config, + db=db_session, + parent_session_id=task.parent_session_id, + task=task, + ) + + await runtime._initialize() + + # Mock _execute_core 方法(因为真实的执行需要 API key) + with patch.object(runtime, '_execute_core', new_callable=AsyncMock) as mock_execute: + mock_execute.return_value = ExecutionResult( + success=True, + output="Analysis complete", + metadata={"rounds": 3}, + ) + + result = await runtime.execute("Analyze this data") + + assert result.success is True + assert result.output == "Analysis complete" + assert result.metadata["rounds"] == 3 + + await runtime._cleanup() + + +class TestDispatchSubAgentWithRuntime: + """dispatch_sub_agent 与新运行时架构集成测试。""" + + @pytest.mark.asyncio + async def test_dispatch_uses_runtime_registry(self): + """dispatch_sub_agent 应通过 runtime_registry 获取运行时。""" + from agentpal.tools.builtin import dispatch_sub_agent + from agentpal.runtimes.registry import runtime_registry + + # 验证 internal 运行时已注册 + assert runtime_registry.exists("internal") is True + + # 验证可以通过 get_runtime 获取 internal 运行时 + from agentpal.runtimes.internal import InternalSubAgentRuntime + from agentpal.runtimes.base import RuntimeConfig + + with patch("agentpal.database.AsyncSessionLocal", autospec=True): + config = RuntimeConfig(runtime_type="internal") + runtime = runtime_registry.create( + name="internal", + session_id="test-session", + config=config, + db=MagicMock(), + parent_session_id="parent", + task=MagicMock(), + ) + assert isinstance(runtime, InternalSubAgentRuntime) + + def test_dispatch_signature_accepts_runtime_params(self): + """dispatch_sub_agent 函数签名应接受运行时相关参数。""" + import inspect + from agentpal.tools.builtin import dispatch_sub_agent + + sig = inspect.signature(dispatch_sub_agent) + params = list(sig.parameters.keys()) + + assert "runtime_type" in params + assert "runtime_config" in params + + # 验证默认值 + assert sig.parameters["runtime_type"].default == "internal" + assert sig.parameters["runtime_config"].default is None diff --git a/backend/tests/unit/test_agents/test_runtimes.py b/backend/tests/unit/test_agents/test_runtimes.py new file mode 100644 index 0000000..c8a7aeb --- /dev/null +++ b/backend/tests/unit/test_agents/test_runtimes.py @@ -0,0 +1,389 @@ +"""Agent Runtime 单元测试。 + +覆盖范围: +- BaseAgentRuntime 抽象基类接口 +- RuntimeConfig 和 ExecutionResult 数据类 +- RuntimeStatus 枚举 +- RuntimeRegistry 注册表 +- get_runtime 便捷函数 +""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from agentpal.runtimes.base import ( + BaseAgentRuntime, + RuntimeConfig, + RuntimeStatus, + ExecutionResult, +) +from agentpal.runtimes.registry import ( + RuntimeDescriptor, + runtime_registry, + get_runtime, + list_available_runtimes, +) + + +# ── BaseAgentRuntime 抽象基类 ───────────────────────────────────── + + +class TestBaseAgentRuntime: + """BaseAgentRuntime 抽象基类接口测试。""" + + def test_abstract_methods_defined(self): + """BaseAgentRuntime 定义了必要的抽象方法。""" + from abc import ABC + + assert issubclass(BaseAgentRuntime, ABC) + + # 验证抽象方法存在 + assert hasattr(BaseAgentRuntime, '_initialize') + assert hasattr(BaseAgentRuntime, '_execute_core') + assert hasattr(BaseAgentRuntime, '_stream_core') + assert hasattr(BaseAgentRuntime, '_cleanup') + assert hasattr(BaseAgentRuntime, '_cancel') + + def test_cannot_instantiate_abstract_base(self): + """不能直接实例化 BaseAgentRuntime。""" + with pytest.raises(TypeError): + BaseAgentRuntime( + session_id="test", + config=RuntimeConfig(runtime_type="test"), + ) + + def test_concrete_class_can_extend(self): + """具体子类可以实现抽象方法。""" + + class ConcreteRuntime(BaseAgentRuntime): + async def _initialize(self) -> None: + self._status = RuntimeStatus.IDLE + + async def _execute_core(self, task_prompt: str, **kwargs): + return ExecutionResult(success=True, output="done") + + async def _stream_core(self, task_prompt: str, **kwargs): + yield {"type": "complete", "data": "done"} + + async def _cleanup(self) -> None: + pass + + async def _cancel(self) -> None: + self._status = RuntimeStatus.IDLE + + config = RuntimeConfig(runtime_type="concrete") + runtime = ConcreteRuntime(session_id="test", config=config) + + assert runtime.session_id == "test" + assert runtime.config.runtime_type == "concrete" + + +# ── RuntimeConfig 数据类 ────────────────────────────────────────── + + +class TestRuntimeConfig: + """RuntimeConfig 数据类测试。""" + + def test_minimal_config(self): + """最小配置。""" + config = RuntimeConfig(runtime_type="internal") + + assert config.runtime_type == "internal" + # model_config 和 extra 默认为空 dict,max_tool_rounds 和 timeout_seconds 有默认值 + assert config.model_config is None or config.model_config == {} + assert isinstance(config.max_tool_rounds, int) # 有默认值 + assert isinstance(config.timeout_seconds, float) # 有默认值 + + def test_full_config(self): + """完整配置。""" + config = RuntimeConfig( + runtime_type="http", + model_config={"model": "claude-sonnet-4-5-20250929"}, + max_tool_rounds=10, + timeout_seconds=600.0, + extra={"base_url": "http://localhost:8000"}, + ) + + assert config.runtime_type == "http" + assert config.model_config["model"] == "claude-sonnet-4-5-20250929" + assert config.max_tool_rounds == 10 + assert config.timeout_seconds == 600.0 + assert config.extra["base_url"] == "http://localhost:8000" + + +# ── ExecutionResult 数据类 ──────────────────────────────────────── + + +class TestExecutionResult: + """ExecutionResult 数据类测试。""" + + def test_success_result(self): + """成功的执行结果。""" + result = ExecutionResult( + success=True, + output="任务完成", + metadata={"elapsed": 1.5}, + ) + + assert result.success is True + assert result.output == "任务完成" + assert result.metadata["elapsed"] == 1.5 + assert result.error is None + + def test_error_result(self): + """失败的执行结果。""" + result = ExecutionResult( + success=False, + error="Something went wrong", + ) + + assert result.success is False + assert result.output == "" + assert result.error == "Something went wrong" + assert result.metadata == {} + + def test_default_values(self): + """默认值。""" + result = ExecutionResult(success=False) + + assert result.success is False + assert result.output == "" + assert result.error is None + assert result.metadata == {} + + +# ── RuntimeStatus 枚举 ──────────────────────────────────────────── + + +class TestRuntimeStatus: + """RuntimeStatus 枚举测试。""" + + def test_status_values(self): + """验证状态枚举值。""" + assert RuntimeStatus.IDLE.value == "idle" + assert RuntimeStatus.RUNNING.value == "running" + assert RuntimeStatus.PAUSED.value == "paused" + assert RuntimeStatus.ERROR.value == "error" + + +# ── RuntimeDescriptor ───────────────────────────────────────────── + + +class TestRuntimeDescriptor: + """RuntimeDescriptor 数据类测试。""" + + def test_descriptor_creation(self): + """RuntimeDescriptor 可以正常创建。""" + mock_class = MagicMock() + desc = RuntimeDescriptor( + name="test-runtime", + runtime_class=mock_class, + description="Test runtime", + ) + + assert desc.name == "test-runtime" + assert desc.runtime_class is mock_class + assert desc.description == "Test runtime" + + +# ── RuntimeRegistry ─────────────────────────────────────────────── + + +class TestRuntimeRegistry: + """运行时注册表测试。""" + + @pytest.fixture + def clean_registry(self): + """每个测试前清理注册表。""" + # 保存原有注册表 + saved = runtime_registry._runtimes.copy() + runtime_registry._runtimes.clear() + yield + # 恢复原有注册表 + runtime_registry._runtimes.clear() + runtime_registry._runtimes.update(saved) + + def test_register_runtime(self, clean_registry): + """register 应成功添加运行时。""" + mock_runtime_class = MagicMock(spec=type) + + runtime_registry.register( + name="mock-runtime", + runtime_class=mock_runtime_class, + description="Mock runtime for testing", + ) + + descriptor = runtime_registry.get("mock-runtime") + assert descriptor is not None + assert descriptor.name == "mock-runtime" + assert descriptor.runtime_class is mock_runtime_class + + def test_unregister_runtime(self, clean_registry): + """unregister 应移除运行时。""" + mock_runtime_class = MagicMock(spec=type) + runtime_registry.register("temp-runtime", mock_runtime_class) + + runtime_registry.unregister("temp-runtime") + + assert runtime_registry.get("temp-runtime") is None + + def test_get_existing_runtime(self, clean_registry): + """get 应返回已注册的运行时描述符。""" + mock_runtime_class = MagicMock(spec=type) + runtime_registry.register("test-runtime", mock_runtime_class) + + result = runtime_registry.get("test-runtime") + + assert result is not None + assert result.name == "test-runtime" + + def test_get_nonexistent_runtime_returns_none(self, clean_registry): + """get 未注册的运行时应返回 None。""" + result = runtime_registry.get("non-existent-runtime") + + assert result is None + + def test_list_runtimes(self, clean_registry): + """list_runtimes 应返回所有已注册的运行时信息。""" + mock_runtime1 = MagicMock(spec=type) + mock_runtime2 = MagicMock(spec=type) + + runtime_registry.register("runtime-a", mock_runtime1, description="First") + runtime_registry.register("runtime-b", mock_runtime2, description="Second") + + result = runtime_registry.list_runtimes() + + assert len(result) == 2 + names = [r["name"] for r in result] + assert "runtime-a" in names + assert "runtime-b" in names + + def test_exists_method(self, clean_registry): + """exists 应检查运行时是否已注册。""" + mock_runtime = MagicMock(spec=type) + runtime_registry.register("check-runtime", mock_runtime) + + assert runtime_registry.exists("check-runtime") is True + assert runtime_registry.exists("unknown") is False + + def test_create_runtime(self, clean_registry): + """create 应创建运行时实例。""" + mock_instance = MagicMock() + mock_runtime_class = MagicMock(return_value=mock_instance) + runtime_registry.register("factory-test", mock_runtime_class) + + config = RuntimeConfig(runtime_type="factory-test") + result = runtime_registry.create( + name="factory-test", + session_id="test-session", + config=config, + ) + + assert result is mock_instance + mock_runtime_class.assert_called_once() + + def test_create_unknown_runtime_raises_value_error(self, clean_registry): + """create 未知运行时应抛出 ValueError。""" + config = RuntimeConfig(runtime_type="unknown") + + with pytest.raises(ValueError, match="Unknown runtime"): + runtime_registry.create( + name="unknown-runtime", + session_id="test-session", + config=config, + ) + + def test_auto_registration_on_import(self): + """导入模块时应自动注册内置运行时。""" + # internal 和 http 已在 registry.py 导入时自动注册 + assert runtime_registry.exists("internal") is True + assert runtime_registry.exists("http") is True + + +# ── get_runtime 便捷函数 ────────────────────────────────────────── + + +class TestGetRuntime: + """get_runtime 便捷函数测试。""" + + def test_get_internal_runtime(self): + """get_runtime('internal') 应返回 InternalSubAgentRuntime 实例。""" + from agentpal.runtimes.internal import InternalSubAgentRuntime + + with patch("agentpal.database.AsyncSessionLocal", autospec=True): + config = RuntimeConfig(runtime_type="internal") + runtime = get_runtime( + runtime_type="internal", + session_id="test-session", + config=config, + ) + + assert isinstance(runtime, InternalSubAgentRuntime) + assert runtime.session_id == "test-session" + + def test_get_unknown_runtime_raises_value_error(self): + """get_runtime 未知运行时应抛出 ValueError。""" + config = RuntimeConfig(runtime_type="unknown") + + with pytest.raises(ValueError, match="Unknown runtime"): + get_runtime( + runtime_type="unknown", + session_id="test-session", + config=config, + ) + + +# ── list_available_runtimes 便捷函数 ────────────────────────────── + + +class TestListAvailableRuntimes: + """list_available_runtimes 便捷函数测试。""" + + def test_returns_list_of_dicts(self): + """应返回字典列表。""" + result = list_available_runtimes() + + assert isinstance(result, list) + # 至少应有 internal 运行时 + names = [r["name"] for r in result] + assert "internal" in names + + +# ── HTTPAgentRuntime 基础测试 ───────────────────────────────────── + + +class TestHTTPAgentRuntimeBasics: + """HTTPAgentRuntime 基础测试。""" + + def test_import_does_not_fail(self): + """HTTPAgentRuntime 应该可以导入。""" + from agentpal.runtimes.http import HTTPAgentRuntime + + assert HTTPAgentRuntime is not None + + def test_class_extends_base(self): + """HTTPAgentRuntime 应继承 BaseAgentRuntime。""" + from agentpal.runtimes.http import HTTPAgentRuntime + + assert issubclass(HTTPAgentRuntime, BaseAgentRuntime) + + +# ── InternalSubAgentRuntime 基础测试 ────────────────────────────── + + +class TestInternalSubAgentRuntimeBasics: + """InternalSubAgentRuntime 基础测试。""" + + def test_import_does_not_fail(self): + """InternalSubAgentRuntime 应该可以导入。""" + from agentpal.runtimes.internal import InternalSubAgentRuntime + + assert InternalSubAgentRuntime is not None + + def test_class_extends_base(self): + """InternalSubAgentRuntime 应继承 BaseAgentRuntime。""" + from agentpal.runtimes.internal import InternalSubAgentRuntime + + assert issubclass(InternalSubAgentRuntime, BaseAgentRuntime) diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 7a9b890..67e3167 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -27,6 +27,7 @@ export interface TaskStatusResponse { error: string | null; agent_name: string | null; task_type: string | null; + task_prompt: string | null; priority: number; retry_count: number; max_retries: number; @@ -64,13 +65,23 @@ export interface SubTaskSummary { id: string; sub_session_id: string; task_prompt: string; - status: "pending" | "running" | "done" | "failed" | "cancelled"; + status: "pending" | "running" | "done" | "failed" | "cancelled" | "input_required"; agent_name: string | null; task_type: string | null; created_at: string; finished_at: string | null; } +export interface TaskArtifact { + id: string; + task_id: string; + artifact_type: string; + content: string; + title: string; + metadata: Record | null; + created_at: string; +} + export interface HistoryMessageMeta { thinking?: string; tool_calls?: Array<{ @@ -160,6 +171,11 @@ export async function getSessionSubTasks(sessionId: string): Promise { + const { data } = await api.get(`/agent/tasks/${taskId}/artifacts`); + return data; +} + // ── Workspace API ───────────────────────────────────────── export interface WorkspaceFileContent { @@ -483,4 +499,11 @@ export async function resolveToolGuard( await api.post(`/agent/tool-guard/${requestId}/resolve`, { approved }); } +// ── Task Cancel API ──────────────────────────────────────── + +export async function cancelTask(taskId: string, reason?: string): Promise<{ task_id: string; status: string; message: string }> { + const { data } = await api.post(`/tasks/${taskId}/cancel`, { reason }); + return data; +} + export default api; diff --git a/frontend/src/components/TaskArtifactViewer.tsx b/frontend/src/components/TaskArtifactViewer.tsx new file mode 100644 index 0000000..531b536 --- /dev/null +++ b/frontend/src/components/TaskArtifactViewer.tsx @@ -0,0 +1,134 @@ +import { useState } from "react"; +import { FileText, Code, FileOutput, ChevronDown, ChevronRight, Copy, Check } from "lucide-react"; +import type { TaskArtifact } from "../api"; + +interface TaskArtifactViewerProps { + artifacts: TaskArtifact[]; +} + +const ARTIFACT_TYPE_CONFIG: Record< + string, + { icon: typeof FileText; color: string; bg: string; label: string } +> = { + code: { icon: Code, color: "text-blue-600", bg: "bg-blue-50", label: "代码" }, + doc: { icon: FileText, color: "text-green-600", bg: "bg-green-50", label: "文档" }, + analysis: { icon: FileOutput, color: "text-purple-600", bg: "bg-purple-50", label: "分析" }, + summary: { icon: FileText, color: "text-orange-600", bg: "bg-orange-50", label: "总结" }, + report: { icon: FileOutput, color: "text-indigo-600", bg: "bg-indigo-50", label: "报告" }, +}; + +function ArtifactIcon({ type }: { type: string }) { + const cfg = ARTIFACT_TYPE_CONFIG[type] ?? ARTIFACT_TYPE_CONFIG.doc; + const Icon = cfg.icon; + return ( + + + + ); +} + +function ArtifactTypeBadge({ type }: { type: string }) { + const cfg = ARTIFACT_TYPE_CONFIG[type] ?? ARTIFACT_TYPE_CONFIG.doc; + return ( + + {cfg.label} + + ); +} + +function ArtifactContent({ artifact }: { artifact: TaskArtifact }) { + const [copied, setCopied] = useState(false); + const [expanded, setExpanded] = useState(true); + + const handleCopy = async () => { + await navigator.clipboard.writeText(artifact.content); + setCopied(true); + setTimeout(() => setCopied(false), 2000); + }; + + const isCode = artifact.artifact_type === "code"; + + return ( + + + setExpanded(!expanded)} + className="flex items-center gap-1 text-xs text-gray-600 hover:text-gray-900" + > + {expanded ? : } + {expanded ? "收起" : "展开"} + + + {copied ? : } + {copied ? "已复制" : "复制"} + + + {expanded && ( + + {artifact.content} + + )} + + ); +} + +export function TaskArtifactViewer({ artifacts }: TaskArtifactViewerProps) { + const [selectedArtifact, setSelectedArtifact] = useState(null); + + if (artifacts.length === 0) { + return ( + + 暂无产出物 + + ); + } + + return ( + + + + 任务产出物 ({artifacts.length}) + + + {artifacts.map((artifact) => ( + setSelectedArtifact(selectedArtifact?.id === artifact.id ? null : artifact)} + > + + + + + + {artifact.title} + + + + + {artifact.content.length} 字符 · {new Date(artifact.created_at).toLocaleString("zh-CN")} + + + {selectedArtifact?.id === artifact.id ? ( + + ) : ( + + )} + + {selectedArtifact?.id === artifact.id && ( + + )} + + ))} + + + ); +} diff --git a/frontend/src/hooks/useTaskArtifacts.ts b/frontend/src/hooks/useTaskArtifacts.ts new file mode 100644 index 0000000..6a949a4 --- /dev/null +++ b/frontend/src/hooks/useTaskArtifacts.ts @@ -0,0 +1,11 @@ +import { useQuery } from "@tanstack/react-query"; +import { getTaskArtifacts, type TaskArtifact } from "../api"; + +export function useTaskArtifacts(taskId: string | null) { + return useQuery({ + queryKey: ["task-artifacts", taskId], + queryFn: () => getTaskArtifacts(taskId!), + enabled: !!taskId, + staleTime: 30_000, + }); +} diff --git a/frontend/src/pages/SessionsPage.tsx b/frontend/src/pages/SessionsPage.tsx index de2fbe8..5695124 100644 --- a/frontend/src/pages/SessionsPage.tsx +++ b/frontend/src/pages/SessionsPage.tsx @@ -4,11 +4,11 @@ import { useQuery, useQueryClient } from "@tanstack/react-query"; import { MessageSquare, Trash2, Search, Cpu, Hash, Loader2, MessagesSquare, ChevronDown, ChevronRight, ExternalLink, - Bot, CheckCircle2, XCircle, AlertCircle, Timer, + Bot, CheckCircle2, XCircle, AlertCircle, Timer, Square, } from "lucide-react"; import clsx from "clsx"; import { useAllSessions } from "../hooks/useSessions"; -import { deleteSession, getSessionSubTasks, type SessionSummary, type SubTaskSummary } from "../api"; +import { deleteSession, getSessionSubTasks, cancelTask, type SessionSummary, type SubTaskSummary } from "../api"; function relativeTime(iso: string): string { const diff = Date.now() - new Date(iso).getTime(); @@ -41,9 +41,10 @@ const STATUS_CONFIG = { cancelled: { icon: AlertCircle, color: "text-gray-400", bg: "bg-gray-50", label: "已取消" }, } as const; -function SubTaskItem({ task }: { task: SubTaskSummary }) { +function SubTaskItem({ task, onCancel }: { task: SubTaskSummary; onCancel?: (taskId: string) => void }) { const cfg = STATUS_CONFIG[task.status as keyof typeof STATUS_CONFIG] ?? STATUS_CONFIG.pending; const Icon = cfg.icon; + const isCancellable = task.status === "running" || task.status === "pending" || task.status === "input_required"; return ( @@ -81,11 +82,22 @@ function SubTaskItem({ task }: { task: SubTaskSummary }) { {task.task_prompt} + + {/* Cancel button */} + {isCancellable && onCancel && ( + onCancel(task.id)} + className="ml-1 p-1 rounded text-gray-400 hover:text-red-500 hover:bg-red-50 transition-colors" + title="取消任务" + > + + + )} ); } -function SubTaskList({ sessionId }: { sessionId: string }) { +function SubTaskList({ sessionId, onCancel }: { sessionId: string; onCancel?: (taskId: string) => void }) { const { data: tasks = [], isLoading } = useQuery({ queryKey: ["session-sub-tasks", sessionId], queryFn: () => getSessionSubTasks(sessionId), @@ -107,7 +119,7 @@ function SubTaskList({ sessionId }: { sessionId: string }) { return ( {tasks.map((t) => ( - + ))} ); @@ -127,9 +139,27 @@ function SessionRow({ deleting: boolean; }) { const [expanded, setExpanded] = useState(false); + const queryClient = useQueryClient(); + const [cancellingId, setCancellingId] = useState(null); const isDingtalk = session.channel === "dingtalk"; const hasSubTasks = session.sub_tasks_count > 0; + const handleCancelTask = async (taskId: string) => { + if (!confirm("确定要取消此任务吗?")) return; + setCancellingId(taskId); + try { + await cancelTask(taskId); + // Refresh session list and sub-tasks + queryClient.invalidateQueries({ queryKey: ["sessions"] }); + queryClient.invalidateQueries({ queryKey: ["session-sub-tasks", session.id] }); + } catch (err) { + console.error("Failed to cancel task:", err); + alert("取消任务失败,请重试"); + } finally { + setCancellingId(null); + } + }; + return ( - + )} diff --git a/frontend/src/pages/TasksPage.tsx b/frontend/src/pages/TasksPage.tsx index 8b1d0dc..6aa0918 100644 --- a/frontend/src/pages/TasksPage.tsx +++ b/frontend/src/pages/TasksPage.tsx @@ -8,8 +8,12 @@ import { ChevronLeft, ChevronRight, Filter, + ChevronDown, + ChevronUp, } from "lucide-react"; import { useTasks } from "../hooks/useTasks"; +import { useTaskArtifacts } from "../hooks/useTaskArtifacts"; +import { TaskArtifactViewer } from "../components/TaskArtifactViewer"; import type { TaskStatusResponse, TaskListParams } from "../api"; const STATUS_CONFIG: Record< @@ -70,6 +74,8 @@ function RetryBadge({ retryCount, maxRetries }: { retryCount: number; maxRetries } function TaskCard({ task }: { task: TaskStatusResponse }) { + const [expanded, setExpanded] = useState(false); + const { data: artifacts } = useTaskArtifacts(expanded ? task.task_id : null); const cfg = STATUS_CONFIG[task.status] ?? STATUS_CONFIG.pending; const Icon = cfg.icon; @@ -115,6 +121,45 @@ function TaskCard({ task }: { task: TaskStatusResponse }) { {task.error && ( {task.error} )} + + {/* 展开/收起按钮 */} + setExpanded(!expanded)} + className="flex items-center gap-1 mt-3 text-xs text-gray-500 hover:text-gray-700" + > + {expanded ? : } + {expanded ? "收起详情" : "查看详情"} + + + {/* 展开内容 */} + {expanded && ( + + {/* 完整 Prompt */} + + 任务描述 + {task.task_prompt} + + + {/* 错误信息(如果有) */} + {task.error && ( + + 错误信息 + + {task.error} + + + )} + + {/* 产出物列表 */} + + {artifacts ? ( + + ) : ( + 加载中... + )} + + + )} ); }
+ {artifact.content} +
{task.error}
+ {task.error} +