Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 66 additions & 49 deletions tests/integration/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from collections.abc import AsyncGenerator
from typing import NamedTuple, cast
from typing import NamedTuple

import grpc
import httpx
import pytest
import pytest_asyncio

from a2a.client.base_client import BaseClient
from a2a.client.client import Client, ClientConfig
from a2a.client.client import ClientConfig
from a2a.client.client_factory import ClientFactory
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication
Expand All @@ -26,7 +26,6 @@
Part,
Role,
SendMessageConfiguration,
SendMessageRequest,
TaskState,
a2a_pb2_grpc,
)
Expand All @@ -42,6 +41,9 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
)
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
await task_updater.add_artifact(
parts=[Part(text='artifact content')], name='test-artifact'
)
await task_updater.update_status(
TaskState.TASK_STATE_COMPLETED,
message=task_updater.new_agent_message([Part(text='done')]),
Expand Down Expand Up @@ -80,7 +82,7 @@ def agent_card() -> AgentCard:
)


class TransportSetup(NamedTuple):
class ClientSetup(NamedTuple):
"""Holds the client and task_store for a given test."""

client: BaseClient
Expand All @@ -99,7 +101,7 @@ def base_e2e_setup():


@pytest.fixture
def rest_setup(agent_card, base_e2e_setup) -> TransportSetup:
def rest_setup(agent_card, base_e2e_setup) -> ClientSetup:
task_store, handler = base_e2e_setup
app_builder = A2ARESTFastAPIApplication(agent_card, handler)
app = app_builder.build()
Expand All @@ -112,15 +114,15 @@ def rest_setup(agent_card, base_e2e_setup) -> TransportSetup:
supported_protocol_bindings=[TransportProtocol.HTTP_JSON],
)
)
client = cast(BaseClient, factory.create(agent_card))
return TransportSetup(
client = factory.create(agent_card)
return ClientSetup(
client=client,
task_store=task_store,
)


@pytest.fixture
def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup:
task_store, handler = base_e2e_setup
app_builder = A2AFastAPIApplication(
agent_card, handler, extended_agent_card=agent_card
Expand All @@ -135,8 +137,8 @@ def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
supported_protocol_bindings=[TransportProtocol.JSONRPC],
)
)
client = cast(BaseClient, factory.create(agent_card))
return TransportSetup(
client = factory.create(agent_card)
return ClientSetup(
client=client,
task_store=task_store,
)
Expand All @@ -145,7 +147,7 @@ def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
@pytest_asyncio.fixture
async def grpc_setup(
agent_card: AgentCard, base_e2e_setup
) -> AsyncGenerator[TransportSetup, None]:
) -> AsyncGenerator[ClientSetup, None]:
task_store, handler = base_e2e_setup
server = grpc.aio.server()
port = server.add_insecure_port('[::]:0')
Expand All @@ -168,12 +170,12 @@ async def grpc_setup(

factory = ClientFactory(
config=ClientConfig(
grpc_channel_factory=lambda url: grpc.aio.insecure_channel(url),
grpc_channel_factory=grpc.aio.insecure_channel,
supported_protocol_bindings=[TransportProtocol.GRPC],
)
)
client = cast(BaseClient, factory.create(grpc_agent_card))
yield TransportSetup(
client = factory.create(grpc_agent_card)
yield ClientSetup(
client=client,
task_store=task_store,
)
Expand All @@ -189,14 +191,15 @@ async def grpc_setup(
pytest.param('grpc_setup', id='gRPC'),
]
)
def transport_setups(request) -> TransportSetup:
def transport_setups(request) -> ClientSetup:
"""Parametrized fixture that runs tests against all supported transports."""
return request.getfixturevalue(request.param)


@pytest.mark.asyncio
async def test_end_to_end_send_message_blocking(transport_setups):
client = transport_setups.client
client._config.streaming = False

message_to_send = Message(
role=Role.ROLE_USER,
Expand All @@ -211,16 +214,19 @@ async def test_end_to_end_send_message_blocking(transport_setups):
request=message_to_send, configuration=configuration
)
]
response, task = events[-1]

assert task
assert task.id
assert task.status.state == TaskState.TASK_STATE_COMPLETED
assert len(events) == 1
response, _ = events[0]
assert response.task.id
assert response.task.status.state == TaskState.TASK_STATE_COMPLETED
assert len(response.task.artifacts) == 1
assert response.task.artifacts[0].name == 'test-artifact'
assert response.task.artifacts[0].parts[0].text == 'artifact content'
Comment thread
sokoliva marked this conversation as resolved.


@pytest.mark.asyncio
async def test_end_to_end_send_message_non_blocking(transport_setups):
client = transport_setups.client
client._config.streaming = False

message_to_send = Message(
role=Role.ROLE_USER,
Expand All @@ -235,10 +241,10 @@ async def test_end_to_end_send_message_non_blocking(transport_setups):
request=message_to_send, configuration=configuration
)
]
response, task = events[-1]

assert task
assert task.id
assert len(events) == 1
response, _ = events[0]
assert response.task.id
assert response.task.status.state == TaskState.TASK_STATE_SUBMITTED


@pytest.mark.asyncio
Expand All @@ -252,20 +258,29 @@ async def test_end_to_end_send_message_streaming(transport_setups):
)

events = [
event async for event in client.send_message(request=message_to_send)
event async for event, _ in client.send_message(request=message_to_send)
]

assert len(events) > 0
stream_response, task = events[-1]
expected_events = [
('status_update', TaskState.TASK_STATE_SUBMITTED),
('status_update', TaskState.TASK_STATE_WORKING),
('artifact_update', None),
('status_update', TaskState.TASK_STATE_COMPLETED),
]

assert stream_response.HasField('status_update')
assert stream_response.status_update.task_id
assert (
stream_response.status_update.status.state
== TaskState.TASK_STATE_COMPLETED
)
assert task
assert task.status.state == TaskState.TASK_STATE_COMPLETED
assert len(events) == len(expected_events)
for event, (expected_type, expected_state) in zip(
events, expected_events, strict=True
):
assert event.HasField(expected_type)
if expected_type == 'status_update':
assert event.status_update.status.state == expected_state
elif expected_type == 'artifact_update':
assert event.artifact_update.artifact.name == 'test-artifact'
assert (
event.artifact_update.artifact.parts[0].text
== 'artifact content'
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -301,21 +316,23 @@ async def test_end_to_end_list_tasks(transport_setups):
total_items = 6
page_size = 2

expected_task_ids = []
for i in range(total_items):
# We need to await the iterator to ensure request completes
async for _ in client.send_message(
request=Message(
role=Role.ROLE_USER,
message_id=f'msg-e2e-list-{i}',
parts=[Part(text=f'Test List Tasks {i}')],
),
configuration=SendMessageConfiguration(blocking=False),
):
pass
# One event is enough to get the task ID
_, task = await anext(
client.send_message(
request=Message(
role=Role.ROLE_USER,
message_id=f'msg-e2e-list-{i}',
parts=[Part(text=f'Test List Tasks {i}')],
)
)
)
expected_task_ids.append(task.id)

list_request = ListTasksRequest(page_size=page_size)

unique_task_ids = set()
actual_task_ids = []
token = None

while token != '':
Expand All @@ -327,9 +344,9 @@ async def test_end_to_end_list_tasks(transport_setups):
assert list_response.total_size == total_items
assert list_response.page_size == page_size

for task in list_response.tasks:
unique_task_ids.add(task.id)
actual_task_ids.extend([task.id for task in list_response.tasks])

token = list_response.next_page_token

assert len(unique_task_ids) == total_items
assert len(actual_task_ids) == total_items
assert sorted(actual_task_ids) == sorted(expected_task_ids)
Loading