diff --git a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/A2aAgent.java b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/A2aAgent.java index 0c27f81a2..2fb2a1a7e 100644 --- a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/A2aAgent.java +++ b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/A2aAgent.java @@ -120,6 +120,7 @@ protected Mono doCall(List msgs) { LoggerUtil.debug(log, "[{}] A2aAgent call with input messages: ", currentRequestId); LoggerUtil.logTextMsgDetail(log, memory.getMessages()); clientEventContext.setHooks(getSortedHooks()); + clientEventContext.setInputMessages(memory.getMessages()); return Mono.defer( () -> { Message message = diff --git a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/ClientEventContext.java b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/ClientEventContext.java index acdf504d4..9c297b8a0 100644 --- a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/ClientEventContext.java +++ b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/ClientEventContext.java @@ -19,8 +19,13 @@ import io.a2a.spec.Task; import io.agentscope.core.a2a.agent.A2aAgent; import io.agentscope.core.hook.Hook; +import io.agentscope.core.hook.PostReasoningEvent; +import io.agentscope.core.hook.PreReasoningEvent; +import io.agentscope.core.hook.ReasoningChunkEvent; import io.agentscope.core.message.Msg; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; /** @@ -40,6 +45,16 @@ public class ClientEventContext { private Task task; + /** + * Temporarily store the complete historical dialogue context at the time of this call, + * specifically for use in constructing PreReasoning Events using the {@link #publishPreReasoning()} method. + */ + private List inputMessages; + + // Ensure that lifecycle events are triggered only once + private final AtomicBoolean preReasoningFired = new AtomicBoolean(false); + private final AtomicBoolean postReasoningFired = new AtomicBoolean(false); + public ClientEventContext(String currentRequestId, A2aAgent agent) { this.currentRequestId = currentRequestId; this.agent = agent; @@ -76,4 +91,70 @@ public Task getTask() { public void setTask(Task task) { this.task = task; } + + public void setInputMessages(List inputMessages) { + this.inputMessages = inputMessages; + } + + // ========================================== + // Unified Event Publishing API + // ========================================== + + /** + * Trigger PreReasoningEvent (triggered only once) + */ + void publishPreReasoning() { + if (hooks != null && !hooks.isEmpty() && preReasoningFired.compareAndSet(false, true)) { + List msgs = inputMessages == null ? List.of() : inputMessages; + PreReasoningEvent preEvent = new PreReasoningEvent(agent, "A2A", null, msgs); + + Mono eventMono = Mono.just(preEvent); + for (Hook hook : hooks) { + eventMono = eventMono.flatMap(hook::onEvent); + } + eventMono.block(); + } + } + + /** + * Trigger ReasoningChunkEvent (streaming process) + */ + void publishReasoningChunk(Msg chunkMsg) { + if (hooks != null && !hooks.isEmpty()) { + publishPreReasoning(); // If not sent Pre before, send Pre first + ReasoningChunkEvent chunkEvent = + new ReasoningChunkEvent(agent, "A2A", null, chunkMsg, chunkMsg); + + Mono eventMono = Mono.just(chunkEvent); + for (Hook hook : hooks) { + eventMono = eventMono.flatMap(hook::onEvent); + } + eventMono.block(); + } + } + + /** + * Trigger PostReasoningEvent (triggered only once) and return the final reasoning message + * after hooks have had a chance to modify it. + * + * @param finalMsg the original final reasoning message + * @return the hook-modified reasoning message, or {@code finalMsg} if no hooks ran or no + * modification was applied + */ + Msg publishPostReasoning(Msg finalMsg) { + if (hooks != null && !hooks.isEmpty() && postReasoningFired.compareAndSet(false, true)) { + publishPreReasoning(); + PostReasoningEvent postEvent = new PostReasoningEvent(agent, "A2A", null, finalMsg); + + Mono eventMono = Mono.just(postEvent); + for (Hook hook : hooks) { + eventMono = eventMono.flatMap(hook::onEvent); + } + postEvent = eventMono.block(); + + Msg modifiedMsg = postEvent.getReasoningMessage(); + return modifiedMsg != null ? modifiedMsg : finalMsg; + } + return finalMsg; + } } diff --git a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/MessageEventHandler.java b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/MessageEventHandler.java index c578f2c5b..09e212532 100644 --- a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/MessageEventHandler.java +++ b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/MessageEventHandler.java @@ -42,6 +42,10 @@ public void handle(MessageEvent event, ClientEventContext context) { Msg msg = MessageConvertUtil.convertFromMessage( event.getMessage(), context.getAgent().getName()); + + // Automatically trigger PreReasoningEvent and PostReasoningEvent + msg = context.publishPostReasoning(msg); + context.getSink().success(msg); LoggerUtil.info(log, "[{}] A2aAgent complete call.", currentRequestId); LoggerUtil.debug(log, "[{}] A2aAgent complete with artifact messages: ", currentRequestId); diff --git a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/TaskEventHandler.java b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/TaskEventHandler.java index 4f468d25b..98e88d64f 100644 --- a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/TaskEventHandler.java +++ b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/TaskEventHandler.java @@ -44,5 +44,7 @@ public void handle(TaskEvent event, ClientEventContext context) { context.getCurrentRequestId(), task.getId(), task.getStatus()); + + context.publishPreReasoning(); } } diff --git a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/TaskUpdateEventHandler.java b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/TaskUpdateEventHandler.java index 73f9b9164..5c0662376 100644 --- a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/TaskUpdateEventHandler.java +++ b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/main/java/io/agentscope/core/a2a/agent/event/TaskUpdateEventHandler.java @@ -23,7 +23,6 @@ import io.a2a.spec.UpdateEvent; import io.agentscope.core.a2a.agent.utils.LoggerUtil; import io.agentscope.core.a2a.agent.utils.MessageConvertUtil; -import io.agentscope.core.hook.ReasoningChunkEvent; import io.agentscope.core.message.Msg; import java.util.HashMap; import java.util.List; @@ -93,6 +92,9 @@ public void handle(TaskStatusUpdateEvent event, ClientEventContext context) { Msg msg = MessageConvertUtil.convertFromArtifact( context.getTask().getArtifacts(), context.getAgent().getName()); + + msg = context.publishPostReasoning(msg); + context.getSink().success(msg); LoggerUtil.info(log, "[{}] A2aAgent complete call.", currentRequestId); LoggerUtil.debug( @@ -114,9 +116,8 @@ public void handle(TaskStatusUpdateEvent event, ClientEventContext context) { LoggerUtil.debug( log, "[{}] A2aAgent task status updated with messages: ", currentRequestId); LoggerUtil.logTextMsgDetail(log, List.of(msg)); - ReasoningChunkEvent chunkEvent = - new ReasoningChunkEvent(context.getAgent(), "A2A", null, msg, msg); - context.getHooks().forEach(hook -> hook.onEvent(chunkEvent).block()); + + context.publishReasoningChunk(msg); } } } @@ -136,9 +137,8 @@ public void handle(TaskArtifactUpdateEvent event, ClientEventContext context) { LoggerUtil.debug( log, "[{}] A2aAgent artifact append with messages: ", currentRequestTaskId); LoggerUtil.logTextMsgDetail(log, List.of(msg)); - ReasoningChunkEvent chunkEvent = - new ReasoningChunkEvent(context.getAgent(), "A2A", null, msg, msg); - context.getHooks().forEach(hook -> hook.onEvent(chunkEvent).block()); + + context.publishReasoningChunk(msg); } } } diff --git a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/test/java/io/agentscope/core/a2a/agent/A2aAgentTest.java b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/test/java/io/agentscope/core/a2a/agent/A2aAgentTest.java index ac3e71897..21317af56 100644 --- a/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/test/java/io/agentscope/core/a2a/agent/A2aAgentTest.java +++ b/agentscope-extensions/agentscope-extensions-a2a/agentscope-extensions-a2a-client/src/test/java/io/agentscope/core/a2a/agent/A2aAgentTest.java @@ -58,7 +58,10 @@ import io.agentscope.core.agent.Event; import io.agentscope.core.hook.Hook; import io.agentscope.core.hook.HookEvent; +import io.agentscope.core.hook.PostReasoningEvent; import io.agentscope.core.hook.PreCallEvent; +import io.agentscope.core.hook.PreReasoningEvent; +import io.agentscope.core.hook.ReasoningChunkEvent; import io.agentscope.core.message.Msg; import java.lang.reflect.Field; import java.util.HashMap; @@ -69,6 +72,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; @@ -282,9 +286,10 @@ void testStreamAgentWithDefaultTransport() { List streamResults = agent.stream(Msg.builder().textContent("test").build()).collectList().block(); assertNotNull(streamResults); - assertEquals(2, streamResults.size()); - assertFalse(streamResults.get(0).isLast()); - assertTrue(streamResults.get(1).isLast()); + assertEquals(3, streamResults.size()); + assertFalse(streamResults.get(0).isLast()); // ReasoningChunkEvent + assertTrue(streamResults.get(1).isLast()); // PostReasoningEvent + assertTrue(streamResults.get(2).isLast()); // AGENT_RESULT } @Test @@ -428,6 +433,113 @@ void testCallAgentWithDefaultTransportByObserve() { assertEquals(3, agent.getMemory().getMessages().size()); } + @Test + @DisplayName("Should trigger Pre, Chunk, Post reasoning events") + void testAgentLifecycleHooksTriggeredCorrectly() { + AtomicInteger preCount = new AtomicInteger(0); + AtomicInteger chunkCount = new AtomicInteger(0); + AtomicInteger postCount = new AtomicInteger(0); + + Hook lifecycleMonitorHook = + new Hook() { + @Override + public Mono onEvent(T event) { + if (event instanceof PreReasoningEvent) { + preCount.incrementAndGet(); + } else if (event instanceof ReasoningChunkEvent) { + chunkCount.incrementAndGet(); + } else if (event instanceof PostReasoningEvent) { + postCount.incrementAndGet(); + } + return Mono.just(event); + } + + @Override + public int priority() { + return 1; + } + }; + + A2aAgent agent = + A2aAgent.builder() + .name("test-lifecycle-agent") + .agentCard(agentCard) + .hook(new ReplaceA2aClientHook()) + .hook(lifecycleMonitorHook) + .build(); + + Answer mockTaskResponse = + invocation -> { + @SuppressWarnings("unchecked") + List> a2aEventConsumer = + invocation.getArgument(1, List.class); + + // Task creation + Task initialTask = + new Task.Builder() + .id("t1") + .contextId("c1") + .status(new TaskStatus(TaskState.WORKING)) + .build(); + a2aEventConsumer.forEach(c -> c.accept(new TaskEvent(initialTask), agentCard)); + + // Stream output a piece of text (Artifact Update) + TaskArtifactUpdateEvent chunkEvent = + new TaskArtifactUpdateEvent.Builder() + .taskId("t1") + .contextId("c1") + .artifact( + new Artifact.Builder() + .artifactId("a1") + .name("mockArtifact") + .parts(new TextPart("Hello A2A")) + .build()) + .build(); + Task workingTask = + new Task.Builder() + .id("t1") + .contextId("c1") + .status(new TaskStatus(TaskState.WORKING)) + .artifacts(List.of(chunkEvent.getArtifact())) + .build(); + a2aEventConsumer.forEach( + c -> c.accept(new TaskUpdateEvent(workingTask, chunkEvent), agentCard)); + + // Task complete (Status Update - COMPLETED) + Task completedTask = + new Task.Builder() + .id("t1") + .contextId("c1") + .status(new TaskStatus(TaskState.COMPLETED)) + .artifacts(List.of(chunkEvent.getArtifact())) + .build(); + TaskStatusUpdateEvent completeEvent = + new TaskStatusUpdateEvent( + "t1", + new TaskStatus(TaskState.COMPLETED), + "c1", + true, + Map.of()); + a2aEventConsumer.forEach( + c -> + c.accept( + new TaskUpdateEvent(completedTask, completeEvent), + agentCard)); + + return null; + }; + + doAnswer(mockTaskResponse) + .when(a2aClient) + .sendMessage(any(Message.class), anyList(), any()); + + agent.stream(Msg.builder().textContent("测试触发").build()).collectList().block(); + + assertEquals(1, preCount.get()); + assertEquals(1, chunkCount.get()); + assertEquals(1, postCount.get()); + } + private Answer mockSuccessMessage() { return invocationOnMock -> { @SuppressWarnings("unchecked")