From 3b21fcbae75cae9efb10cb393b0204a669e114e9 Mon Sep 17 00:00:00 2001 From: Umer Khan <96595386+umerkhan95@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:54:48 +0200 Subject: [PATCH 01/11] fix(memory): serialize concurrent writes in AsyncSQLAlchemyMemory (#1390) --- .../_working_memory/_sqlalchemy_memory.py | 388 +++++++++--------- tests/memory_test.py | 43 ++ 2 files changed, 243 insertions(+), 188 deletions(-) diff --git a/src/agentscope/memory/_working_memory/_sqlalchemy_memory.py b/src/agentscope/memory/_working_memory/_sqlalchemy_memory.py index 769abb5371..561ce90400 100644 --- a/src/agentscope/memory/_working_memory/_sqlalchemy_memory.py +++ b/src/agentscope/memory/_working_memory/_sqlalchemy_memory.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- """The SQLAlchemy database storage module, which supports storing messages in a SQL database using SQLAlchemy ORM (e.g., SQLite, PostgreSQL, MySQL).""" -from typing import Any +import asyncio +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator from sqlalchemy import ( Column, @@ -167,6 +169,20 @@ def __init__( # Flag to track if tables and records have been initialized self._initialized = False + # Lock to serialize concurrent write operations on the shared session + self._lock = asyncio.Lock() + + @asynccontextmanager + async def _write_session(self) -> AsyncIterator[None]: + """Acquire the write lock and auto-commit/rollback the session.""" + async with self._lock: + try: + yield + await self.session.commit() + except Exception: + await self.session.rollback() + raise + def _make_message_id(self, msg_id: str) -> str: """Generate a composite primary key for a message. @@ -412,81 +428,81 @@ async def add( f"but got {type(marks)}.", ) - # Create table if not exists - await self._create_table() - - # If skip_duplicated is True, filter out existing messages - messages_to_add = memories - if skip_duplicated: - existing_msg_ids = set() - result = await self.session.execute( - select(self.MessageTable.id).filter( - self.MessageTable.id.in_( - [self._make_message_id(m.id) for m in memories], - ), - ), - ) - existing_msg_ids = {row[0] for row in result.fetchall()} + async with self._write_session(): + # Create table if not exists + await self._create_table() - messages_to_add = [ - m - for m in memories - if self._make_message_id(m.id) not in existing_msg_ids - ] - - # If all messages are duplicates, return early - if not messages_to_add: - return - - # Get the starting index once to avoid race conditions - start_index = await self._get_next_index() - - # Add messages to message table - for i, m in enumerate(messages_to_add): - message_record = self.MessageTable( - id=self._make_message_id(m.id), - msg=m.to_dict(), - session_id=self.session_id, - index=start_index + i, - ) - self.session.add(message_record) - - # Create mark records if marks are provided (use bulk insert) - if marks: - mark_records = [ - {"msg_id": self._make_message_id(msg.id), "mark": mark} - for msg in messages_to_add - for mark in marks - ] - if mark_records: - if skip_duplicated: - # Query existing mark combinations to avoid duplicates - result = await self.session.execute( - select( - self.MessageMarkTable.msg_id, - self.MessageMarkTable.mark, + # If skip_duplicated is True, filter out existing messages + messages_to_add = memories + if skip_duplicated: + existing_msg_ids = set() + result = await self.session.execute( + select(self.MessageTable.id).filter( + self.MessageTable.id.in_( + [self._make_message_id(m.id) for m in memories], ), - ) - existing_marks = { - (row[0], row[1]) for row in result.fetchall() + ), + ) + existing_msg_ids = {row[0] for row in result.fetchall()} + + messages_to_add = [ + m + for m in memories + if self._make_message_id(m.id) not in existing_msg_ids + ] + + # If all messages are duplicates, return early + if not messages_to_add: + return + + # Get the starting index once to avoid race conditions + start_index = await self._get_next_index() + + # Add messages to message table + for i, m in enumerate(messages_to_add): + message_record = self.MessageTable( + id=self._make_message_id(m.id), + msg=m.to_dict(), + session_id=self.session_id, + index=start_index + i, + ) + self.session.add(message_record) + + # Create mark records if marks are provided (bulk insert) + if marks: + mark_records = [ + { + "msg_id": self._make_message_id(msg.id), + "mark": mark, } - - # Filter out existing mark combinations - mark_records = [ - r - for r in mark_records - if (r["msg_id"], r["mark"]) not in existing_marks - ] - + for msg in messages_to_add + for mark in marks + ] if mark_records: - await self.session.run_sync( - lambda session: session.bulk_insert_mappings( - self.MessageMarkTable, - mark_records, - ), - ) - - await self.session.commit() + if skip_duplicated: + result = await self.session.execute( + select( + self.MessageMarkTable.msg_id, + self.MessageMarkTable.mark, + ), + ) + existing_marks = { + (row[0], row[1]) for row in result.fetchall() + } + + mark_records = [ + r + for r in mark_records + if (r["msg_id"], r["mark"]) not in existing_marks + ] + + if mark_records: + await self.session.run_sync( + lambda session: session.bulk_insert_mappings( + self.MessageMarkTable, + mark_records, + ), + ) async def _get_next_index(self) -> int: """Get the next index for a new message in the current session. @@ -515,25 +531,24 @@ async def size(self) -> int: async def clear(self) -> None: """Clear all messages from the storage.""" - # Delete all marks for messages in this session - await self.session.execute( - delete(self.MessageMarkTable).where( - self.MessageMarkTable.msg_id.in_( - select(self.MessageTable.id).filter( - self.MessageTable.session_id == self.session_id, + async with self._write_session(): + # Delete all marks for messages in this session + await self.session.execute( + delete(self.MessageMarkTable).where( + self.MessageMarkTable.msg_id.in_( + select(self.MessageTable.id).filter( + self.MessageTable.session_id == self.session_id, + ), ), ), - ), - ) - - # Then delete all messages - await self.session.execute( - delete(self.MessageTable).filter( - self.MessageTable.session_id == self.session_id, - ), - ) + ) - await self.session.commit() + # Then delete all messages + await self.session.execute( + delete(self.MessageTable).filter( + self.MessageTable.session_id == self.session_id, + ), + ) async def delete_by_mark( self, @@ -553,45 +568,44 @@ async def delete_by_mark( if isinstance(mark, str): mark = [mark] - # First, find message IDs that have the specified marks - query = ( - select(self.MessageTable.id) - .join( - self.MessageMarkTable, - self.MessageTable.id == self.MessageMarkTable.msg_id, - ) - .filter( - self.MessageTable.session_id == self.session_id, - self.MessageMarkTable.mark.in_(mark), + async with self._write_session(): + # First, find message IDs that have the specified marks + query = ( + select(self.MessageTable.id) + .join( + self.MessageMarkTable, + self.MessageTable.id == self.MessageMarkTable.msg_id, + ) + .filter( + self.MessageTable.session_id == self.session_id, + self.MessageMarkTable.mark.in_(mark), + ) ) - ) - result = await self.session.execute(query) - msg_ids = [row[0] for row in result.all()] + result = await self.session.execute(query) + msg_ids = [row[0] for row in result.all()] - if not msg_ids: - return 0 + if not msg_ids: + return 0 - # Store the count before deletion - deleted_count = len(msg_ids) + deleted_count = len(msg_ids) - # Delete marks first - await self.session.execute( - delete(self.MessageMarkTable).filter( - self.MessageMarkTable.msg_id.in_(msg_ids), - ), - ) + # Delete marks first + await self.session.execute( + delete(self.MessageMarkTable).filter( + self.MessageMarkTable.msg_id.in_(msg_ids), + ), + ) - # Then delete the messages - await self.session.execute( - delete(self.MessageTable).filter( - self.MessageTable.session_id == self.session_id, - self.MessageTable.id.in_(msg_ids), - ), - ) + # Then delete the messages + await self.session.execute( + delete(self.MessageTable).filter( + self.MessageTable.session_id == self.session_id, + self.MessageTable.id.in_(msg_ids), + ), + ) - await self.session.commit() - return deleted_count + return deleted_count async def delete( self, @@ -620,26 +634,25 @@ async def delete( if not composite_ids: return 0 - # Store the count before deletion - deleted_count = len(composite_ids) + async with self._write_session(): + deleted_count = len(composite_ids) - # Delete related marks first (explicit cleanup for reliability) - await self.session.execute( - delete(self.MessageMarkTable).filter( - self.MessageMarkTable.msg_id.in_(composite_ids), - ), - ) + # Delete related marks first (explicit cleanup for reliability) + await self.session.execute( + delete(self.MessageMarkTable).filter( + self.MessageMarkTable.msg_id.in_(composite_ids), + ), + ) - # Then delete the messages - await self.session.execute( - delete(self.MessageTable).filter( - self.MessageTable.session_id == self.session_id, - self.MessageTable.id.in_(composite_ids), - ), - ) + # Then delete the messages + await self.session.execute( + delete(self.MessageTable).filter( + self.MessageTable.session_id == self.session_id, + self.MessageTable.id.in_(composite_ids), + ), + ) - await self.session.commit() - return deleted_count + return deleted_count async def update_messages_mark( self, @@ -695,55 +708,52 @@ async def update_messages_mark( f"but got {type(msg_ids)}.", ) - # First obtain the message ids that belong to this session - query = select(self.MessageTable).filter( - self.MessageTable.session_id == self.session_id, - ) - - # Filter by msg_ids if provided - if msg_ids is not None: - # Convert to composite keys - composite_ids = [ - self._make_message_id(msg_id) for msg_id in msg_ids - ] - query = query.filter(self.MessageTable.id.in_(composite_ids)) - - # Filter by old_mark if provided - if old_mark is not None: - query = query.join( - self.MessageMarkTable, - self.MessageTable.id == self.MessageMarkTable.msg_id, - ).filter(self.MessageMarkTable.mark == old_mark) - - # Obtain the message records - result = await self.session.execute(query) - msg_ids = [str(_.id) for _ in result.scalars().all()] + async with self._write_session(): + # First obtain the message ids that belong to this session + query = select(self.MessageTable).filter( + self.MessageTable.session_id == self.session_id, + ) - # Return early if no messages found - if not msg_ids: - return 0 + # Filter by msg_ids if provided + if msg_ids is not None: + composite_ids = [ + self._make_message_id(msg_id) for msg_id in msg_ids + ] + query = query.filter( + self.MessageTable.id.in_(composite_ids), + ) - if new_mark: - if old_mark: - # Replace old_mark with new_mark - return await self._replace_message_mark( + # Filter by old_mark if provided + if old_mark is not None: + query = query.join( + self.MessageMarkTable, + self.MessageTable.id == self.MessageMarkTable.msg_id, + ).filter(self.MessageMarkTable.mark == old_mark) + + # Obtain the message records + result = await self.session.execute(query) + msg_ids = [str(_.id) for _ in result.scalars().all()] + + if not msg_ids: + return 0 + + if new_mark: + if old_mark: + return await self._replace_message_mark( + msg_ids=msg_ids, + old_mark=old_mark, + new_mark=new_mark, + ) + return await self._add_message_mark( msg_ids=msg_ids, - old_mark=old_mark, - new_mark=new_mark, + mark=new_mark, ) - # Add new_mark to the messages - return await self._add_message_mark( + return await self._remove_message_mark( msg_ids=msg_ids, - mark=new_mark, + old_mark=old_mark, ) - # Remove all marks from the messages - return await self._remove_message_mark( - msg_ids=msg_ids, - old_mark=old_mark, - ) - async def _replace_message_mark( self, msg_ids: list[str], @@ -753,6 +763,8 @@ async def _replace_message_mark( """Replace the old mark with the new mark for the given messages by updating records in the message_mark table. + Note: Must be called within ``_write_session()``. + Args: msg_ids (`list[str]`): The list of message IDs to be updated. @@ -765,7 +777,6 @@ async def _replace_message_mark( `int`: The number of messages updated. """ - await self.session.execute( update(self.MessageMarkTable) .filter( @@ -774,13 +785,14 @@ async def _replace_message_mark( ) .values(mark=new_mark), ) - await self.session.commit() return len(msg_ids) async def _add_message_mark(self, msg_ids: list[str], mark: str) -> int: """Mark the messages with the given mark by adding records to the message_mark table. + Note: Must be called within ``_write_session()``. + Args: msg_ids (`list[str]`): The list of message IDs to be marked. @@ -791,7 +803,6 @@ async def _add_message_mark(self, msg_ids: list[str], mark: str) -> int: `int`: The number of messages marked. """ - # Use bulk insert for better performance mark_records = [{"msg_id": msg_id, "mark": mark} for msg_id in msg_ids] if mark_records: @@ -802,7 +813,6 @@ async def _add_message_mark(self, msg_ids: list[str], mark: str) -> int: ), ) - await self.session.commit() return len(msg_ids) async def _remove_message_mark( @@ -813,6 +823,8 @@ async def _remove_message_mark( """Remove marks from the messages by deleting records from the message_mark table. + Note: Must be called within ``_write_session()``. + Args: msg_ids (`list[str]`): The list of message IDs to be unmarked. @@ -834,16 +846,16 @@ async def _remove_message_mark( ) await self.session.execute(delete_query) - await self.session.commit() return len(msg_ids) async def close(self) -> None: """Close the database session.""" - if self._db_session and self._db_session.is_active: - await self._db_session.close() - - self._db_session = None - self._initialized = False + try: + if self._db_session and self._db_session.is_active: + await self._db_session.close() + finally: + self._db_session = None + self._initialized = False async def __aenter__(self) -> "AsyncSQLAlchemyMemory": """Enter the async context manager. diff --git a/tests/memory_test.py b/tests/memory_test.py index 1eec5fe4f4..19eed14991 100644 --- a/tests/memory_test.py +++ b/tests/memory_test.py @@ -639,6 +639,49 @@ async def test_memory(self) -> None: await self._multi_session_tests() await self._test_serialization() + async def test_concurrent_add(self) -> None: + """Test that concurrent add() calls don't cause IntegrityError. + + Reproduces the bug from GitHub issue #1381: when parallel_tool_calls + is True, multiple _acting coroutines call memory.add() concurrently, + causing duplicate primary key conflicts. + """ + from sqlalchemy import select as sa_select + + messages = [ + Msg("system", f"Tool result {i}", "system") for i in range(20) + ] + + # Add all messages concurrently (simulates parallel_tool_calls) + await asyncio.gather( + *(self.memory.add(msg) for msg in messages), + ) + + # Verify all messages were added + stored = await self.memory.get_memory() + self.assertEqual(len(stored), len(messages)) + + # Verify indices are unique and contiguous (the core race condition + # in _get_next_index would cause duplicate indices without the lock) + result = await self.memory.session.execute( + sa_select(self.memory.MessageTable.index) + .filter( + self.memory.MessageTable.session_id == self.memory.session_id, + ) + .order_by(self.memory.MessageTable.index), + ) + indices = [row[0] for row in result.fetchall()] + self.assertEqual( + len(set(indices)), + len(messages), + "Indices not unique", + ) + self.assertEqual( + indices, + list(range(len(messages))), + "Indices not contiguous", + ) + async def asyncTearDown(self) -> None: """Clean up after unittests""" await super().asyncTearDown() From 69dd5a9a237dd203a350d28da111713aa4848c6b Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 20 Apr 2026 18:10:22 +0800 Subject: [PATCH 02/11] fix(agent): prevent duplicate hook execution when subclass overrides _reasoning or _acting (#1481) --- src/agentscope/agent/_agent_meta.py | 24 +- tests/hook_test.py | 378 +++++++++++++++++++++++++++- 2 files changed, 393 insertions(+), 9 deletions(-) diff --git a/src/agentscope/agent/_agent_meta.py b/src/agentscope/agent/_agent_meta.py index a9bc81b5f8..834437074f 100644 --- a/src/agentscope/agent/_agent_meta.py +++ b/src/agentscope/agent/_agent_meta.py @@ -63,6 +63,8 @@ def _wrap_with_hooks( """ func_name = original_func.__name__.replace("_", "") + hook_guard_attr = f"_hook_running_{func_name}" + @wraps(original_func) async def async_wrapper( self: AgentBase, @@ -72,6 +74,12 @@ async def async_wrapper( """The wrapped function, which call the pre- and post-hooks before and after the original function.""" + # Guard against re-entrant hook execution when multiple classes + # in the MRO define the same method (each wrapped independently + # by the metaclass). Only the outermost wrapper runs hooks. + if getattr(self, hook_guard_attr, False): + return await original_func(self, *args, **kwargs) + # Unify all positional and keyword arguments into a keyword arguments normalized_kwargs = _normalize_to_kwargs( original_func, @@ -117,12 +125,16 @@ async def async_wrapper( for k, v in current_normalized_kwargs.items() if k not in ["args", "kwargs"] } - current_output = await original_func( - self, - *args, - **others, - **kwargs, - ) + setattr(self, hook_guard_attr, True) + try: + current_output = await original_func( + self, + *args, + **others, + **kwargs, + ) + finally: + setattr(self, hook_guard_attr, False) # post_hooks post_hooks = list( diff --git a/tests/hook_test.py b/tests/hook_test.py index 90d148f2fd..43ba022dcc 100644 --- a/tests/hook_test.py +++ b/tests/hook_test.py @@ -1,10 +1,17 @@ # -*- coding: utf-8 -*- +# pylint: disable=too-many-lines """Hook related tests in agentscope.""" from typing import Any from unittest.async_case import IsolatedAsyncioTestCase -from agentscope.agent import AgentBase -from agentscope.message import Msg, TextBlock +from pydantic import BaseModel, Field + +from agentscope.agent import AgentBase, ReActAgent +from agentscope.formatter import DashScopeChatFormatter +from agentscope.memory import InMemoryMemory +from agentscope.message import Msg, TextBlock, ToolUseBlock +from agentscope.model import ChatModelBase, ChatResponse +from agentscope.tool import Toolkit class MyAgent(AgentBase): @@ -32,7 +39,7 @@ async def observe(self, msg: Msg) -> None: """Observe the message without generating a reply.""" self.memory.append(msg) - async def handle_interrupt(self, *args: Any, **kwargs: Any) -> Msg: + async def handle_interrupt(self, *_args: Any, **_kwargs: Any) -> Msg: """Handle the interrupt signal.""" # This is a placeholder for handling interrupts. return Msg("test", "Interrupt handled", "assistant") @@ -46,6 +53,32 @@ class GrandChildAgent(ChildAgent): """Grandchild agent for testing deeper inheritance.""" +class ChildAgentWithReplyOverride(MyAgent): + """Child agent that overrides reply and calls super().reply(), + triggering double wrapping by the metaclass. Used to test + that hook_guard_attr prevents duplicate hook execution.""" + + async def reply(self, msg: Msg) -> Msg: + """Override reply, delegating to parent via super().""" + return await super().reply(msg) + + +class ChildAgentWithObserveOverride(MyAgent): + """Child agent that overrides observe and calls super().observe().""" + + async def observe(self, msg: Msg) -> None: + """Override observe, delegating to parent via super().""" + await super().observe(msg) + + +class GrandChildAgentWithReplyOverride(ChildAgentWithReplyOverride): + """Three-level inheritance chain with each level overriding reply.""" + + async def reply(self, msg: Msg) -> Msg: + """Override reply again, delegating to parent via super().""" + return await super().reply(msg) + + class AgentA(MyAgent): """First parent class.""" @@ -58,6 +91,66 @@ class AgentC(AgentA, AgentB): """Multiple inheritance class.""" +class MockModel(ChatModelBase): + """Mock model that returns text-only on the first call and + text + tool_use on subsequent calls.""" + + def __init__(self) -> None: + """Initialize the mock model.""" + super().__init__("mock_model", stream=False) + self.cnt = 1 + self.fake_content_text = [ + TextBlock(type="text", text="text_response"), + ] + self.fake_content_tool = [ + TextBlock(type="text", text="tool_response"), + ToolUseBlock( + type="tool_use", + name="generate_response", + id="mock_id", + input={"result": "structured_value"}, + ), + ] + + async def __call__( + self, + _messages: list[dict], + **kwargs: Any, + ) -> ChatResponse: + """Mock model call.""" + self.cnt += 1 + if self.cnt == 2: + return ChatResponse(content=self.fake_content_text) + else: + return ChatResponse(content=self.fake_content_tool) + + +class MyReActAgent(ReActAgent): + """Subclass that overrides reply, _reasoning and _acting, each calling + super(). Used to test that hook_guard_attr prevents duplicate hook + execution when the metaclass wraps both the child's and parent's methods + independently.""" + + async def reply( + self, + msg: Msg | list[Msg] | None = None, + structured_model: Any = None, + ) -> Msg: + """Override reply, delegating to parent via super().""" + return await super().reply(msg, structured_model=structured_model) + + async def _reasoning( + self, + tool_choice: Any = None, + ) -> Msg: + """Override _reasoning, delegating to parent via super().""" + return await super()._reasoning(tool_choice=tool_choice) + + async def _acting(self, tool_call: Any) -> dict | None: + """Override _acting, delegating to parent via super().""" + return await super()._acting(tool_call) + + async def async_pre_func_w_modifying( self: MyAgent, kwargs: dict[str, Any], @@ -786,3 +879,282 @@ async def asyncTearDown(self) -> None: AgentA.clear_class_hooks() AgentB.clear_class_hooks() AgentC.clear_class_hooks() + + +class HookGuardTest(IsolatedAsyncioTestCase): + """Tests for the hook_guard_attr re-entrancy prevention mechanism. + + When a child class overrides a hook-wrapped method (reply, observe, + _reasoning, _acting, etc.) and calls super().method(), the metaclass + wraps both the child's and the parent's method independently. Without + the guard, hooks would fire once per wrapper in the call chain. The + hook_guard_attr ensures hooks only execute in the outermost wrapper. + + Covers both AgentBase-level (reply, observe) and ReActAgent-level + (reply, _reasoning, _acting) scenarios. + """ + + @property + def msg(self) -> Msg: + """Get the test message.""" + return Msg( + "user", + [TextBlock(type="text", text="0")], + "user", + ) + + def _make_react_agent(self) -> MyReActAgent: + """Create a MyReActAgent with a fresh mock model.""" + return MyReActAgent( + name="TestAgent", + sys_prompt="You are a helpful assistant.", + model=MockModel(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=Toolkit(), + ) + + # ---- AgentBase-level tests ---- + + async def test_reply_hooks_execute_once_with_override(self) -> None: + """Pre and post reply hooks should each execute exactly once when + a child class overrides reply() and calls super().reply().""" + agent = ChildAgentWithReplyOverride() + pre_count = 0 + post_count = 0 + + async def counting_pre_hook( + _self: Any, + _kwargs: dict[str, Any], + ) -> None: + nonlocal pre_count + pre_count += 1 + + async def counting_post_hook( + _self: Any, + _kwargs: dict[str, Any], + _output: Any, + ) -> None: + nonlocal post_count + post_count += 1 + + agent.register_instance_hook( + "pre_reply", + "counter_pre", + counting_pre_hook, + ) + agent.register_instance_hook( + "post_reply", + "counter_post", + counting_post_hook, + ) + + await agent(self.msg) + self.assertEqual(pre_count, 1) + self.assertEqual(post_count, 1) + + async def test_observe_hooks_execute_once_with_override(self) -> None: + """Observe hooks should execute exactly once when a child class + overrides observe() and calls super().observe().""" + agent = ChildAgentWithObserveOverride() + pre_count = 0 + + async def counting_pre_hook( + _self: Any, + _kwargs: dict[str, Any], + ) -> None: + nonlocal pre_count + pre_count += 1 + + agent.register_instance_hook( + "pre_observe", + "counter", + counting_pre_hook, + ) + + await agent.observe(self.msg) + self.assertEqual(pre_count, 1) + + async def test_deep_inheritance_hooks_execute_once(self) -> None: + """Hooks should execute exactly once even with a 3-level override + chain (GrandChild -> Child -> MyAgent), each overriding reply and + calling super().""" + agent = GrandChildAgentWithReplyOverride() + pre_count = 0 + + async def counting_pre_hook( + _self: Any, + _kwargs: dict[str, Any], + ) -> None: + nonlocal pre_count + pre_count += 1 + + agent.register_instance_hook( + "pre_reply", + "counter", + counting_pre_hook, + ) + + await agent(self.msg) + self.assertEqual(pre_count, 1) + + async def test_hook_guard_cleared_after_exception(self) -> None: + """The guard flag should be properly cleaned up when the wrapped + method raises an exception, allowing hooks to work on retry.""" + + class FailingAgent(MyAgent): + """Agent whose reply always raises.""" + + async def reply(self, msg: Msg) -> Msg: + raise RuntimeError("intentional failure") + + class ChildOfFailing(FailingAgent): + """Child that overrides reply and calls super().""" + + async def reply(self, msg: Msg) -> Msg: + return await super().reply(msg) + + agent = ChildOfFailing() + pre_count = 0 + + async def counting_pre_hook( + _self: Any, + _kwargs: dict[str, Any], + ) -> None: + nonlocal pre_count + pre_count += 1 + + agent.register_instance_hook( + "pre_reply", + "counter", + counting_pre_hook, + ) + + with self.assertRaises(RuntimeError): + await agent(self.msg) + self.assertEqual(pre_count, 1) + self.assertFalse( + getattr(agent, "_hook_running_reply", False), + "Guard flag should be cleared after exception", + ) + + # Hooks should still work on subsequent calls + pre_count = 0 + with self.assertRaises(RuntimeError): + await agent(self.msg) + self.assertEqual(pre_count, 1) + + # ---- ReActAgent-level tests ---- + + async def test_react_reply_hooks_execute_once_with_override( + self, + ) -> None: + """ReActAgent reply hooks should execute exactly once when + a subclass overrides reply() and calls super().reply().""" + agent = self._make_react_agent() + pre_count = 0 + post_count = 0 + + async def counting_pre(_self: Any, _kwargs: Any) -> None: + nonlocal pre_count + pre_count += 1 + + async def counting_post( + _self: Any, + _kwargs: Any, + _output: Any, + ) -> None: + nonlocal post_count + post_count += 1 + + agent.register_instance_hook("pre_reply", "counter", counting_pre) + agent.register_instance_hook("post_reply", "counter", counting_post) + + await agent() + self.assertEqual(pre_count, 1) + self.assertEqual(post_count, 1) + + async def test_react_reasoning_hooks_execute_once_with_override( + self, + ) -> None: + """ReActAgent reasoning hooks should execute exactly once when + a subclass overrides _reasoning() and calls + super()._reasoning().""" + agent = self._make_react_agent() + pre_count = 0 + post_count = 0 + + async def counting_pre(_self: Any, _kwargs: Any) -> None: + nonlocal pre_count + pre_count += 1 + + async def counting_post( + _self: Any, + _kwargs: Any, + _output: Any, + ) -> None: + nonlocal post_count + post_count += 1 + + agent.register_instance_hook( + "pre_reasoning", + "counter", + counting_pre, + ) + agent.register_instance_hook( + "post_reasoning", + "counter", + counting_post, + ) + + await agent() + self.assertEqual(pre_count, 1) + self.assertEqual(post_count, 1) + + async def test_react_acting_hooks_execute_once_with_override( + self, + ) -> None: + """ReActAgent acting hooks should execute exactly once when + a subclass overrides _acting() and calls super()._acting().""" + agent = self._make_react_agent() + pre_count = 0 + post_count = 0 + + async def counting_pre(_self: Any, _kwargs: Any) -> None: + nonlocal pre_count + pre_count += 1 + + async def counting_post( + _self: Any, + _kwargs: Any, + _output: Any, + ) -> None: + nonlocal post_count + post_count += 1 + + agent.register_instance_hook( + "pre_acting", + "counter", + counting_pre, + ) + agent.register_instance_hook( + "post_acting", + "counter", + counting_post, + ) + + class TestStructuredModel(BaseModel): + """Test structured model.""" + + result: str = Field(description="Test result field.") + + await agent(structured_model=TestStructuredModel) + self.assertEqual(pre_count, 1) + self.assertEqual(post_count, 1) + + async def asyncTearDown(self) -> None: + """Tear down the test environment.""" + ChildAgentWithReplyOverride.clear_class_hooks() + ChildAgentWithObserveOverride.clear_class_hooks() + GrandChildAgentWithReplyOverride.clear_class_hooks() + MyReActAgent.clear_class_hooks() From 303bc4a66050d8070909be38d3954734f1205d6c Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 20 Apr 2026 18:13:16 +0800 Subject: [PATCH 03/11] feat(model): add fallback for structured output for openai chat model (#1430) --- src/agentscope/model/_openai_model.py | 91 ++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 10 deletions(-) diff --git a/src/agentscope/model/_openai_model.py b/src/agentscope/model/_openai_model.py index 19752c1fa4..548d639e42 100644 --- a/src/agentscope/model/_openai_model.py +++ b/src/agentscope/model/_openai_model.py @@ -23,6 +23,7 @@ from .._utils._common import ( _json_loads_with_repair, _parse_streaming_json_dict, + _create_tool_from_base_model, ) from ..message import ( ToolUseBlock, @@ -169,6 +170,7 @@ def __init__( self.reasoning_effort = reasoning_effort self.stream_tool_parsing = stream_tool_parsing self.generate_kwargs = generate_kwargs or {} + self._structured_output_fallback = False @trace_llm async def __call__( @@ -277,16 +279,47 @@ async def __call__( kwargs.pop("stream", None) kwargs.pop("tools", None) kwargs.pop("tool_choice", None) - kwargs["response_format"] = structured_model - if not self.stream: - response = await self.client.chat.completions.parse(**kwargs) - else: - response = self.client.chat.completions.stream(**kwargs) - return self._parse_openai_stream_response( - start_datetime, - response, + + if self._structured_output_fallback: + response = await self._structured_via_tool_call( + kwargs, structured_model, + start_datetime, ) + if isinstance(response, AsyncGenerator): + return response + else: + kwargs["response_format"] = structured_model + try: + if not self.stream: + response = await self.client.chat.completions.parse( + **kwargs, + ) + else: + response = self.client.chat.completions.stream( + **kwargs, + ) + return self._parse_openai_stream_response( + start_datetime, + response, + structured_model, + ) + except Exception as e: + logger.warning( + "response_format structured output failed (%s: %s), " + "falling back to tool-call based structured output. " + "Subsequent calls will use tool-call directly.", + type(e).__name__, + e, + ) + self._structured_output_fallback = True + response = await self._structured_via_tool_call( + kwargs, + structured_model, + start_datetime, + ) + if isinstance(response, AsyncGenerator): + return response else: response = await self.client.chat.completions.create(**kwargs) @@ -350,7 +383,7 @@ async def _parse_openai_stream_response( async with response as stream: async for item in stream: - if structured_model: + if structured_model and not self._structured_output_fallback: if item.type != "chunk": continue chunk = item.chunk @@ -613,7 +646,16 @@ def _parse_openai_completion_response( ) if structured_model: - metadata = choice.message.parsed.model_dump() + try: + parsed = choice.message.parsed + except AttributeError: + parsed = None + if parsed is not None: + metadata = parsed.model_dump() + elif choice.message.tool_calls: + metadata = _json_loads_with_repair( + choice.message.tool_calls[0].function.arguments, + ) usage = None if response.usage: @@ -635,6 +677,35 @@ def _parse_openai_completion_response( return ChatResponse(**resp_kwargs) + async def _structured_via_tool_call( + self, + kwargs: dict, + structured_model: Type[BaseModel], + start_datetime: datetime, + ) -> Any: + """Use tool-call approach for structured output. + + Falls back to this when the API endpoint does not support + json_schema response_format (e.g. DashScope, DeepSeek). + """ + kwargs.pop("response_format", None) + format_tool = _create_tool_from_base_model(structured_model) + kwargs["tools"] = self._format_tools_json_schemas([format_tool]) + kwargs["tool_choice"] = self._format_tool_choice( + format_tool["function"]["name"], + ) + if self.stream: + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} + response = await self.client.chat.completions.create(**kwargs) + if self.stream: + return self._parse_openai_stream_response( + start_datetime, + response, + structured_model, + ) + return response + def _format_tools_json_schemas( self, schemas: list[dict[str, Any]], From 7269cf3042a6009cb93addf0f8531f7efa2617cf Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 20 Apr 2026 18:18:01 +0800 Subject: [PATCH 04/11] fix(sqlalchemy_memory): flush the session to ensure message records are written to the database (#1396) --- src/agentscope/memory/_working_memory/_sqlalchemy_memory.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/agentscope/memory/_working_memory/_sqlalchemy_memory.py b/src/agentscope/memory/_working_memory/_sqlalchemy_memory.py index 561ce90400..f75fb1ed6b 100644 --- a/src/agentscope/memory/_working_memory/_sqlalchemy_memory.py +++ b/src/agentscope/memory/_working_memory/_sqlalchemy_memory.py @@ -470,6 +470,10 @@ async def add( # Create mark records if marks are provided (bulk insert) if marks: + # Flush the session to ensure message records are written to + # the database before bulk_insert_mappings for message_mark + # records to satisfy foreign key constraints + await self.session.flush() mark_records = [ { "msg_id": self._make_message_id(msg.id), From 7e0997b4ca82e59003d0e73dc9c66c15b2500976 Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 20 Apr 2026 18:19:59 +0800 Subject: [PATCH 05/11] feat(formatter): translate local file path to base64 in image blocks in Anthropic formatter (#1361) --- .../formatter/_anthropic_formatter.py | 109 +++++++++++++++++- tests/formatter_anthropic_test.py | 84 +++++++++++++- 2 files changed, 185 insertions(+), 8 deletions(-) diff --git a/src/agentscope/formatter/_anthropic_formatter.py b/src/agentscope/formatter/_anthropic_formatter.py index 3a1818325a..7b8db2b435 100644 --- a/src/agentscope/formatter/_anthropic_formatter.py +++ b/src/agentscope/formatter/_anthropic_formatter.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- -# pylint: disable=too-many-branches +# pylint: disable=too-many-branches, too-many-nested-blocks """The Anthropic formatter module.""" +import base64 +import os from typing import Any +from urllib.parse import urlparse from ._truncated_formatter_base import TruncatedFormatterBase from .._logging import logger @@ -10,6 +13,88 @@ from ..token import TokenCounterBase +def _format_anthropic_image_block(image_block: ImageBlock) -> dict: + """Format an image block for Anthropic API. If the source is a URLSource + pointing to a local file, it will be converted to base64 format. + + Args: + image_block (`ImageBlock`): + The image block to format. + + Returns: + `dict`: + A dictionary in Anthropic image block format. + + Raises: + `ValueError`: + If the source type or image format is not supported. + """ + import filetype + + # See https://platform.openai.com/docs/guides/vision for details of + # support image extensions. + support_image_extensions = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + } + + source = image_block["source"] + + if source["type"] == "base64": + return {**image_block} + + url = source["url"] + raw_url = url.removeprefix("file://") + + if os.path.exists(raw_url) and os.path.isfile(raw_url): + ext = os.path.splitext(raw_url)[1].lower() + media_type = support_image_extensions.get(ext) + if media_type: + with open(raw_url, "rb") as f: + data = base64.b64encode(f.read()).decode("utf-8") + return { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": data, + }, + } + # No extension - detect file type using filetype + kind = filetype.guess(raw_url) + if kind is not None and kind.mime.startswith("image/"): + with open(raw_url, "rb") as image_file: + data = base64.b64encode(image_file.read()).decode( + "utf-8", + ) + return { + "type": "image", + "source": { + "type": "base64", + "media_type": kind.mime, + "data": data, + }, + } + + # For web urls + parsed_url = urlparse(raw_url) + if parsed_url.scheme not in ("", "file"): + return { + "type": "image", + "source": { + "type": "url", + "url": url, + }, + } + + raise ValueError( + f'Invalid image URL: "{url}". It should be a local file or a web URL.', + ) + + class AnthropicChatFormatter(TruncatedFormatterBase): """The Anthropic formatter class for chatbot scenario, where only a user and an agent are involved. We use the `role` field to identify different @@ -63,9 +148,16 @@ async def _format( for block in msg.get_content_blocks(): typ = block.get("type") - if typ in ["thinking", "text", "image"]: + if typ in ["thinking", "text"]: content_blocks.append({**block}) + elif typ == "image": + content_blocks.append( + _format_anthropic_image_block( + block, # type: ignore[arg-type] + ), + ) + elif typ == "tool_use": content_blocks.append( { @@ -81,7 +173,12 @@ async def _format( if output is None: content_value = [{"type": "text", "text": None}] elif isinstance(output, list): - content_value = output + content_value = [ + _format_anthropic_image_block(item) + if item.get("type") == "image" + else item + for item in output + ] else: content_value = [{"type": "text", "text": str(output)}] messages.append( @@ -207,7 +304,11 @@ async def _format_agent_message( ) accumulated_text.clear() - conversation_blocks.append({**block}) + conversation_blocks.append( + _format_anthropic_image_block( + block, # type: ignore[arg-type] + ), + ) if accumulated_text: conversation_blocks.append( diff --git a/tests/formatter_anthropic_test.py b/tests/formatter_anthropic_test.py index f68c0f0d11..daea341bf3 100644 --- a/tests/formatter_anthropic_test.py +++ b/tests/formatter_anthropic_test.py @@ -21,7 +21,10 @@ class TestAnthropicChatFormatterFormatter(IsolatedAsyncioTestCase): async def asyncSetUp(self) -> None: """Set up the test environment.""" - self.image_url = "www.example_image.png" + self.image_url = "https://www.example.com/image.png" + self.image_path = "./image.png" + with open(self.image_path, "wb") as f: + f.write(b"fake image content") self.msgs_system = [ Msg( @@ -153,7 +156,7 @@ async def asyncSetUp(self) -> None: "type": "image", "source": { "type": "url", - "url": "www.example_image.png", + "url": "https://www.example.com/image.png", }, }, ], @@ -239,7 +242,7 @@ async def asyncSetUp(self) -> None: "type": "image", "source": { "type": "url", - "url": "www.example_image.png", + "url": "https://www.example.com/image.png", }, }, { @@ -366,7 +369,7 @@ async def asyncSetUp(self) -> None: "type": "image", "source": { "type": "url", - "url": "www.example_image.png", + "url": "https://www.example.com/image.png", }, }, { @@ -599,3 +602,76 @@ async def test_multiagent_formater(self) -> None: res, self.ground_truth_multiagent_without_first_conversation[1:], ) + + async def test_chat_local_image_to_base64(self) -> None: + """Local image URL in user message and tool result should be + converted to base64.""" + formatter = AnthropicChatFormatter() + msgs = [ + Msg( + "user", + [ + TextBlock(type="text", text="Describe the image."), + ImageBlock( + type="image", + source=URLSource(type="url", url=self.image_path), + ), + ], + "user", + ), + Msg( + "assistant", + [ + ToolUseBlock( + type="tool_use", + id="tool_1", + name="view_image", + input={"path": "/tmp/img.png"}, + ), + ], + "assistant", + ), + Msg( + "system", + [ + ToolResultBlock( + type="tool_result", + id="tool_1", + name="view_image", + output=[ + ImageBlock( + type="image", + source=URLSource( + type="url", + url=self.image_path, + ), + ), + TextBlock(type="text", text="Image loaded"), + ], + ), + ], + "system", + ), + ] + res = await formatter.format(msgs) + + # User message: local image should be base64 + user_img = res[0]["content"][1] + self.assertEqual(user_img["source"]["type"], "base64") + self.assertEqual(user_img["source"]["media_type"], "image/png") + self.assertEqual( + user_img["source"]["data"], + "ZmFrZSBpbWFnZSBjb250ZW50", + ) + + # Tool result: local image should also be base64 + tool_result_content = res[2]["content"][0]["content"] + img_items = [ + b for b in tool_result_content if b.get("type") == "image" + ] + self.assertEqual(len(img_items), 1) + self.assertEqual(img_items[0]["source"]["type"], "base64") + self.assertEqual( + img_items[0]["source"]["data"], + "ZmFrZSBpbWFnZSBjb250ZW50", + ) From eecf8f7f92afae88ebed66a8b57909e663c49d3f Mon Sep 17 00:00:00 2001 From: Octopus Date: Mon, 20 Apr 2026 18:20:19 +0800 Subject: [PATCH 06/11] fix(deepresearchagent): use underscore naming for Tavily tool functions in DeepResearch agent (#1400) --- examples/agent/deep_research_agent/deep_research_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/agent/deep_research_agent/deep_research_agent.py b/examples/agent/deep_research_agent/deep_research_agent.py index c9f3d4bf9c..55cbf10bec 100644 --- a/examples/agent/deep_research_agent/deep_research_agent.py +++ b/examples/agent/deep_research_agent/deep_research_agent.py @@ -166,8 +166,8 @@ def __init__( self._search_mcp_client = search_mcp_client self._mcp_initialized = False - self.search_function = "tavily-search" - self.extract_function = "tavily-extract" + self.search_function = "tavily_search" + self.extract_function = "tavily_extract" self.read_file_function = "view_text_file" self.write_file_function = "write_text_file" self.summarize_function = "summarize_intermediate_results" From 77b306aa70e39f61d23e8ee2576fb3b0f0e6445d Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 20 Apr 2026 18:20:43 +0800 Subject: [PATCH 07/11] feat(model): refine counting input tokens for anthropic chat model (#1319) --- src/agentscope/model/_anthropic_model.py | 11 +++++++++++ tests/model_anthropic_test.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/agentscope/model/_anthropic_model.py b/src/agentscope/model/_anthropic_model.py index 1c62fbe0c2..8c9e9d1deb 100644 --- a/src/agentscope/model/_anthropic_model.py +++ b/src/agentscope/model/_anthropic_model.py @@ -457,6 +457,17 @@ async def _parse_anthropic_stream_completion_response( elif event.type == "message_delta": if event.usage and usage: + # For some providers compatible with Anthropic's API + # (e.g., DashScope), the input tokens are contained in + # events where event.type == "message_delta", so we add + # this step to extract the value. + final_input_tokens = getattr( + event.usage, + "input_tokens", + 0, + ) + if final_input_tokens > usage.input_tokens: + usage.input_tokens = final_input_tokens usage.output_tokens = event.usage.output_tokens if (thinking_changed or content_changed) and usage: diff --git a/tests/model_anthropic_test.py b/tests/model_anthropic_test.py index 158bf3448c..4719a6c2d9 100644 --- a/tests/model_anthropic_test.py +++ b/tests/model_anthropic_test.py @@ -304,7 +304,7 @@ async def test_streaming_response_processing(self) -> None: ), AnthropicEventMock( "message_delta", - usage=Mock(output_tokens=5), + usage=Mock(spec=["output_tokens"], output_tokens=5), ), ] From 6a13b26ee937266febbd8358c3c6cc3002be66fa Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 20 Apr 2026 18:24:00 +0800 Subject: [PATCH 08/11] fix(formatter): convert local files to base64 data in dashscope formatter (#1253) --- examples/functionality/rag/multimodal_rag.py | 2 +- .../formatter/_dashscope_formatter.py | 70 +++++++++---------- tests/formatter_dashscope_test.py | 22 +++--- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/examples/functionality/rag/multimodal_rag.py b/examples/functionality/rag/multimodal_rag.py index 3a53e1995b..04720a04c0 100644 --- a/examples/functionality/rag/multimodal_rag.py +++ b/examples/functionality/rag/multimodal_rag.py @@ -65,7 +65,7 @@ async def example_multimodal_rag() -> None: # Let's see if the agent has stored the retrieved document in its memory print("\nThe retrieved document stored in the agent's memory:") - content = (await agent.memory.get_memory())[-4].content + content = (await agent.memory.get_memory())[-2].content print(json.dumps(content, indent=2, ensure_ascii=False)) diff --git a/src/agentscope/formatter/_dashscope_formatter.py b/src/agentscope/formatter/_dashscope_formatter.py index 766876b99f..242f73f103 100644 --- a/src/agentscope/formatter/_dashscope_formatter.py +++ b/src/agentscope/formatter/_dashscope_formatter.py @@ -2,9 +2,11 @@ # pylint: disable=too-many-branches """The dashscope formatter module.""" +import base64 import json +import mimetypes import os.path -from typing import Any +from typing import Any, cast from ._truncated_formatter_base import TruncatedFormatterBase from .._logging import logger @@ -23,17 +25,17 @@ def _format_dashscope_media_block( - block: ImageBlock | AudioBlock, + block: ImageBlock | AudioBlock | VideoBlock, ) -> dict[str, str]: - """Format an image or audio block for DashScope API. + """Format an image, audio, or video block for DashScope API. Args: - block (`ImageBlock` | `AudioBlock`): - The image or audio block to format. + block (`ImageBlock | AudioBlock | VideoBlock`): + The media block to format. Returns: `dict[str, str]`: - A dictionary with "image" or "audio" key and the formatted URL or + A dictionary with the media type key and the formatted URL or data URI as value. Raises: @@ -43,9 +45,19 @@ def _format_dashscope_media_block( typ = block["type"] source = block["source"] if source["type"] == "url": - url = source["url"] + url = source["url"].removeprefix("file://") if _is_accessible_local_file(url): - return {typ: "file://" + os.path.abspath(url)} + abs_path = os.path.abspath(url) + media_type = mimetypes.guess_type(abs_path)[0] + if not media_type: + raise ValueError( + f"Cannot determine the media type of '{abs_path}'. " + "Please use a file with a recognized extension " + "(e.g., .png, .jpg, .mp3, .mp4).", + ) + with open(abs_path, "rb") as f: + base64_data = base64.b64encode(f.read()).decode("utf-8") + return {typ: f"data:{media_type};base64,{base64_data}"} else: # treat as web url return {typ: url} @@ -266,7 +278,10 @@ async def _format( elif typ in ["image", "audio", "video"]: content_blocks.append( _format_dashscope_media_block( - block, # type: ignore[arg-type] + cast( + ImageBlock | AudioBlock | VideoBlock, + block, + ), ), ) @@ -564,35 +579,14 @@ async def _format_agent_message( ) accumulated_text.clear() - if block["source"]["type"] == "url": - url = block["source"]["url"] - if _is_accessible_local_file(url): - conversation_blocks.append( - { - block["type"]: "file://" - + os.path.abspath(url), - }, - ) - else: - conversation_blocks.append({block["type"]: url}) - - elif block["source"]["type"] == "base64": - media_type = block["source"]["media_type"] - base64_data = block["source"]["data"] - conversation_blocks.append( - { - block[ - "type" - ]: f"data:{media_type};base64,{base64_data}", - }, - ) - - else: - logger.warning( - "Unsupported block type %s in the message, " - "skipped.", - block["type"], - ) + conversation_blocks.append( + _format_dashscope_media_block( + cast( + ImageBlock | AudioBlock | VideoBlock, + block, + ), + ), + ) if accumulated_text: conversation_blocks.append({"text": "\n".join(accumulated_text)}) diff --git a/tests/formatter_dashscope_test.py b/tests/formatter_dashscope_test.py index e52c8b57cb..fd3156afa0 100644 --- a/tests/formatter_dashscope_test.py +++ b/tests/formatter_dashscope_test.py @@ -27,8 +27,14 @@ class TestDashScopeFormatter(IsolatedAsyncioTestCase): async def asyncSetUp(self) -> None: """Set up the test environment.""" self.image_path = "./image.png" + self.image_content = b"fake image content" with open(self.image_path, "wb") as f: - f.write(b"fake image content") + f.write(self.image_content) + + import base64 + + b64 = base64.b64encode(self.image_content).decode("utf-8") + self.image_data_uri = f"data:image/png;base64,{b64}" self.mock_audio_path = ( "/var/folders/gf/krg8x_ws409cpw_46b2s6rjc0000gn/T/tmpfymnv2w9.wav" @@ -228,7 +234,7 @@ async def asyncSetUp(self) -> None: "text": "What is the capital of France?", }, { - "image": f"file://{os.path.abspath(self.image_path)}", + "image": self.image_data_uri, }, ], }, @@ -301,7 +307,7 @@ async def asyncSetUp(self) -> None: " is the capital of France?", }, { - "image": f"file://{os.path.abspath(self.image_path)}", + "image": self.image_data_uri, }, { "text": "assistant: The capital of France is Paris." @@ -403,7 +409,7 @@ async def asyncSetUp(self) -> None: "France?", }, { - "image": f"file://{os.path.abspath(self.image_path)}", + "image": self.image_data_uri, }, { "text": "assistant: The capital of France is Paris." @@ -583,7 +589,7 @@ async def test_chat_formatter_with_extract_media_blocks( "text": "What is the capital of France?", }, { - "image": f"file://{os.path.abspath(self.image_path)}", + "image": self.image_data_uri, }, ], }, @@ -647,7 +653,7 @@ async def test_chat_formatter_with_extract_media_blocks( "text": "\n- The image from './image.png': ", }, { - "image": f"file://{os.path.abspath(self.image_path)}", + "image": self.image_data_uri, }, { "text": "\n- The audio from " @@ -818,7 +824,7 @@ async def test_multiagent_formatter_with_promote_media_tool_result( " is the capital of France?", }, { - "image": f"file://{os.path.abspath(self.image_path)}", + "image": self.image_data_uri, }, { "text": "assistant: The capital of France is Paris." @@ -871,7 +877,7 @@ async def test_multiagent_formatter_with_promote_media_tool_result( "text": "\n- The image from './image.png': ", }, { - "image": f"file://{os.path.abspath(self.image_path)}", + "image": self.image_data_uri, }, { "text": "\n- The audio from " From 5d09848e07d95033c68dacc745e2fadbacf5804a Mon Sep 17 00:00:00 2001 From: YingchaoX Date: Mon, 20 Apr 2026 19:13:07 +0800 Subject: [PATCH 09/11] feat(deepresearchagent): guard DeepResearchAgent tool metadata access (#1489) --- .../deep_research_agent/deep_research_agent.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/agent/deep_research_agent/deep_research_agent.py b/examples/agent/deep_research_agent/deep_research_agent.py index 55cbf10bec..363e58a635 100644 --- a/examples/agent/deep_research_agent/deep_research_agent.py +++ b/examples/agent/deep_research_agent/deep_research_agent.py @@ -348,6 +348,10 @@ async def _acting(self, tool_call: ToolUseBlock) -> Msg | None: # Async generator handling async for chunk in tool_res: + chunk_metadata = ( + chunk.metadata if isinstance(chunk.metadata, dict) else {} + ) + # Turn into a tool result block tool_res_msg.content[0][ # type: ignore[index] "output" @@ -357,19 +361,19 @@ async def _acting(self, tool_call: ToolUseBlock) -> Msg | None: if ( tool_call["name"] != self.finish_function_name or tool_call["name"] == self.finish_function_name - and not chunk.metadata.get("success") + and not chunk_metadata.get("success") ): await self.print(tool_res_msg, chunk.is_last) # Return message if generate_response is called successfully if tool_call[ "name" - ] == self.finish_function_name and chunk.metadata.get( + ] == self.finish_function_name and chunk_metadata.get( "success", True, ): if len(self.current_subtask) == 0: - return chunk.metadata.get("response_msg") + return chunk_metadata.get("response_msg") # Summarize intermediate results into a draft report elif tool_call["name"] == self.summarize_function: @@ -398,11 +402,11 @@ async def _acting(self, tool_call: ToolUseBlock) -> Msg | None: ) # Update memory when an intermediate report is generated - if isinstance(chunk.metadata, dict) and chunk.metadata.get( + if chunk_metadata.get( "update_memory", ): update_memory = True - intermediate_report = chunk.metadata.get( + intermediate_report = chunk_metadata.get( "intermediate_report", ) return None From 2a92996cb4f52e765ec017aa7a9dfe8e177ed7a3 Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 20 Apr 2026 19:37:59 +0800 Subject: [PATCH 10/11] fix(unittest): hotfix for unittest (#1507) --- tests/model_anthropic_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/model_anthropic_test.py b/tests/model_anthropic_test.py index 4719a6c2d9..855135f18f 100644 --- a/tests/model_anthropic_test.py +++ b/tests/model_anthropic_test.py @@ -371,7 +371,7 @@ async def test_streaming_tool_input_prefers_valid_final_json(self) -> None: ), AnthropicEventMock( "message_delta", - usage=Mock(output_tokens=5), + usage=Mock(spec=["output_tokens"], output_tokens=5), ), ] From eb7678e1cb958125c852027c1a91480a03d7f870 Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 20 Apr 2026 20:44:54 +0800 Subject: [PATCH 11/11] chore(version): bumping verstion to 1.0.19 (#1506) --- src/agentscope/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentscope/_version.py b/src/agentscope/_version.py index ece5932720..6deda02d50 100644 --- a/src/agentscope/_version.py +++ b/src/agentscope/_version.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """The version of agentscope.""" -__version__ = "1.0.19dev" +__version__ = "1.0.19"