diff --git a/vertexai/_genai/_agent_engines_utils.py b/vertexai/_genai/_agent_engines_utils.py index 92d91addbb..76de09640e 100644 --- a/vertexai/_genai/_agent_engines_utils.py +++ b/vertexai/_genai/_agent_engines_utils.py @@ -110,6 +110,26 @@ try: + from a2a.utils.constants import TransportProtocol as _A2aVersionTest + + _A2A_SDK_VERSION: Optional[str] = "1.0" +except ImportError: + try: + from a2a.types import TransportProtocol as _A2aVersionTest + + _A2A_SDK_VERSION = "0.3" + except ImportError: + _A2A_SDK_VERSION = None + +if _A2A_SDK_VERSION == "1.0": + from a2a.types import ( + AgentCard, + Message, + ) + from a2a.client import ClientConfig, ClientFactory + from a2a.utils.constants import TransportProtocol + from a2a.compat.v0_3.types import TaskIdParams, TaskQueryParams +elif _A2A_SDK_VERSION == "0.3": from a2a.types import ( AgentCard, TransportProtocol, @@ -118,15 +138,7 @@ TaskQueryParams, ) from a2a.client import ClientConfig, ClientFactory - - AgentCard = AgentCard - TransportProtocol = TransportProtocol - Message = Message - ClientConfig = ClientConfig - ClientFactory = ClientFactory - TaskIdParams = TaskIdParams - TaskQueryParams = TaskQueryParams -except (ImportError, AttributeError): +else: AgentCard = None TransportProtocol = None Message = None @@ -134,6 +146,10 @@ ClientFactory = None TaskIdParams = None TaskQueryParams = None + SendMessageRequest = None + GetTaskRequest = None + CancelTaskRequest = None + GetExtendedAgentCardRequest = None _ACTIONS_KEY = "actions" _ACTION_APPEND = "append" @@ -1737,7 +1753,9 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any] return _method -def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]: +def _wrap_a2a_operation_v03( + method_name: str, agent_card: str +) -> Callable[..., list[Any]]: """Wraps an Agent Engine method, creating a callable for A2A API. Args: @@ -1854,6 +1872,134 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] return _method # type: ignore[return-value] +def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]: + """Wraps an Agent Engine method, creating a callable for A2A API (v1.0.0+). + + Args: + method_name: The name of the Agent Engine method to call. + agent_card: The agent card JSON string to use for the A2A API call. + Example: + { + 'name': 'Sample Agent', + 'description': ( + 'A helpful assistant agent that can answer questions.' + ), + 'supportedInterfaces': [{ + 'url': 'http://localhost:8080/a2a/rest/', + 'protocolBinding': 'HTTP+JSON', + 'protocolVersion': '1.0', + }], + 'version': '1.0.0', + 'capabilities': { + 'streaming': True, + 'pushNotifications': False, + 'extendedAgentCard': True, + }, + 'defaultInputModes': ['text'], + 'defaultOutputModes': ['text'], + 'skills': [{ + 'id': 'question_answer', + 'name': 'Q&A Agent', + 'description': ( + 'A helpful assistant agent that can answer questions.' + ), + 'tags': ['Question-Answer'], + 'examples': [ + 'Who is leading 2025 F1 Standings?', + 'Where can i find an active volcano?', + ], + 'inputModes': ['text'], + 'outputModes': ['text'], + }], + } + + Returns: + A callable object that executes the method on the Agent Engine via + the A2A API. + """ + + async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + + a2a_agent_card = AgentCard() + json_format.ParseDict( + json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True + ) + + if a2a_agent_card.supported_interfaces: + interface = a2a_agent_card.supported_interfaces[0] + if interface.protocol_binding != TransportProtocol.HTTP_JSON: + raise ValueError( + "Only HTTP+JSON is supported for preferred transport on agent card" + ) + else: + raise ValueError("Agent card does not define any supported interfaces.") + + base_url = self.api_client._api_client._http_options.base_url.rstrip("/") + api_version = self.api_client._api_client._http_options.api_version + a2a_agent_card.supported_interfaces[0].url = ( + f"{base_url}/{api_version}/{self.api_resource.name}/a2a" + ) + + config = ClientConfig( + supported_protocol_bindings=[ + TransportProtocol.HTTP_JSON, + ], + use_client_preference=True, + httpx_client=httpx.AsyncClient( + headers={ + "Authorization": ( + f"Bearer {self.api_client._api_client._credentials.token}" + ) + }, + timeout=( + self.api_client._api_client._http_options.timeout / 1000.0 + if self.api_client._api_client._http_options.timeout + else None + ), + ), + ) + factory = ClientFactory(config) + client = factory.create(a2a_agent_card) + + context = kwargs.pop("context", None) + if context is not None: + from a2a.client.client import ClientCallContext + + if not isinstance(context, ClientCallContext): + actual_context = ClientCallContext() + if hasattr(context, "state"): + actual_context.state = context.state + elif isinstance(context, dict): + actual_context.state = context + context = actual_context + + req = kwargs["request"] + if method_name == "on_message_send": + response = client.send_message(req, context=context) + chunks = [] + async for chunk in response: + chunks.append(chunk) + return chunks + elif method_name == "on_get_task": + return await client.get_task(req, context=context) + elif method_name == "on_cancel_task": + return await client.cancel_task(req, context=context) + elif method_name == "on_get_extended_agent_card": + return await client.get_extended_agent_card(req, context=context) + else: + raise ValueError(f"Unknown method name: {method_name}") + + return _method # type: ignore[return-value] + + +if _A2A_SDK_VERSION != "1.0": + _wrap_a2a_operation = _wrap_a2a_operation_v03 + + def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]: """Converts the body of the HTTP Response message to JSON format.