diff --git a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java similarity index 91% rename from a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java rename to a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java index 0fbeb0a72..0c12727aa 100644 --- a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java @@ -1,9 +1,10 @@ -package com.google.adk.a2a; +package com.google.adk.a2a.executor; + +import static java.util.Objects.requireNonNull; import com.google.adk.a2a.converters.EventConverter; import com.google.adk.a2a.converters.PartConverter; import com.google.adk.agents.BaseAgent; -import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; @@ -44,12 +45,10 @@ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); private static final String USER_ID_PREFIX = "A2A_USER_"; - private static final RunConfig DEFAULT_RUN_CONFIG = - RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); private final Map activeTasks = new ConcurrentHashMap<>(); private final Runner.Builder runnerBuilder; - private final RunConfig runConfig; + private final AgentExecutorConfig agentExecutorConfig; private AgentExecutor( App app, @@ -59,7 +58,10 @@ private AgentExecutor( BaseSessionService sessionService, BaseMemoryService memoryService, List plugins, - RunConfig runConfig) { + AgentExecutorConfig agentExecutorConfig) { + requireNonNull(agentExecutorConfig); + this.agentExecutorConfig = agentExecutorConfig; + this.runnerBuilder = Runner.builder() .agent(agent) @@ -73,7 +75,6 @@ private AgentExecutor( } // Check that the runner is configured correctly and can be built. var unused = runnerBuilder.build(); - this.runConfig = runConfig == null ? DEFAULT_RUN_CONFIG : runConfig; } /** Builder for {@link AgentExecutor}. */ @@ -85,7 +86,13 @@ public static class Builder { private BaseSessionService sessionService; private BaseMemoryService memoryService; private List plugins = ImmutableList.of(); - private RunConfig runConfig; + private AgentExecutorConfig agentExecutorConfig; + + @CanIgnoreReturnValue + public Builder agentExecutorConfig(AgentExecutorConfig agentExecutorConfig) { + this.agentExecutorConfig = agentExecutorConfig; + return this; + } @CanIgnoreReturnValue public Builder app(App app) { @@ -129,16 +136,17 @@ public Builder plugins(List plugins) { return this; } - @CanIgnoreReturnValue - public Builder runConfig(RunConfig runConfig) { - this.runConfig = runConfig; - return this; - } - @CanIgnoreReturnValue public AgentExecutor build() { return new AgentExecutor( - app, agent, appName, artifactService, sessionService, memoryService, plugins, runConfig); + app, + agent, + appName, + artifactService, + sessionService, + memoryService, + plugins, + agentExecutorConfig); } } @@ -178,7 +186,8 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { .flatMapPublisher( session -> { updater.startWork(); - return runner.runAsync(getUserId(ctx), session.id(), content, runConfig); + return runner.runAsync( + getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig()); }) .subscribe( event -> { diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java new file mode 100644 index 000000000..979a081fb --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java @@ -0,0 +1,57 @@ +package com.google.adk.a2a.executor; + +import com.google.adk.a2a.executor.Callbacks.AfterEventCallback; +import com.google.adk.a2a.executor.Callbacks.AfterExecuteCallback; +import com.google.adk.a2a.executor.Callbacks.BeforeExecuteCallback; +import com.google.adk.agents.RunConfig; +import com.google.auto.value.AutoValue; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import org.jspecify.annotations.Nullable; + +/** Configuration for the {@link AgentExecutor}. */ +@AutoValue +public abstract class AgentExecutorConfig { + + private static final RunConfig DEFAULT_RUN_CONFIG = + RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); + + public abstract @Nullable RunConfig runConfig(); + + public abstract @Nullable BeforeExecuteCallback beforeExecuteCallback(); + + public abstract @Nullable AfterExecuteCallback afterExecuteCallback(); + + public abstract @Nullable AfterEventCallback afterEventCallback(); + + public abstract Builder toBuilder(); + + public static Builder builder() { + return new AutoValue_AgentExecutorConfig.Builder(); + } + + /** Builder for {@link AgentExecutorConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + @CanIgnoreReturnValue + public abstract Builder runConfig(RunConfig runConfig); + + @CanIgnoreReturnValue + public abstract Builder beforeExecuteCallback(BeforeExecuteCallback beforeExecuteCallback); + + @CanIgnoreReturnValue + public abstract Builder afterExecuteCallback(AfterExecuteCallback afterExecuteCallback); + + @CanIgnoreReturnValue + public abstract Builder afterEventCallback(AfterEventCallback afterEventCallback); + + abstract AgentExecutorConfig autoBuild(); + + public AgentExecutorConfig build() { + AgentExecutorConfig config = autoBuild(); + if (config.runConfig() == null) { + config = config.toBuilder().runConfig(DEFAULT_RUN_CONFIG).build(); + } + return config; + } + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java b/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java new file mode 100644 index 000000000..8683e2e14 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java @@ -0,0 +1,67 @@ +package com.google.adk.a2a.executor; + +import com.google.adk.events.Event; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; + +/** Functional interfaces for agent executor lifecycle callbacks. */ +public final class Callbacks { + + private Callbacks() {} + + interface BeforeExecuteCallbackBase {} + + /** Async callback interface for actions to be performed before an execution is started. */ + @FunctionalInterface + public interface BeforeExecuteCallback extends BeforeExecuteCallbackBase { + /** + * Callback which will be called before an execution is started. It can be used to instrument a + * context or prevent the execution by returning an error. + * + * @param ctx the request context + * @return a {@link Completable} that completes when the callback is done + */ + Completable call(RequestContext ctx); + } + + interface AfterExecuteCallbackBase {} + + /** + * Async callback interface for actions to be performed after an execution is completed or failed. + */ + @FunctionalInterface + public interface AfterExecuteCallback extends AfterExecuteCallbackBase { + /** + * Callback which will be called after an execution resolved into a completed or failed task. + * This gives an opportunity to enrich the event with additional metadata or log it. + * + * @param ctx the request context + * @param finalUpdateEvent the final update event + * @return a {@link Maybe} that completes when the callback is done + */ + Maybe call(RequestContext ctx, TaskStatusUpdateEvent finalUpdateEvent); + } + + interface AfterEventCallbackBase {} + + /** Async callback interface for actions to be performed after an event is processed. */ + @FunctionalInterface + public interface AfterEventCallback extends AfterEventCallbackBase { + /** + * Callback which will be called after an ADK event is successfully converted to an A2A event. + * This gives an opportunity to enrich the event with additional metadata or abort the execution + * by returning an error. The callback is not invoked for errors originating from ADK or event + * processing. + * + * @param ctx the request context + * @param processedEvent the processed task artifact update event + * @param event the ADK event + * @return a {@link Maybe} that completes when the callback is done + */ + Maybe call( + RequestContext ctx, TaskArtifactUpdateEvent processedEvent, Event event); + } +} diff --git a/a2a/src/test/java/com/google/adk/a2a/AgentExecutorTest.java b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java similarity index 77% rename from a2a/src/test/java/com/google/adk/a2a/AgentExecutorTest.java rename to a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java index 44daf13d1..350bd6f16 100644 --- a/a2a/src/test/java/com/google/adk/a2a/AgentExecutorTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java @@ -1,4 +1,4 @@ -package com.google.adk.a2a; +package com.google.adk.a2a.executor; import static org.junit.Assert.assertThrows; @@ -32,6 +32,7 @@ public void createAgentExecutor_noAgent_succeeds() { .app(App.builder().name("test_app").rootAgent(testAgent).build()) .sessionService(new InMemorySessionService()) .artifactService(new InMemoryArtifactService()) + .agentExecutorConfig(AgentExecutorConfig.builder().build()) .build(); } @@ -44,6 +45,7 @@ public void createAgentExecutor_withAgentAndApp_throwsException() { .agent(testAgent) .app(App.builder().name("test_app").rootAgent(testAgent).build()) .sessionService(new InMemorySessionService()) + .agentExecutorConfig(AgentExecutorConfig.builder().build()) .artifactService(new InMemoryArtifactService()) .build(); }); @@ -55,6 +57,20 @@ public void createAgentExecutor_withEmptyAgentAndApp_throwsException() { IllegalStateException.class, () -> { new AgentExecutor.Builder() + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .agentExecutorConfig(AgentExecutorConfig.builder().build()) + .build(); + }); + } + + @Test + public void createAgentExecutor_noAgentExecutorConfig_throwsException() { + assertThrows( + NullPointerException.class, + () -> { + new AgentExecutor.Builder() + .app(App.builder().name("test_app").rootAgent(testAgent).build()) .sessionService(new InMemorySessionService()) .artifactService(new InMemoryArtifactService()) .build();