Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.google.adk.a2a;
package com.google.adk.a2a.executor;

import static java.util.Objects.requireNonNull;

import com.google.adk.a2a.converters.EventConverter;
import com.google.adk.a2a.converters.PartConverter;
import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.RunConfig;
import com.google.adk.apps.App;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.events.Event;
Expand Down Expand Up @@ -44,12 +45,10 @@ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor

private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class);
private static final String USER_ID_PREFIX = "A2A_USER_";
private static final RunConfig DEFAULT_RUN_CONFIG =
RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build();

private final Map<String, Disposable> activeTasks = new ConcurrentHashMap<>();
private final Runner.Builder runnerBuilder;
private final RunConfig runConfig;
private final AgentExecutorConfig agentExecutorConfig;

private AgentExecutor(
App app,
Expand All @@ -59,7 +58,10 @@ private AgentExecutor(
BaseSessionService sessionService,
BaseMemoryService memoryService,
List<? extends Plugin> plugins,
RunConfig runConfig) {
AgentExecutorConfig agentExecutorConfig) {
requireNonNull(agentExecutorConfig);
this.agentExecutorConfig = agentExecutorConfig;

this.runnerBuilder =
Runner.builder()
.agent(agent)
Expand All @@ -73,7 +75,6 @@ private AgentExecutor(
}
// Check that the runner is configured correctly and can be built.
var unused = runnerBuilder.build();
this.runConfig = runConfig == null ? DEFAULT_RUN_CONFIG : runConfig;
}

/** Builder for {@link AgentExecutor}. */
Expand All @@ -85,7 +86,13 @@ public static class Builder {
private BaseSessionService sessionService;
private BaseMemoryService memoryService;
private List<? extends Plugin> plugins = ImmutableList.of();
private RunConfig runConfig;
private AgentExecutorConfig agentExecutorConfig;

@CanIgnoreReturnValue
public Builder agentExecutorConfig(AgentExecutorConfig agentExecutorConfig) {
this.agentExecutorConfig = agentExecutorConfig;
return this;
}

@CanIgnoreReturnValue
public Builder app(App app) {
Expand Down Expand Up @@ -129,16 +136,17 @@ public Builder plugins(List<? extends Plugin> plugins) {
return this;
}

@CanIgnoreReturnValue
public Builder runConfig(RunConfig runConfig) {
this.runConfig = runConfig;
return this;
}

@CanIgnoreReturnValue
public AgentExecutor build() {
return new AgentExecutor(
app, agent, appName, artifactService, sessionService, memoryService, plugins, runConfig);
app,
agent,
appName,
artifactService,
sessionService,
memoryService,
plugins,
agentExecutorConfig);
}
}

Expand Down Expand Up @@ -178,7 +186,8 @@ public void execute(RequestContext ctx, EventQueue eventQueue) {
.flatMapPublisher(
session -> {
updater.startWork();
return runner.runAsync(getUserId(ctx), session.id(), content, runConfig);
return runner.runAsync(
getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig());
})
.subscribe(
event -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.google.adk.a2a.executor;

import com.google.adk.a2a.executor.Callbacks.AfterEventCallback;
import com.google.adk.a2a.executor.Callbacks.AfterExecuteCallback;
import com.google.adk.a2a.executor.Callbacks.BeforeExecuteCallback;
import com.google.adk.agents.RunConfig;
import com.google.auto.value.AutoValue;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import org.jspecify.annotations.Nullable;

/** Configuration for the {@link AgentExecutor}. */
@AutoValue
public abstract class AgentExecutorConfig {

private static final RunConfig DEFAULT_RUN_CONFIG =
RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build();

public abstract @Nullable RunConfig runConfig();

public abstract @Nullable BeforeExecuteCallback beforeExecuteCallback();

public abstract @Nullable AfterExecuteCallback afterExecuteCallback();

public abstract @Nullable AfterEventCallback afterEventCallback();

public abstract Builder toBuilder();

public static Builder builder() {
return new AutoValue_AgentExecutorConfig.Builder();
}

/** Builder for {@link AgentExecutorConfig}. */
@AutoValue.Builder
public abstract static class Builder {
@CanIgnoreReturnValue
public abstract Builder runConfig(RunConfig runConfig);

@CanIgnoreReturnValue
public abstract Builder beforeExecuteCallback(BeforeExecuteCallback beforeExecuteCallback);

@CanIgnoreReturnValue
public abstract Builder afterExecuteCallback(AfterExecuteCallback afterExecuteCallback);

@CanIgnoreReturnValue
public abstract Builder afterEventCallback(AfterEventCallback afterEventCallback);

abstract AgentExecutorConfig autoBuild();

public AgentExecutorConfig build() {
AgentExecutorConfig config = autoBuild();
if (config.runConfig() == null) {
config = config.toBuilder().runConfig(DEFAULT_RUN_CONFIG).build();
}
return config;
}
}
}
67 changes: 67 additions & 0 deletions a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.google.adk.a2a.executor;

import com.google.adk.events.Event;
import io.a2a.server.agentexecution.RequestContext;
import io.a2a.spec.TaskArtifactUpdateEvent;
import io.a2a.spec.TaskStatusUpdateEvent;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Maybe;

/** Functional interfaces for agent executor lifecycle callbacks. */
public final class Callbacks {

private Callbacks() {}

interface BeforeExecuteCallbackBase {}

/** Async callback interface for actions to be performed before an execution is started. */
@FunctionalInterface
public interface BeforeExecuteCallback extends BeforeExecuteCallbackBase {
/**
* Callback which will be called before an execution is started. It can be used to instrument a
* context or prevent the execution by returning an error.
*
* @param ctx the request context
* @return a {@link Completable} that completes when the callback is done
*/
Completable call(RequestContext ctx);
}

interface AfterExecuteCallbackBase {}

/**
* Async callback interface for actions to be performed after an execution is completed or failed.
*/
@FunctionalInterface
public interface AfterExecuteCallback extends AfterExecuteCallbackBase {
/**
* Callback which will be called after an execution resolved into a completed or failed task.
* This gives an opportunity to enrich the event with additional metadata or log it.
*
* @param ctx the request context
* @param finalUpdateEvent the final update event
* @return a {@link Maybe} that completes when the callback is done
*/
Maybe<TaskStatusUpdateEvent> call(RequestContext ctx, TaskStatusUpdateEvent finalUpdateEvent);
}

interface AfterEventCallbackBase {}

/** Async callback interface for actions to be performed after an event is processed. */
@FunctionalInterface
public interface AfterEventCallback extends AfterEventCallbackBase {
/**
* Callback which will be called after an ADK event is successfully converted to an A2A event.
* This gives an opportunity to enrich the event with additional metadata or abort the execution
* by returning an error. The callback is not invoked for errors originating from ADK or event
* processing.
*
* @param ctx the request context
* @param processedEvent the processed task artifact update event
* @param event the ADK event
* @return a {@link Maybe} that completes when the callback is done
*/
Maybe<TaskArtifactUpdateEvent> call(
RequestContext ctx, TaskArtifactUpdateEvent processedEvent, Event event);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.google.adk.a2a;
package com.google.adk.a2a.executor;

import static org.junit.Assert.assertThrows;

Expand Down Expand Up @@ -32,6 +32,7 @@ public void createAgentExecutor_noAgent_succeeds() {
.app(App.builder().name("test_app").rootAgent(testAgent).build())
.sessionService(new InMemorySessionService())
.artifactService(new InMemoryArtifactService())
.agentExecutorConfig(AgentExecutorConfig.builder().build())
.build();
}

Expand All @@ -44,6 +45,7 @@ public void createAgentExecutor_withAgentAndApp_throwsException() {
.agent(testAgent)
.app(App.builder().name("test_app").rootAgent(testAgent).build())
.sessionService(new InMemorySessionService())
.agentExecutorConfig(AgentExecutorConfig.builder().build())
.artifactService(new InMemoryArtifactService())
.build();
});
Expand All @@ -55,6 +57,20 @@ public void createAgentExecutor_withEmptyAgentAndApp_throwsException() {
IllegalStateException.class,
() -> {
new AgentExecutor.Builder()
.sessionService(new InMemorySessionService())
.artifactService(new InMemoryArtifactService())
.agentExecutorConfig(AgentExecutorConfig.builder().build())
.build();
});
}

@Test
public void createAgentExecutor_noAgentExecutorConfig_throwsException() {
assertThrows(
NullPointerException.class,
() -> {
new AgentExecutor.Builder()
.app(App.builder().name("test_app").rootAgent(testAgent).build())
.sessionService(new InMemorySessionService())
.artifactService(new InMemoryArtifactService())
.build();
Expand Down