Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions examples/agent/deep_research_agent/deep_research_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __init__(
self._search_mcp_client = search_mcp_client
self._mcp_initialized = False

self.search_function = "tavily-search"
self.extract_function = "tavily-extract"
self.search_function = "tavily_search"
self.extract_function = "tavily_extract"
self.read_file_function = "view_text_file"
self.write_file_function = "write_text_file"
self.summarize_function = "summarize_intermediate_results"
Expand Down Expand Up @@ -348,6 +348,10 @@ async def _acting(self, tool_call: ToolUseBlock) -> Msg | None:

# Async generator handling
async for chunk in tool_res:
chunk_metadata = (
chunk.metadata if isinstance(chunk.metadata, dict) else {}
)

# Turn into a tool result block
tool_res_msg.content[0][ # type: ignore[index]
"output"
Expand All @@ -357,19 +361,19 @@ async def _acting(self, tool_call: ToolUseBlock) -> Msg | None:
if (
tool_call["name"] != self.finish_function_name
or tool_call["name"] == self.finish_function_name
and not chunk.metadata.get("success")
and not chunk_metadata.get("success")
):
await self.print(tool_res_msg, chunk.is_last)

# Return message if generate_response is called successfully
if tool_call[
"name"
] == self.finish_function_name and chunk.metadata.get(
] == self.finish_function_name and chunk_metadata.get(
"success",
True,
):
if len(self.current_subtask) == 0:
return chunk.metadata.get("response_msg")
return chunk_metadata.get("response_msg")

# Summarize intermediate results into a draft report
elif tool_call["name"] == self.summarize_function:
Expand Down Expand Up @@ -398,11 +402,11 @@ async def _acting(self, tool_call: ToolUseBlock) -> Msg | None:
)

# Update memory when an intermediate report is generated
if isinstance(chunk.metadata, dict) and chunk.metadata.get(
if chunk_metadata.get(
"update_memory",
):
update_memory = True
intermediate_report = chunk.metadata.get(
intermediate_report = chunk_metadata.get(
"intermediate_report",
)
return None
Expand Down
2 changes: 1 addition & 1 deletion examples/functionality/rag/multimodal_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def example_multimodal_rag() -> None:

# Let's see if the agent has stored the retrieved document in its memory
print("\nThe retrieved document stored in the agent's memory:")
content = (await agent.memory.get_memory())[-4].content
content = (await agent.memory.get_memory())[-2].content
print(json.dumps(content, indent=2, ensure_ascii=False))


Expand Down
2 changes: 1 addition & 1 deletion src/agentscope/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""The version of agentscope."""

__version__ = "1.0.19dev"
__version__ = "1.0.19"
24 changes: 18 additions & 6 deletions src/agentscope/agent/_agent_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def _wrap_with_hooks(
"""
func_name = original_func.__name__.replace("_", "")

hook_guard_attr = f"_hook_running_{func_name}"

@wraps(original_func)
async def async_wrapper(
self: AgentBase,
Expand All @@ -72,6 +74,12 @@ async def async_wrapper(
"""The wrapped function, which call the pre- and post-hooks before and
after the original function."""

# Guard against re-entrant hook execution when multiple classes
# in the MRO define the same method (each wrapped independently
# by the metaclass). Only the outermost wrapper runs hooks.
if getattr(self, hook_guard_attr, False):
return await original_func(self, *args, **kwargs)

# Unify all positional and keyword arguments into a keyword arguments
normalized_kwargs = _normalize_to_kwargs(
original_func,
Expand Down Expand Up @@ -117,12 +125,16 @@ async def async_wrapper(
for k, v in current_normalized_kwargs.items()
if k not in ["args", "kwargs"]
}
current_output = await original_func(
self,
*args,
**others,
**kwargs,
)
setattr(self, hook_guard_attr, True)
try:
current_output = await original_func(
self,
*args,
**others,
**kwargs,
)
finally:
setattr(self, hook_guard_attr, False)

# post_hooks
post_hooks = list(
Expand Down
109 changes: 105 additions & 4 deletions src/agentscope/formatter/_anthropic_formatter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,100 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches, too-many-nested-blocks
"""The Anthropic formatter module."""

import base64
import os
from typing import Any
from urllib.parse import urlparse

from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase


def _format_anthropic_image_block(image_block: ImageBlock) -> dict:
"""Format an image block for Anthropic API. If the source is a URLSource
pointing to a local file, it will be converted to base64 format.

Args:
image_block (`ImageBlock`):
The image block to format.

Returns:
`dict`:
A dictionary in Anthropic image block format.

Raises:
`ValueError`:
If the source type or image format is not supported.
"""
import filetype

# See https://platform.openai.com/docs/guides/vision for details of
# support image extensions.
support_image_extensions = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}

source = image_block["source"]

if source["type"] == "base64":
return {**image_block}

url = source["url"]
raw_url = url.removeprefix("file://")

if os.path.exists(raw_url) and os.path.isfile(raw_url):
ext = os.path.splitext(raw_url)[1].lower()
media_type = support_image_extensions.get(ext)
if media_type:
with open(raw_url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
# No extension - detect file type using filetype
kind = filetype.guess(raw_url)
if kind is not None and kind.mime.startswith("image/"):
with open(raw_url, "rb") as image_file:
data = base64.b64encode(image_file.read()).decode(
"utf-8",
)
return {
"type": "image",
"source": {
"type": "base64",
"media_type": kind.mime,
"data": data,
},
}

# For web urls
parsed_url = urlparse(raw_url)
if parsed_url.scheme not in ("", "file"):
return {
"type": "image",
"source": {
"type": "url",
"url": url,
},
}

raise ValueError(
f'Invalid image URL: "{url}". It should be a local file or a web URL.',
)


class AnthropicChatFormatter(TruncatedFormatterBase):
"""The Anthropic formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
Expand Down Expand Up @@ -63,9 +148,16 @@ async def _format(

for block in msg.get_content_blocks():
typ = block.get("type")
if typ in ["thinking", "text", "image"]:
if typ in ["thinking", "text"]:
content_blocks.append({**block})

elif typ == "image":
content_blocks.append(
_format_anthropic_image_block(
block, # type: ignore[arg-type]
),
)

elif typ == "tool_use":
content_blocks.append(
{
Expand All @@ -81,7 +173,12 @@ async def _format(
if output is None:
content_value = [{"type": "text", "text": None}]
elif isinstance(output, list):
content_value = output
content_value = [
_format_anthropic_image_block(item)
if item.get("type") == "image"
else item
for item in output
]
else:
content_value = [{"type": "text", "text": str(output)}]
messages.append(
Expand Down Expand Up @@ -207,7 +304,11 @@ async def _format_agent_message(
)
accumulated_text.clear()

conversation_blocks.append({**block})
conversation_blocks.append(
_format_anthropic_image_block(
block, # type: ignore[arg-type]
),
)

if accumulated_text:
conversation_blocks.append(
Expand Down
70 changes: 32 additions & 38 deletions src/agentscope/formatter/_dashscope_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# pylint: disable=too-many-branches
"""The dashscope formatter module."""

import base64
import json
import mimetypes
import os.path
from typing import Any
from typing import Any, cast

from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
Expand All @@ -23,17 +25,17 @@


def _format_dashscope_media_block(
block: ImageBlock | AudioBlock,
block: ImageBlock | AudioBlock | VideoBlock,
) -> dict[str, str]:
"""Format an image or audio block for DashScope API.
"""Format an image, audio, or video block for DashScope API.

Args:
block (`ImageBlock` | `AudioBlock`):
The image or audio block to format.
block (`ImageBlock | AudioBlock | VideoBlock`):
The media block to format.

Returns:
`dict[str, str]`:
A dictionary with "image" or "audio" key and the formatted URL or
A dictionary with the media type key and the formatted URL or
data URI as value.

Raises:
Expand All @@ -43,9 +45,19 @@ def _format_dashscope_media_block(
typ = block["type"]
source = block["source"]
if source["type"] == "url":
url = source["url"]
url = source["url"].removeprefix("file://")
if _is_accessible_local_file(url):
return {typ: "file://" + os.path.abspath(url)}
abs_path = os.path.abspath(url)
media_type = mimetypes.guess_type(abs_path)[0]
if not media_type:
raise ValueError(
f"Cannot determine the media type of '{abs_path}'. "
"Please use a file with a recognized extension "
"(e.g., .png, .jpg, .mp3, .mp4).",
)
with open(abs_path, "rb") as f:
base64_data = base64.b64encode(f.read()).decode("utf-8")
return {typ: f"data:{media_type};base64,{base64_data}"}
else:
# treat as web url
return {typ: url}
Expand Down Expand Up @@ -266,7 +278,10 @@ async def _format(
elif typ in ["image", "audio", "video"]:
content_blocks.append(
_format_dashscope_media_block(
block, # type: ignore[arg-type]
cast(
ImageBlock | AudioBlock | VideoBlock,
block,
),
),
)

Expand Down Expand Up @@ -564,35 +579,14 @@ async def _format_agent_message(
)
accumulated_text.clear()

if block["source"]["type"] == "url":
url = block["source"]["url"]
if _is_accessible_local_file(url):
conversation_blocks.append(
{
block["type"]: "file://"
+ os.path.abspath(url),
},
)
else:
conversation_blocks.append({block["type"]: url})

elif block["source"]["type"] == "base64":
media_type = block["source"]["media_type"]
base64_data = block["source"]["data"]
conversation_blocks.append(
{
block[
"type"
]: f"data:{media_type};base64,{base64_data}",
},
)

else:
logger.warning(
"Unsupported block type %s in the message, "
"skipped.",
block["type"],
)
conversation_blocks.append(
_format_dashscope_media_block(
cast(
ImageBlock | AudioBlock | VideoBlock,
block,
),
),
)

if accumulated_text:
conversation_blocks.append({"text": "\n".join(accumulated_text)})
Expand Down
Loading
Loading