Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions docs/architecture/a2a-subagents.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# A2A Subagents

Kagent allows users to add subagents (other agents running on Kagent or remotely) as tools to a main agent, connected via the A2A protocol. This feature is enabled by `KAgentRemoteA2ATool` (`python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py`), kagent's custom replacement for the upstream `AgentTool(RemoteA2aAgent(...))` pairing.

It directly manages the A2A conversation with a remote subagent and adds three things the upstream lacks: HITL propagation, live activity viewing, and user ID forwarding.

See [human-in-the-loop.md](human-in-the-loop.md) for HITL details.

---

## How it works

Each parent A2A request creates a fresh `Runner` and fresh tool instances. `KAgentRemoteA2ATool.__init__` generates a UUID (`_last_context_id`) that is used as the A2A `context_id` for every message sent to the subagent. On the subagent side, this `context_id` becomes the session ID.

`run_async` has two phases:

- **Phase 1** (normal call): sends the request to the subagent and handles the response — returning the result, pausing for HITL if the subagent returns `input_required`, or returning an error string.
- **Phase 2** (HITL resume): reads the stored `task_id`/`context_id` from `tool_context.tool_confirmation.payload` and forwards the user's decision (approve / reject / batch / ask-user answers) to the subagent's pending task.

On success, `run_async` returns:
```python
{"result": str, "subagent_session_id": str} # normal
{"result": str, "subagent_session_id": str,
"kagent_usage_metadata": dict} # with usage
{"status": "pending", "waiting_for": "subagent_approval", ...} # HITL pause
```

`KAgentRemoteA2AToolset` is a thin `BaseToolset` wrapper whose only job is ensuring the owned `httpx.AsyncClient` is closed when the runner shuts down — ADK's cleanup path only discovers `BaseToolset` instances, not bare `BaseTool` instances.

---

## User ID and session tagging

`_SubagentInterceptor` is registered on the A2A client at construction time and injects two headers on every outgoing request:

| Header | Value | Purpose |
|---|---|---|
| `x-user-id` | parent session's user ID | Scopes the subagent DB session to the same user |
| `x-kagent-source` | `"subagent"` | Hides the session from the agent's session history sidebar |

> Interceptors must be passed to `ClientFactory.create(interceptors=[...])` — `A2AClient.add_request_middleware()` appends to a list that the transport never reads.

On the subagent side, `KAgentRequestContextBuilder` reads these headers and passes them through to `_prepare_session`, which calls `KAgentSessionService.create_session()` with `source="subagent"`. The Go layer stores this in a `Source` column and excludes such sessions from `ListSessionsForAgent`.

---

## Live activity viewing

The UI can show what a subagent is doing in a live panel before it finishes. This works because the session ID is known before the tool runs:

Before the run loop, `A2aAgentExecutor` builds a `{tool_name → session_id}` map from all tools implementing the `SubagentSessionProvider` protocol (`subagent_session_id` property). The event converter stamps this as `kagent_subagent_session_id` metadata on each `function_call` DataPart as soon as the LLM emits the call. The UI reads it immediately and begins polling `/api/sessions/{id}` every 2 seconds, rendering the subagent's events as a nested chat thread. Nesting is capped at depth 3.

The map is keyed by tool name because within one parent request, all calls to the same subagent tool intentionally share one `context_id` — giving the subagent conversation continuity across sequential invocations. A fresh `context_id` is generated on the next parent request when the runner rebuilds.

When sending session requests to Go backend, take note that:

| Session query | Includes subagent sessions? |
|---|---|
| `GET /api/sessions/agent/{ns}/{name}` | No — filtered by `source != 'subagent'` |
| `GET /api/sessions/{id}` | Yes |
| `GET /api/sessions/{id}/tasks` | Yes |
3 changes: 3 additions & 0 deletions go/api/database/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ type Session struct {
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"`

AgentID *string `gorm:"index" json:"agent_id"`
// Source indicates how this session was created.
// nil or empty = user-initiated, "subagent" = created by a parent agent's A2A call.
Source *string `gorm:"index" json:"source,omitempty"`
}

type Task struct {
Expand Down
2 changes: 2 additions & 0 deletions go/api/httpapi/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ type SessionRequest struct {
AgentRef *string `json:"agent_ref,omitempty"`
Name *string `json:"name,omitempty"`
ID *string `json:"id,omitempty"`
// Source indicates how this session was created (e.g. "subagent").
Source *string `json:"source,omitempty"`
}

// Run types
Expand Down
13 changes: 10 additions & 3 deletions go/core/internal/database/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,16 @@ func (c *clientImpl) ListTasksForSession(ctx context.Context, sessionID string)
}

func (c *clientImpl) ListSessionsForAgent(ctx context.Context, agentID string, userID string) ([]dbpkg.Session, error) {
return list[dbpkg.Session](c.db.WithContext(ctx),
Clause{Key: "agent_id", Value: agentID},
Clause{Key: "user_id", Value: userID})
var sessions []dbpkg.Session
err := c.db.WithContext(ctx).
Where("agent_id = ? AND user_id = ?", agentID, userID).
Where("source IS NULL OR source != ?", "subagent").
Order("created_at ASC").
Find(&sessions).Error
if err != nil {
return nil, fmt.Errorf("failed to list sessions for agent: %w", err)
}
return sessions, nil
}

// ListSessions lists all sessions for a user
Expand Down
6 changes: 5 additions & 1 deletion go/core/internal/database/fake/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,18 @@ func (c *InMemoryFakeClient) ListSessions(_ context.Context, userID string) ([]d
return result, nil
}

// ListSessionsForAgent lists all sessions for an agent
// ListSessionsForAgent lists all sessions for an agent, excluding subagent sessions.
func (c *InMemoryFakeClient) ListSessionsForAgent(_ context.Context, agentID string, userID string) ([]database.Session, error) {
c.mu.RLock()
defer c.mu.RUnlock()

var result []database.Session
for _, session := range c.sessions {
if session.AgentID != nil && *session.AgentID == agentID && session.UserID == userID {
// Exclude subagent sessions from the listing
if session.Source != nil && *session.Source == "subagent" {
continue
}
result = append(result, *session)
}
}
Expand Down
1 change: 1 addition & 0 deletions go/core/internal/httpserver/handlers/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func (h *SessionsHandler) HandleCreateSession(w ErrorResponseWriter, r *http.Req
Name: sessionRequest.Name,
UserID: userID,
AgentID: &agent.ID,
Source: sessionRequest.Source,
}

log.V(1).Info("Creating session in database",
Expand Down
24 changes: 22 additions & 2 deletions python/packages/kagent-adk/src/kagent/adk/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)

from ._mcp_toolset import is_anyio_cross_task_cancel_scope_error
from ._remote_a2a_tool import SubagentSessionProvider
from .converters.event_converter import convert_event_to_a2a_events, serialize_metadata_value
from .converters.part_converter import convert_a2a_part_to_genai_part, convert_genai_part_to_a2a_part
from .converters.request_converter import convert_a2a_request_to_adk_run_args
Expand Down Expand Up @@ -587,6 +588,13 @@ async def _handle_request(
real_invocation_id: str | None = None
last_usage_metadata = None

# Build a mapping of tool name -> subagent session ID once so the
# event converter can stamp it onto function_call DataParts.
subagent_session_ids: dict[str, str] = {}
for tool in getattr(runner.agent, "tools", None) or []:
if isinstance(tool, SubagentSessionProvider) and tool.subagent_session_id:
subagent_session_ids[tool.name] = tool.subagent_session_id

task_result_aggregator = TaskResultAggregator()
async with Aclosing(runner.run_async(**run_args)) as agen:
async for adk_event in agen:
Expand All @@ -603,7 +611,11 @@ async def _handle_request(
last_usage_metadata = adk_event.usage_metadata

for a2a_event in convert_event_to_a2a_events(
adk_event, invocation_context, context.task_id, context.context_id
adk_event,
invocation_context,
context.task_id,
context.context_id,
subagent_session_ids=subagent_session_ids or None,
):
Comment on lines 591 to 619
Copy link
Contributor Author

@supreme-gg-gg supreme-gg-gg Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like an edge case, but when combined with the behaviour of the above comment, it's actually an intended feature. So the behaviour overall is:

User -> Main agent -> Subagent (multiple / parallel / sequential) -> they go into the same session for that subagent, the UI shows the same session -> each subagent return a response (existing behaviour) -> main agent

Now when user does another invocation of main agent:

User -> Main agent -> Runner gets recreated (existing behaviour) -> KagentRemoteA2ATool gets reset + session ID reset -> subagent (new session) -> ...

Because within each invocation the session ID for the same subagent is the same, therefore stamping the data part with the same ID mapped by name is not a problem. 😄

This is intended because when the main agent invokes the subagent sequentially, it's because it want to continue the conversation and the flow here is designed to support this.

# Only aggregate non-partial events to avoid duplicates from streaming chunks
# Partial events are sent to frontend for display but not accumulated
Expand Down Expand Up @@ -691,10 +703,18 @@ async def _prepare_session(self, context: RequestContext, run_args: dict[str, An
session_name = text[:20] + ("..." if len(text) > 20 else "")
break

state: dict[str, Any] = {"session_name": session_name}
# Propagate source (e.g. "subagent") so the session is tagged in the DB.
source = None
if context.call_context and context.call_context.state:
source = context.call_context.state.get("kagent_source")
if source:
state["source"] = source

session = await runner.session_service.create_session(
app_name=runner.app_name,
user_id=user_id,
state={"session_name": session_name},
state=state,
session_id=session_id,
)

Expand Down
89 changes: 71 additions & 18 deletions python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import logging
import uuid
from typing import Any, Optional
from typing import Any, Optional, Protocol, runtime_checkable
from urllib.parse import urlparse

import httpx
Expand All @@ -22,6 +22,7 @@
from a2a.client.client import ClientConfig as A2AClientConfig
from a2a.client.client_factory import ClientFactory as A2AClientFactory
from a2a.client.errors import A2AClientHTTPError
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.types import (
AgentCard,
DataPart,
Expand Down Expand Up @@ -55,6 +56,27 @@

logger = logging.getLogger("kagent_adk." + __name__)

_USER_ID_CONTEXT_KEY = "x-user-id"
_SOURCE_HEADER = "x-kagent-source"
_SOURCE_SUBAGENT = "subagent"


class _SubagentInterceptor(ClientCallInterceptor):
"""
Injects the authenticated user's ID as an ``x-user-id`` HTTP header and
marks the request as originating from a subagent call via
``x-kagent-source: subagent`` on every outgoing A2A request.
"""

async def intercept(self, method_name, request_payload, http_kwargs, agent_card, context):
headers = dict(http_kwargs.get("headers", {}))
# Always mark requests from a parent agent tool as subagent-originated
headers[_SOURCE_HEADER] = _SOURCE_SUBAGENT
if context and _USER_ID_CONTEXT_KEY in context.state:
headers["x-user-id"] = context.state[_USER_ID_CONTEXT_KEY]
http_kwargs["headers"] = headers
return request_payload, http_kwargs


def _extract_text_from_task(task: Task) -> str:
"""Extract text content from a completed task's artifacts or status message."""
Expand Down Expand Up @@ -98,6 +120,17 @@ def _extract_usage_from_task(task: Task) -> Optional[dict]:
return None


@runtime_checkable
class SubagentSessionProvider(Protocol):
"""Protocol for tools that delegate to a subagent and can expose
the subagent's session ID for live activity polling."""

name: str

@property
def subagent_session_id(self) -> str | None: ...


class KAgentRemoteA2ATool(BaseTool):
"""A tool that calls a remote A2A agent and propagates HITL state."""

Expand All @@ -114,8 +147,13 @@ def __init__(
self._httpx_client = httpx_client
self._a2a_client: Optional[A2AClient] = None
self._agent_card: Optional[AgentCard] = None
# Track the context_id from the remote agent for session continuity
self._last_context_id: Optional[str] = None
# Pre-generate context_id for UI session polling
self._last_context_id: str = str(uuid.uuid4())

@property
def subagent_session_id(self) -> str | None:
"""The subagent's session ID (== context_id sent in the A2A message)."""
return self._last_context_id
Comment on lines +150 to +156
Copy link
Contributor Author

@supreme-gg-gg supreme-gg-gg Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tool instances are created fresh per request and destroyed after, so this _last_context_id UUID is effectively per-request already. For more info see the following comment, these three comments are all related to each other.


async def _ensure_client(self) -> A2AClient:
"""Lazily resolve the agent card and initialize the A2A client."""
Expand All @@ -141,15 +179,18 @@ async def _ensure_client(self) -> A2AClient:
if not self.description and self._agent_card.description:
self.description = self._agent_card.description

# Create the A2A client
# Create the A2A client.
config = A2AClientConfig(
httpx_client=self._httpx_client,
streaming=False,
polling=False,
supported_transports=[A2ATransport.jsonrpc],
)
factory = A2AClientFactory(config=config)
self._a2a_client = factory.create(self._agent_card)
self._a2a_client = factory.create(
self._agent_card,
interceptors=[_SubagentInterceptor()],
)
return self._a2a_client

def _get_declaration(self) -> genai_types.FunctionDeclaration:
Expand Down Expand Up @@ -197,16 +238,17 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte
context_id=self._last_context_id,
)

# Forward the authenticated user ID so the subagent session is scoped
# to the same user as the parent agent session.
call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id})

task: Optional[Task] = None
try:
async for response in client.send_message(request=message):
async for response in client.send_message(request=message, context=call_context):
if isinstance(response, tuple):
# ClientEvent: (Task, UpdateEvent | None)
task = response[0]
elif isinstance(response, A2AMessage):
# Direct message response (no task management)
if response.context_id:
self._last_context_id = response.context_id
return self._extract_text_from_message(response)
except A2AClientHTTPError as e:
return f"Remote agent '{self.name}' request failed: {e}"
Expand All @@ -217,10 +259,6 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte
if task is None:
return f"Remote agent '{self.name}' returned no result."

# Track context_id for future requests to the same remote agent
if task.context_id:
self._last_context_id = task.context_id

state = task.status.state if task.status else None

if state == TaskState.input_required:
Expand All @@ -235,8 +273,8 @@ async def _handle_first_call(self, args: dict[str, Any], tool_context: ToolConte
result_text = _extract_text_from_task(task)
usage = _extract_usage_from_task(task)
if usage:
return {"result": result_text, "kagent_usage_metadata": usage}
return result_text or ""
return {"result": result_text, "kagent_usage_metadata": usage, "subagent_session_id": self._last_context_id}
return {"result": result_text or "", "subagent_session_id": self._last_context_id}

def _handle_input_required(self, task: Task, tool_context: ToolContext) -> dict[str, Any]:
"""Handle a subagent that returned input_required (HITL).
Expand Down Expand Up @@ -344,9 +382,10 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any:
)

client = await self._ensure_client()
call_context = ClientCallContext(state={_USER_ID_CONTEXT_KEY: tool_context.session.user_id})
task: Optional[Task] = None
try:
async for response in client.send_message(request=decision_message):
async for response in client.send_message(request=decision_message, context=call_context):
if isinstance(response, tuple):
task = response[0]
elif isinstance(response, A2AMessage):
Expand Down Expand Up @@ -374,8 +413,13 @@ async def _handle_resume(self, tool_context: ToolContext) -> Any:
result_text = _extract_text_from_task(task)
usage = _extract_usage_from_task(task)
if usage:
return {"result": result_text, "kagent_usage_metadata": usage}
return result_text or ""
return {
"result": result_text,
"kagent_usage_metadata": usage,
"subagent_session_id": context_id or self._last_context_id,
}
# context_id from the confirmation payload is the original subagent session ID in case of interrupts
return {"result": result_text, "subagent_session_id": context_id or self._last_context_id}

@staticmethod
def _extract_text_from_message(message: A2AMessage) -> str:
Expand Down Expand Up @@ -416,6 +460,15 @@ def __init__(
httpx_client=httpx_client,
)

@property
def name(self) -> str:
return self._tool.name

@property
def subagent_session_id(self) -> str | None:
"""The subagent's session ID (== context_id sent in the A2A message)."""
return self._tool.subagent_session_id

async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]:
return [self._tool]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ async def create_session(
request_data["id"] = session_id
if state and state.get("session_name"):
request_data["name"] = state.get("session_name", "")
if state and state.get("source"):
request_data["source"] = state.get("source", "")

# Make API call to create session
response = await self.client.post(
Expand Down
Loading
Loading