Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 206 additions & 2 deletions backend/agentpal/agents/sub_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,47 @@ async def _update_status(
result: str | None = None,
error: str | None = None,
) -> None:
from agentpal.services.task_event_bus import task_event_bus

self._task.status = status
if result is not None:
self._task.result = result
if error is not None:
self._task.error = error
if status in (TaskStatus.DONE, TaskStatus.FAILED):

# 设置时间戳
if status == TaskStatus.RUNNING and not self._task.started_at:
self._task.started_at = datetime.now(timezone.utc)
if status in (TaskStatus.DONE, TaskStatus.FAILED, TaskStatus.CANCELLED):
self._task.finished_at = datetime.now(timezone.utc)
self._task.completed_at = self._task.finished_at

try:
await self._db.flush()
except Exception:
pass

# 发射任务状态事件
event_type_map = {
TaskStatus.RUNNING: "task.started",
TaskStatus.DONE: "task.completed",
TaskStatus.FAILED: "task.failed",
TaskStatus.CANCELLED: "task.cancelled",
TaskStatus.INPUT_REQUIRED: "task.input_required",
TaskStatus.PAUSED: "task.paused",
}
if status in event_type_map:
await task_event_bus.emit(
self._task.id,
event_type_map[status],
{
"status": status.value,
"result": result[:500] if result else None,
"error": error[:500] if error else None,
},
f"任务状态变更为 {status.value}",
)

# 任务终态时发布 WebSocket 通知
if status in (TaskStatus.DONE, TaskStatus.FAILED):
try:
Expand Down Expand Up @@ -407,9 +436,184 @@ async def _update_status(
pass # 通知失败不影响主流程

def _log(self, event_type: str, data: dict[str, Any]) -> None:
"""向执行日志追加一条记录。"""
"""向执行日志追加一条记录,并同步发射到 TaskEventBus。"""
import asyncio

from agentpal.services.task_event_bus import task_event_bus

self._execution_log.append({
"type": event_type,
"timestamp": datetime.now(timezone.utc).isoformat(),
**data,
})

# 同步发射到事件总线(用于实时 SSE 推送)
event_type_map = {
"tool_start": "tool.start",
"tool_done": "tool.complete",
"llm_response": "llm.message",
"user_message": "user.message",
}
if event_type in event_type_map:
message_map = {
"tool_start": f"开始执行工具 {data.get('name', '')}",
"tool_done": f"工具 {data.get('name', '')} 执行完成",
"llm_response": "LLM 响应",
"user_message": "用户消息",
}
# 不要在这里 await,避免阻塞主流程
asyncio.create_task(
task_event_bus.emit(
self._task.id,
event_type_map[event_type],
data,
message_map.get(event_type, ""),
)
)

async def _emit_progress(self, pct: int, message: str) -> None:
"""发射进度更新事件。"""
import asyncio

from agentpal.services.task_event_bus import task_event_bus

self._task.progress_pct = pct
self._task.progress_message = message

try:
await self._db.flush()
except Exception:
pass

asyncio.create_task(
task_event_bus.emit(
self._task.id,
"task.progress",
{"pct": pct, "message": message},
message,
)
)

async def produce_artifact(
self,
artifact_type: str,
content: str,
title: str | None = None,
extra: dict[str, Any] | None = None,
) -> str:
"""生成任务产出物并保存到数据库。

Args:
artifact_type: 产出物类型(如 "code", "doc", "analysis", "summary")
content: 产出物内容(文本或 JSON)
title: 人类可读标题
extra: 额外元数据

Returns:
产出物 ID
"""
from agentpal.models.session import TaskArtifact

artifact_id = f"art_{uuid.uuid4().hex[:12]}"
artifact = TaskArtifact(
id=artifact_id,
task_id=self._task.id,
artifact_type=artifact_type,
content=content,
title=title or f"{artifact_type}_{artifact_id[:8]}",
extra=extra or {},
)
self._db.add(artifact)
await self._db.flush()

# 发射事件
from agentpal.services.task_event_bus import task_event_bus

asyncio.create_task(
task_event_bus.emit(
self._task.id,
"artifact.created",
{
"artifact_id": artifact_id,
"artifact_type": artifact_type,
"title": artifact.title,
},
f"生成产出物:{artifact.title}",
)
)

return artifact_id

async def request_user_input(
self,
question: str,
context: str | None = None,
) -> str:
"""请求用户输入,将任务状态设置为 INPUT_REQUIRED 并等待。

Args:
question: 向用户提出的问题
context: 可选的上下文信息

Returns:
用户提供的输入内容
"""
# 保存当前问题到 meta
if self._task.meta is None:
self._task.meta = {}
self._task.meta["input_request"] = {
"question": question,
"context": context,
"requested_at": datetime.now(timezone.utc).isoformat(),
}

# 记录到执行日志
self._log("input_requested", {"question": question, "context": context})

# 更新任务状态为 INPUT_REQUIRED
await self._update_status(TaskStatus.INPUT_REQUIRED)

# 等待用户输入(轮询检查 meta)
while True:
await asyncio.sleep(2) # 每 2 秒检查一次
# 刷新任务数据
await self._db.refresh(self._task)
if self._task.meta and "user_input" in self._task.meta:
user_input = self._task.meta["user_input"]
self._log("input_received", {"input": user_input[:500]})
# 清除输入请求
del self._task.meta["input_request"]
await self._db.flush()
return user_input
# 检查任务是否被恢复执行
if self._task.status == TaskStatus.PENDING:
# 任务已恢复,获取用户输入
user_input = self._task.meta.get("user_input", "")
self._log("input_received", {"input": user_input[:500]})
return user_input
# 检查任务是否被取消
if self._task.status == TaskStatus.CANCELLED:
self._log("task_cancelled", {"reason": self._task.meta.get("cancel_reason", "用户取消")})
raise asyncio.CancelledError("任务已被取消")

async def cancel(self, reason: str = "用户取消") -> None:
"""取消正在运行的任务。

Args:
reason: 取消原因
"""
self._log("cancelling", {"reason": reason})
await self._update_status(TaskStatus.CANCELLED)

# 保存取消原因
if self._task.meta is None:
self._task.meta = {}
self._task.meta["cancel_reason"] = reason
self._task.meta["cancelled_at"] = datetime.now(timezone.utc).isoformat()

try:
await self._db.flush()
except Exception:
pass

logger.info(f"SubAgent task {self._task.id} cancelled: {reason}")
91 changes: 69 additions & 22 deletions backend/agentpal/api/v1/endpoints/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DispatchRequest(BaseModel):
agent_name: str | None = None
priority: int = Field(default=5, ge=1, le=10, description="任务优先级 1-10(10 最高)")
max_retries: int = Field(default=3, ge=0, le=10, description="最大重试次数 0-10")
blocking: bool = Field(default=False, description="是否阻塞等待任务完成")
wait_seconds: int = Field(default=120, ge=0, description="阻塞模式下的最大等待时间(秒)")


class TaskStatusResponse(BaseModel):
Expand Down Expand Up @@ -149,31 +151,76 @@ async def event_stream() -> AsyncGenerator[str, None]:

@router.post("/dispatch", response_model=TaskStatusResponse)
async def dispatch_sub_agent(req: DispatchRequest, db: AsyncSession = Depends(get_db)):
"""派遣 SubAgent 异步执行任务。"""
"""派遣 SubAgent 异步执行任务。

支持两种模式:
- 非阻塞模式(blocking=False):立即返回任务 ID 和初始状态
- 阻塞模式(blocking=True):等待任务完成后返回最终结果
"""
settings = get_settings()
memory = MemoryFactory.create(settings.memory_backend, db=db)
assistant = PersonalAssistant(session_id=req.parent_session_id, memory=memory, db=db)
task = await assistant.dispatch_sub_agent(
task_prompt=req.task_prompt,
db=db,
context=req.context,
task_type=req.task_type,
agent_name=req.agent_name,
priority=req.priority,
max_retries=req.max_retries,
)
return TaskStatusResponse(
task_id=task.id,
status=task.status,
result=task.result,
error=task.error,
agent_name=task.agent_name,
task_type=task.task_type,
priority=task.priority,
retry_count=task.retry_count,
max_retries=task.max_retries,
created_at=utc_isoformat(task.created_at),
)

if req.blocking:
# 阻塞模式:使用工具层面的 dispatch_sub_agent
from agentpal.tools.builtin import dispatch_sub_agent as builtin_dispatch

result_text = await builtin_dispatch(
task_prompt=req.task_prompt,
parent_session_id=req.parent_session_id,
task_type=req.task_type or "",
agent_name=req.agent_name or "",
wait_seconds=req.wait_seconds,
blocking=True,
)

# 从数据库获取最新的 task 记录
from sqlalchemy import select
from agentpal.models.agent import SubAgentTask

result = await db.execute(
select(SubAgentTask)
.where(SubAgentTask.parent_session_id == req.parent_session_id)
.order_by(SubAgentTask.created_at.desc())
.limit(1)
)
task = result.scalars().first()

return TaskStatusResponse(
task_id=task.id if task else "unknown",
status=task.status.value if task else "unknown",
result=result_text,
error=getattr(task, "error", None),
agent_name=getattr(task, "agent_name", None),
task_type=getattr(task, "task_type", None),
priority=getattr(task, "priority", 5),
retry_count=getattr(task, "retry_count", 0),
max_retries=getattr(task, "max_retries", 3),
created_at=utc_isoformat(task.created_at) if task else None,
)
else:
# 非阻塞模式:立即返回
task = await assistant.dispatch_sub_agent(
task_prompt=req.task_prompt,
db=db,
context=req.context,
task_type=req.task_type,
agent_name=req.agent_name,
priority=req.priority,
max_retries=req.max_retries,
)
return TaskStatusResponse(
task_id=task.id,
status=task.status,
result=task.result,
error=task.error,
agent_name=task.agent_name,
task_type=task.task_type,
priority=task.priority,
retry_count=task.retry_count,
max_retries=task.max_retries,
created_at=utc_isoformat(task.created_at),
)


@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
Expand Down
Loading