feat: Add InterruptService for human-in-the-loop graph workflows#4585
feat: Add InterruptService for human-in-the-loop graph workflows#4585drahnreb wants to merge 4 commits intogoogle:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @drahnreb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances GraphAgent's capabilities by introducing a robust interrupt system, enabling human-in-the-loop workflows. It addresses the need for structured pausing, human decision-making, and dynamic resumption within complex agentic graphs. The changes provide greater control and flexibility for managing long-running or critical AI workflows, ensuring human oversight where necessary. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces a robust InterruptService for human-in-the-loop workflows within GraphAgent. The implementation is comprehensive, covering per-session isolation, bounded queues, and LLM-based reasoning for interrupt actions. The integration with the existing execution loop and telemetry is well-handled. My feedback focuses on improving the efficiency of queue peeking, enhancing the robustness of LLM response parsing, and refining the state clearing logic during 'go back' actions.
| def list_queued_messages( | ||
| self, session_id: str, page: int = 1, page_size: int = 50 | ||
| ) -> List[InterruptMessage]: | ||
| """List queued messages with ADK-style pagination. | ||
|
|
||
| Provides non-destructive peek at queue for observability and debugging. | ||
| Messages remain in queue for processing by GraphAgent. | ||
|
|
||
| Safe in asyncio's single-threaded cooperative context: drain then requeue | ||
| is atomic from the perspective of all other coroutines (no await between | ||
| drain and requeue means no other coroutine can run between them). | ||
|
|
||
| Args: | ||
| session_id: Session identifier | ||
| page: Page number (1-indexed, default: 1) | ||
| page_size: Messages per page (default: 50, max: 1000) | ||
|
|
||
| Returns: | ||
| List of InterruptMessage objects (may be empty) | ||
| """ | ||
| # Validate pagination bounds (ADK pattern) | ||
| if page < 1: | ||
| page = 1 | ||
| if page_size < 1 or page_size > 1000: | ||
| page_size = 50 | ||
|
|
||
| if session_id not in self._message_queues: | ||
| return [] | ||
|
|
||
| # Drain and requeue: safe in asyncio single-threaded context because | ||
| # there is no await between the drain and the requeue, so no other | ||
| # coroutine can interleave and observe the empty state. | ||
| queue = self._message_queues[session_id] | ||
| messages: List[InterruptMessage] = [] | ||
| while not queue.empty(): | ||
| try: | ||
| messages.append(queue.get_nowait()) | ||
| except asyncio.QueueEmpty: | ||
| break | ||
| for msg in messages: | ||
| try: | ||
| queue.put_nowait(msg) | ||
| except asyncio.QueueFull: | ||
| break | ||
|
|
||
| # Paginate | ||
| offset = (page - 1) * page_size | ||
| return messages[offset : offset + page_size] | ||
|
|
There was a problem hiding this comment.
The current implementation of list_queued_messages peeks at the queue by draining and then requeueing all messages. While this is task-safe in a single-threaded asyncio loop, it is inefficient for larger queues. A more direct way to peek at an asyncio.Queue without modifying it is to access its internal _queue attribute, which is a collections.deque.
def list_queued_messages(
self, session_id: str, page: int = 1, page_size: int = 50
) -> List[InterruptMessage]:
if page < 1:
page = 1
if page_size < 1 or page_size > 1000:
page_size = 50
if session_id not in self._message_queues:
return []
# Access internal deque for non-destructive peek
queue = self._message_queues[session_id]
messages = list(queue._queue)
offset = (page - 1) * page_size
return messages[offset : offset + page_size]There was a problem hiding this comment.
Fixed. list_queued_messages now uses direct access to asyncio.Queue's internal _queue deque for non-destructive reads. No drain/requeue needed — avoids QueueFull risk entirely. See interrupt_service.py docstring update.
| if "```" in text: | ||
| parts = text.split("```") | ||
| if len(parts) >= 3: | ||
| inner = parts[1] | ||
| if inner.startswith("json"): | ||
| inner = inner[4:] | ||
| text = inner.strip() | ||
| elif len(parts) == 2: | ||
| inner = parts[1] | ||
| if inner.startswith("json"): | ||
| inner = inner[4:] | ||
| text = inner.strip() | ||
|
|
There was a problem hiding this comment.
The markdown stripping logic assumes that the JSON block is always the first code block in the response. If the LLM provides multiple blocks or preamble text with backticks, this might fail. It's safer to use a regex to find the first JSON-like block or iterate through all code blocks.
| if "```" in text: | |
| parts = text.split("```") | |
| if len(parts) >= 3: | |
| inner = parts[1] | |
| if inner.startswith("json"): | |
| inner = inner[4:] | |
| text = inner.strip() | |
| elif len(parts) == 2: | |
| inner = parts[1] | |
| if inner.startswith("json"): | |
| inner = inner[4:] | |
| text = inner.strip() | |
| if "```" in text: | |
| import re | |
| # Find all blocks wrapped in triple backticks | |
| matches = re.findall(r"```(?:json)?\s*([\s\S]*?)\s*```", text) | |
| for match in matches: | |
| try: | |
| decision = json.loads(match) | |
| return self._validate_decision(decision) | |
| except json.JSONDecodeError: | |
| continue |
There was a problem hiding this comment.
Correction: Fixed: using output_schema=InterruptDecision (Pydantic model) ensures structured JSON output — no markdown stripping needed.
| for node_name in nodes_to_clear: | ||
| state.data.pop(node_name, None) |
There was a problem hiding this comment.
The go_back logic assumes that the key in state.data is exactly the node name. While this is true for default output mappers, users might define custom mappers that write to different keys. This logic will fail to clear state for those nodes. Consider documenting this limitation or providing a way for nodes to specify which state keys they 'own'.
There was a problem hiding this comment.
Documented: go_back clears state keys matching exact node names (default output_mapper convention). Custom output_mappers using different keys are the users responsibility.
| def _validate_condition_ast(node: ast.AST) -> None: | ||
| """Walk AST and reject any unsafe node types. | ||
|
|
||
| Only allows: comparisons, boolean ops, unary not, attribute access, | ||
| safe method calls (.get, .get_parsed, .get_str, .get_dict), | ||
| constants, and whitelisted names. | ||
|
|
||
| Raises: | ||
| ValueError: If an unsafe AST node is encountered. | ||
| """ | ||
| if isinstance(node, ast.Expression): | ||
| _validate_condition_ast(node.body) | ||
| elif isinstance(node, ast.BoolOp): | ||
| for value in node.values: | ||
| _validate_condition_ast(value) | ||
| elif isinstance(node, ast.UnaryOp): | ||
| if not isinstance(node.op, ast.Not): | ||
| raise ValueError(f"Unsafe unary operator: {type(node.op).__name__}") | ||
| _validate_condition_ast(node.operand) | ||
| elif isinstance(node, ast.Compare): | ||
| _validate_condition_ast(node.left) | ||
| for comparator in node.comparators: | ||
| _validate_condition_ast(comparator) | ||
| elif isinstance(node, ast.Call): | ||
| if isinstance(node.func, ast.Attribute): | ||
| if node.func.attr not in _SAFE_METHODS: | ||
| raise ValueError(f"Unsafe method call: .{node.func.attr}()") | ||
| _validate_condition_ast(node.func.value) | ||
| else: | ||
| raise ValueError(f"Unsafe call: {ast.dump(node.func)}") | ||
| for arg in node.args: | ||
| _validate_condition_ast(arg) | ||
| for kw in node.keywords: | ||
| _validate_condition_ast(kw.value) | ||
| elif isinstance(node, ast.Attribute): | ||
| # Allow attribute access on safe names only | ||
| _validate_condition_ast(node.value) | ||
| elif isinstance(node, ast.Subscript): | ||
| _validate_condition_ast(node.value) | ||
| _validate_condition_ast(node.slice) | ||
| elif isinstance(node, ast.Name): | ||
| if node.id not in _SAFE_NAMES: | ||
| raise ValueError(f"Unsafe name: '{node.id}'") | ||
| elif isinstance(node, ast.Constant): | ||
| pass # string, int, float, bool, None literals are safe | ||
| else: | ||
| raise ValueError(f"Unsafe expression node: {type(node).__name__}") | ||
|
|
There was a problem hiding this comment.
The AST validation for conditions is a good security measure. However, it currently blocks common utility functions like len(). If a workflow needs to route based on the size of a list in state (e.g., len(data.get('items')) > 0), it will fail. Consider adding len to a whitelist of safe built-in functions.
There was a problem hiding this comment.
Already done: len() is in _SAFE_BUILTINS whitelist.
75732b6 to
d95fe68
Compare
There was a problem hiding this comment.
Code Review
This is an extensive and well-architected feature addition that introduces powerful graph-based workflow orchestration capabilities to the ADK. The new GraphAgent and its related components like InterruptService, GraphNode, and patterns for dynamic and nested execution are thoughtfully designed. The code is well-structured, breaking down complex logic into manageable mixins and helper modules. The inclusion of comprehensive samples, tests, and telemetry instrumentation is commendable and demonstrates a high level of quality. My review focuses on a recurring subtle bug in a few of the new sample files related to session state handling, which could be misleading for users. Overall, this is an excellent contribution.
| fresh_session = await session_service.get_session( | ||
| app_name="dynamic_queue_demo", user_id="demo_user", session_id=session.id | ||
| ) | ||
| final_session = fresh_session or session | ||
| final_data = final_session.state.get("graph_data", {}) | ||
| final_state = GraphState(data=final_data) if final_data else GraphState() |
There was a problem hiding this comment.
The fallback or session is problematic. The session object holds a stale, deep-copied state from when it was created via create_session. If get_session fails for some reason and returns None, this code will silently proceed to print incorrect final statistics from the old state, which is misleading for anyone running this sample. It's better to handle the None case explicitly to avoid this subtle bug.
| fresh_session = await session_service.get_session( | |
| app_name="dynamic_queue_demo", user_id="demo_user", session_id=session.id | |
| ) | |
| final_session = fresh_session or session | |
| final_data = final_session.state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) if final_data else GraphState() | |
| final_session = await session_service.get_session( | |
| app_name="dynamic_queue_demo", user_id="demo_user", session_id=session.id | |
| ) | |
| if final_session: | |
| final_data = final_session.state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) | |
| else: | |
| print("Warning: Could not retrieve final session state.") | |
| final_state = GraphState() |
There was a problem hiding this comment.
Fixed: Explicit None check with warning print instead of silent stale fallback.
There was a problem hiding this comment.
Fixed — changed from warn+fallback to early return when session cannot be retrieved.
| fresh_session = await session_service.get_session( | ||
| app_name="research_graph", user_id="user1", session_id=session.id | ||
| ) | ||
| final_data = (fresh_session or session).state.get("graph_data", {}) | ||
| final_state = GraphState(data=final_data) |
There was a problem hiding this comment.
Similar to another sample file in this PR, the fallback or session can lead to misleading results. The session object is a deepcopy from when the session was created and does not contain the final state from the graph execution. If get_session were to fail and return None, this code would fall back to the stale session object and print incorrect results. A more robust implementation would handle the None case explicitly.
| fresh_session = await session_service.get_session( | |
| app_name="research_graph", user_id="user1", session_id=session.id | |
| ) | |
| final_data = (fresh_session or session).state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) | |
| final_session = await session_service.get_session( | |
| app_name="research_graph", user_id="user1", session_id=session.id | |
| ) | |
| if final_session: | |
| final_data = final_session.state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) | |
| else: | |
| print("Warning: Could not retrieve final session state.") | |
| final_state = GraphState() |
There was a problem hiding this comment.
Fixed: Explicit None check with warning print instead of silent stale fallback.
| fresh_session = await session_service.get_session( | ||
| app_name="react_agent", user_id="user1", session_id=session.id | ||
| ) | ||
| final_data = (fresh_session or session).state.get("graph_data", {}) | ||
| final_state = GraphState(data=final_data) |
There was a problem hiding this comment.
This pattern of re-fetching the session and falling back to the original session object is problematic. The original session is a deepcopy from create_session and holds stale state. If get_session returns None, this code will silently use the old state and print incorrect final results, which is misleading. It's safer to handle the None case explicitly.
| fresh_session = await session_service.get_session( | |
| app_name="react_agent", user_id="user1", session_id=session.id | |
| ) | |
| final_data = (fresh_session or session).state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) | |
| final_session = await session_service.get_session( | |
| app_name="react_agent", user_id="user1", session_id=session.id | |
| ) | |
| if final_session: | |
| final_data = final_session.state.get("graph_data", {}) | |
| final_state = GraphState(data=final_data) | |
| else: | |
| print("Warning: Could not retrieve final session state.") | |
| final_state = GraphState() |
There was a problem hiding this comment.
Fixed: Explicit None check with warning print instead of silent stale fallback.
c6f4902 to
8bbeedd
Compare
Addressing review feedbackForce-pushed with the following fixes: High:
Medium:
|
|
@gemini-code-assist please re-review |
There was a problem hiding this comment.
Code Review
This pull request introduces a major new feature: the GraphAgent and its associated components for building complex, stateful workflows with conditional routing, loops, and human-in-the-loop capabilities. This is a significant and well-executed addition to the ADK. The code is well-structured, with clear separation of concerns into modules for graph logic, interrupts, telemetry, and patterns. The inclusion of numerous samples and comprehensive tests is excellent.
My main feedback is a high-severity issue regarding the state merging logic in parallel execution, which could lead to unexpected behavior. Addressing this would make the parallel execution feature more robust and predictable.
| # Merge branch states back into main state with conflict detection | ||
| conflicts_detected = [] | ||
| keys_merged = set() | ||
|
|
||
| for node_name in group.nodes: | ||
| if node_name not in results: | ||
| continue | ||
| result = results[node_name] | ||
| branch_state = result["state"] | ||
|
|
||
| # Merge data keys with conflict detection | ||
| for key, value in branch_state.data.items(): | ||
| if ( | ||
| key in state.data | ||
| and key in keys_merged | ||
| and state.data[key] != value | ||
| ): | ||
| # Conflict: Multiple branches set the same key to different values | ||
| conflicts_detected.append({ | ||
| "key": key, | ||
| "node": node_name, | ||
| "existing_value": state.data[key], | ||
| "new_value": value, | ||
| }) | ||
| logger.warning( | ||
| f"State merge conflict detected: key '{key}' modified by multiple" | ||
| f" parallel branches. Last write wins (node: {node_name})." | ||
| ) | ||
|
|
||
| state.data[key] = value | ||
| keys_merged.add(key) |
There was a problem hiding this comment.
The current state merging logic uses a "last write wins" strategy based on the iteration order of nodes in the ParallelNodeGroup. This can lead to subtle bugs and non-deterministic behavior if two parallel branches modify the same state key. The result will depend on which branch is processed last in the loop, not which one finished execution last.
For example, if branch A modifies state.data['x'] and branch B does not, but B is processed second, B's unchanged value for x (from the initial state copy) will overwrite A's modification.
Consider a more explicit merge strategy that only merges keys that were actually changed by each branch. For conflicting writes to the same key, you could either raise an error or allow the user to provide a custom merge function to resolve the conflict. This would make the state updates more predictable and robust.
# Merge branch states back into main state with conflict detection
conflicts_detected = []
keys_merged: set[str] = set()
original_data = deepcopy(state.data)
for node_name in group.nodes:
if node_name not in results:
continue
result = results[node_name]
branch_data = result["state"].data
# Diff branch state against original to find changed keys
for key, value in branch_data.items():
original_value = original_data.get(key, object()) # Sentinel for new keys
if value == original_value:
continue # Skip unchanged keys
# This key was added or changed by the branch
if key in keys_merged and state.data.get(key) != value:
conflicts_detected.append({
"key": key,
"node": node_name,
"existing_value": state.data.get(key),
"new_value": value,
})
logger.warning(
"State merge conflict: key '%s' written by multiple parallel"
" branches. Last write from node '%s' wins.",
key,
node_name,
)
state.data[key] = value
keys_merged.add(key)There was a problem hiding this comment.
Fixed — implemented diff-based state merge. Captures original snapshot before merge loop and only merges keys that actually changed from the pre-branch state.
Add GraphAgent for building directed-graph workflows with conditional routing, cyclic execution, state management with reducers, typed events, streaming, callbacks, rewind, resumability, telemetry with OpenTelemetry tracing, evaluation metrics, and CLI graph visualization for GraphAgent topologies. Includes samples and design documentation.
Add DynamicNode (runtime agent selection), NestedGraphNode (hierarchical workflow composition), and DynamicParallelGroup (variable-count concurrent execution). Extends CLI visualization with pattern-aware rendering (diamond, parallelogram, sub-cluster shapes). Includes pattern samples, node type reference, and design documentation.
8bbeedd to
218d2ec
Compare
|
@gemini-code-assist please re-review |
There was a problem hiding this comment.
Code Review
This is an impressive and substantial pull request that introduces the GraphAgent for complex workflow orchestration. The feature set is comprehensive, including conditional routing, cyclic execution, parallel processing, and a robust human-in-the-loop InterruptService. The code is well-structured, with good separation of concerns using mixins for telemetry and interrupt handling. The inclusion of numerous focused examples and advanced patterns is excellent and will greatly benefit users. Overall, the implementation is of high quality. My review includes a few suggestions to improve performance and robustness in the sample code and the interrupt service.
| # asyncio.Queue stores items in _queue (a collections.deque). | ||
| # This avoids drain/requeue and eliminates QueueFull risk. | ||
| queue = self._message_queues[session_id] | ||
| messages = list(queue._queue) # type: ignore[attr-defined] |
There was a problem hiding this comment.
Accessing the private member _queue of asyncio.Queue is risky because it relies on an internal implementation detail of the standard library that is not guaranteed to be stable across Python versions. If asyncio.Queue's internal structure changes in a future release, this code will break.
While the current approach avoids a potential QueueFull error that could occur with a drain-and-requeue strategy, a more robust long-term solution should be considered. This might involve implementing a custom, non-blocking peek method or using a different queue implementation that officially supports this operation.
| return {"all_complete": True, "tasks_remaining": 0} | ||
|
|
||
| # Pop next task | ||
| next_task = task_queue.pop(0) |
| final_counter = session.state.get("counter") | ||
| if final_counter is None: | ||
| graph_data_raw = session.state.get("graph_data") | ||
| if graph_data_raw: | ||
| try: | ||
| data = ( | ||
| json.loads(graph_data_raw) | ||
| if isinstance(graph_data_raw, str) | ||
| else graph_data_raw | ||
| ) | ||
| final_counter = data.get("counter", 0) | ||
| except (json.JSONDecodeError, TypeError): | ||
| final_counter = 0 | ||
|
|
||
| if final_counter is None: | ||
| final_counter = 0 |
There was a problem hiding this comment.
The logic to retrieve the final_counter is more complex than necessary. The GraphAgent's final event consistently places the final state in session.state['graph_data']. You can simplify this by directly accessing the counter from the final graph data with a simple fallback, which will improve the readability and maintainability of this example.
session = await session_service.get_session(
app_name="cyclic_demo", user_id="user1", session_id="session1"
)
graph_data = session.state.get("graph_data", {})
final_counter = graph_data.get("counter", 0)
Please ensure you have read the contribution guide before creating a pull request.
Link to Issue or Description of Change
1. Link to an existing issue (if applicable):
2. Or, if no issue exists, describe the change:
Problem:
Many agent workflows need human oversight — approval gates, review checkpoints, error correction. There is no structured way to pause execution, present context to a human, and resume with their decision.
Solution:
Add interrupt system enabling pause/resume, human approval gates, and cancellation during graph execution. Includes
InterruptServicewith priority queues and per-session metrics,InterruptReasonerfor LLM-based interrupt decisions,GraphInterruptMixinfor execution loop integration, and checkpoint tracing infrastructure.What's included:
src/google/adk/agents/graph/interrupt.py— InterruptConfig, InterruptMode, InterruptActionsrc/google/adk/agents/graph/interrupt_service.py— InterruptService, InterruptMessage, InterruptServiceConfig, QueueStatus, SessionMetricssrc/google/adk/agents/graph/interrupt_reasoner.py— InterruptReasoner, InterruptDecision, InterruptReasonerConfigsrc/google/adk/agents/graph/graph_interrupt_handler.py— GraphInterruptMixinsrc/google/adk/telemetry/checkpoint_tracing.pygraph_agent.pywith interrupt handlinggraph/__init__.pywith interrupt exportstest_interrupt_cancellation.py,test_interrupt_integration.py,test_interrupt_reasoner.py,test_interrupt_service.py,test_checkpoint_tracing.pygraph_agent_hitl_review,graph_agent_hitl_orchestrated, examples 05-06/13)interrupt_service_architecture.mdPart 4 of 5 — see tracking issue #4581. Stacked on #4584.
Testing Plan
Unit Tests:
Manual End-to-End (E2E) Tests:
4 interrupt sample agents import and instantiate successfully.
Checklist
Additional context
Part 4 of 5. Depends on #4582 (Core), #4583 (Patterns), #4584 (Parallel).