diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java index 94a54aa67..3d66a4e07 100644 --- a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java @@ -20,11 +20,19 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.Artifact; import io.a2a.spec.InvalidAgentResponseError; import io.a2a.spec.Message; import io.a2a.spec.Part; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.disposables.Disposable; import java.util.HashMap; @@ -43,10 +51,8 @@ * use in production code. */ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { - private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); private static final String USER_ID_PREFIX = "A2A_USER_"; - private final Map activeTasks = new ConcurrentHashMap<>(); private final Runner.Builder runnerBuilder; private final AgentExecutorConfig agentExecutorConfig; @@ -137,7 +143,6 @@ public Builder plugins(List plugins) { return this; } - @CanIgnoreReturnValue public AgentExecutor build() { return new AgentExecutor( app, @@ -165,46 +170,88 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { if (message == null) { throw new IllegalArgumentException("Message cannot be null"); } - // Submits a new task if there is no active task. if (ctx.getTask() == null) { updater.submit(); } - // Group all reactive work for this task into one container CompositeDisposable taskDisposables = new CompositeDisposable(); // Check if the task with the task id is already running, put if absent. if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) { throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId())); } - EventProcessor p = new EventProcessor(agentExecutorConfig.outputMode()); Content content = PartConverter.messageToContent(message); - Runner runner = runnerBuilder.build(); + Single skipExecution = + agentExecutorConfig.beforeExecuteCallback() != null + ? agentExecutorConfig.beforeExecuteCallback().call(ctx) + : Single.just(false); + Runner runner = runnerBuilder.build(); taskDisposables.add( - prepareSession(ctx, runner.appName(), runner.sessionService()) + skipExecution .flatMapPublisher( - session -> { - updater.startWork(); - return runner.runAsync( - getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig()); + skip -> { + if (skip) { + cancel(ctx, eventQueue); + return Flowable.empty(); + } + return Maybe.defer( + () -> { + return prepareSession(ctx, runner.appName(), runner.sessionService()); + }) + .flatMapPublisher( + session -> { + updater.startWork(); + return runner.runAsync( + getUserId(ctx), + session.id(), + content, + agentExecutorConfig.runConfig()); + }); }) - .subscribe( + .concatMap( event -> { - p.process(event, updater); - }, + return p.process(event, ctx, agentExecutorConfig.afterEventCallback(), eventQueue) + .toFlowable(); + }) + // Ignore all events from the runner, since they are already processed. + .ignoreElements() + .materialize() + .flatMapCompletable( + notification -> { + Throwable error = notification.getError(); + if (error != null) { + logger.error("Runner failed to execute", error); + } + return handleExecutionEnd(ctx, error, eventQueue); + }) + .doFinally(() -> cleanupTask(ctx.getTaskId())) + .subscribe( + () -> {}, error -> { - logger.error("Runner failed with {}", error); - updater.fail(failedMessage(ctx, error)); - cleanupTask(ctx.getTaskId()); - }, - () -> { - updater.complete(); - cleanupTask(ctx.getTaskId()); + logger.error("Failed to handle execution end", error); })); } + private Completable handleExecutionEnd( + RequestContext ctx, Throwable error, EventQueue eventQueue) { + TaskState state = error != null ? TaskState.FAILED : TaskState.COMPLETED; + Message message = error != null ? failedMessage(ctx, error) : null; + TaskStatusUpdateEvent initialEvent = + new TaskStatusUpdateEvent.Builder() + .taskId(ctx.getTaskId()) + .contextId(ctx.getContextId()) + .isFinal(true) + .status(new TaskStatus(state, message, null)) + .build(); + Maybe afterExecute = + agentExecutorConfig.afterExecuteCallback() != null + ? agentExecutorConfig.afterExecuteCallback().call(ctx, initialEvent) + : Maybe.just(initialEvent); + return afterExecute.doOnSuccess(event -> eventQueue.enqueueEvent(event)).ignoreElement(); + } + private void cleanupTask(String taskId) { Disposable d = activeTasks.remove(taskId); if (d != null) { @@ -249,16 +296,19 @@ private EventProcessor(AgentExecutorConfig.OutputMode outputMode) { this.outputMode = outputMode; } - private void process(Event event, TaskUpdater updater) { + private Maybe process( + Event event, + RequestContext ctx, + Callbacks.AfterEventCallback callback, + EventQueue eventQueue) { if (event.errorCode().isPresent()) { - throw new InvalidAgentResponseError( - null, // Uses default code -32006 - "Agent returned an error: " + event.errorCode().get(), - null); + return Maybe.error( + new InvalidAgentResponseError( + null, // Uses default code -32006 + "Agent returned an error: " + event.errorCode().get(), + null)); } - ImmutableList> parts = EventConverter.contentToParts(event.content()); - // Mark all parts as partial if the event is partial. if (event.partial().orElse(false)) { parts.forEach( @@ -302,7 +352,26 @@ private void process(Event event, TaskUpdater updater) { } } - updater.addArtifact(parts, artifactId, null, metadata, append, lastChunk); + TaskArtifactUpdateEvent initialEvent = + new TaskArtifactUpdateEvent.Builder() + .taskId(ctx.getTaskId()) + .contextId(ctx.getContextId()) + .lastChunk(lastChunk) + .append(append) + .artifact( + new Artifact.Builder() + .artifactId(artifactId) + .parts(parts) + .metadata(metadata) + .build()) + .build(); + + Maybe afterEvent = + callback != null ? callback.call(ctx, initialEvent, event) : Maybe.just(initialEvent); + return afterEvent.doOnSuccess( + finalEvent -> { + eventQueue.enqueueEvent(finalEvent); + }); } } } diff --git a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java index d9c7c25ab..5570f40d0 100644 --- a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java @@ -3,7 +3,9 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -15,15 +17,25 @@ import com.google.adk.events.Event; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.spec.Message; import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatus; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; +import java.util.UUID; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -33,10 +45,21 @@ @RunWith(JUnit4.class) public final class AgentExecutorTest { + private EventQueue eventQueue; + private List enqueuedEvents; private TestAgent testAgent; @Before public void setUp() { + enqueuedEvents = new ArrayList<>(); + eventQueue = mock(EventQueue.class); + doAnswer( + invocation -> { + enqueuedEvents.add(invocation.getArgument(0)); + return null; + }) + .when(eventQueue) + .enqueueEvent(any()); testAgent = new TestAgent(); } @@ -92,6 +115,248 @@ public void createAgentExecutor_noAgentExecutorConfig_throwsException() { }); } + @Test + public void execute_withBeforeExecuteCallback_cancelsExecutionOnError() { + // If callback returns error, execution should stop/fail. + Callbacks.BeforeExecuteCallback callback = + ctx -> Single.error(new RuntimeException("Cancelled")); + + AgentExecutorConfig config = + AgentExecutorConfig.builder().beforeExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Verify error handling triggered cleanup and fail event + // The executor catches the error and emits failed event. + assertThat(enqueuedEvents).isNotEmpty(); + Object lastEvent = Iterables.getLast(enqueuedEvents); + assertThat(lastEvent).isInstanceOf(TaskStatusUpdateEvent.class); + TaskStatusUpdateEvent statusEvent = (TaskStatusUpdateEvent) lastEvent; + assertThat(statusEvent.getStatus().state().toString()).isEqualTo("FAILED"); + assertThat(statusEvent.getStatus().message().getParts().get(0)).isInstanceOf(TextPart.class); + TextPart textPart = (TextPart) statusEvent.getStatus().message().getParts().get(0); + assertThat(textPart.getText()).contains("Cancelled"); + } + + @Test + public void execute_withBeforeExecuteCallback_skipsExecutionIfTrue() { + Callbacks.BeforeExecuteCallback callback = ctx -> Single.just(true); + + AgentExecutorConfig config = + AgentExecutorConfig.builder().beforeExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Filter for artifact events + Optional artifactEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskArtifactUpdateEvent) + .map(e -> (TaskArtifactUpdateEvent) e) + .findFirst(); + + assertThat(artifactEvent).isEmpty(); + } + + @Test + public void execute_withAfterEventCallback_modifiesEvent() { + // Agent emits an event. Callback intercepts and modifies it. + Part textPart = Part.builder().text("Hello world").build(); + Event agentEvent = + Event.builder() + .id("event-1") + .author("agent") + .content(Content.builder().role("model").parts(ImmutableList.of(textPart)).build()) + .build(); + testAgent.setEventsToEmit(Flowable.just(agentEvent)); + + Callbacks.AfterEventCallback callback = + (ctx, event, sourceEvent) -> { + // Modify event by adding metadata + return Maybe.just( + new TaskArtifactUpdateEvent.Builder(event) + .metadata(ImmutableMap.of("modified", true)) + .build()); + }; + + AgentExecutorConfig config = AgentExecutorConfig.builder().afterEventCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Filter for artifact events + Optional artifactEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskArtifactUpdateEvent) + .map(e -> (TaskArtifactUpdateEvent) e) + .findFirst(); + + assertThat(artifactEvent).isPresent(); + assertThat(artifactEvent.get().getMetadata()).containsEntry("modified", true); + } + + @Test + public void execute_withAfterExecuteCallback_modifiesStatus() { + testAgent.setEventsToEmit(Flowable.empty()); // Just complete + + Callbacks.AfterExecuteCallback callback = + (ctx, event) -> { + // Modify status to have different message + Message newMessage = + new Message.Builder() + .messageId(UUID.randomUUID().toString()) + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Modified completion"))) + .build(); + + return Maybe.just( + new TaskStatusUpdateEvent.Builder(event) + .status(new TaskStatus(event.getStatus().state(), newMessage, null)) + .build()); + }; + + AgentExecutorConfig config = + AgentExecutorConfig.builder().afterExecuteCallback(callback).build(); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(config) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Verify status event + Optional statusEvent = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskStatusUpdateEvent) + .map(e -> (TaskStatusUpdateEvent) e) + .filter(TaskStatusUpdateEvent::isFinal) + .findFirst(); + + assertThat(statusEvent).isPresent(); + assertThat(statusEvent.get().getStatus().message().getParts().get(0)) + .isInstanceOf(TextPart.class); + TextPart textPart = (TextPart) statusEvent.get().getStatus().message().getParts().get(0); + assertThat(textPart.getText()).isEqualTo("Modified completion"); + } + + @Test + public void execute_runnerFails_registersFailedEvent() { + testAgent.setEventsToEmit(Flowable.error(new RuntimeException("Runner error"))); + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(AgentExecutorConfig.builder().build()) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + ImmutableList finalEvents = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskStatusUpdateEvent) + .map(e -> (TaskStatusUpdateEvent) e) + // final events could be COMPLETED, FAILED, CANCELED, REJECTED or UNKNOWN + // as per io.a2a.spec.TaskState + .filter(TaskStatusUpdateEvent::isFinal) + .collect(toImmutableList()); + + assertThat(finalEvents).hasSize(1); + + TaskStatusUpdateEvent statusEvent = finalEvents.get(0); + assertThat(statusEvent.getStatus().state()).isEqualTo(TaskState.FAILED); + assertThat(statusEvent.getStatus().message().getParts().get(0)).isInstanceOf(TextPart.class); + TextPart textPart = (TextPart) statusEvent.getStatus().message().getParts().get(0); + assertThat(textPart.getText()).isEqualTo("Runner error"); + } + + @Test + public void execute_runnerSucceeds_registerCompletedTaskFails_noFailedTaskRegistered() { + testAgent.setEventsToEmit(Flowable.empty()); + + // Configure eventQueue to throw exception when TaskStatusUpdateEvent is enqueued + doAnswer( + invocation -> { + Object event = invocation.getArgument(0); + if (event instanceof TaskStatusUpdateEvent statusUpdate) { + if (statusUpdate.getStatus().state() == TaskState.COMPLETED) { + throw new RuntimeException("Enqueue failed"); + } + } + return null; + }) + .when(eventQueue) + .enqueueEvent(any()); + + AgentExecutor executor = + new AgentExecutor.Builder() + .agentExecutorConfig(AgentExecutorConfig.builder().build()) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + + RequestContext ctx = createRequestContext(); + executor.execute(ctx, eventQueue); + + // Verify status events in the tracked enqueuedEvents + ImmutableList statusEvents = + enqueuedEvents.stream() + .filter(e -> e instanceof TaskStatusUpdateEvent) + .map(e -> (TaskStatusUpdateEvent) e) + .filter(TaskStatusUpdateEvent::isFinal) + .collect(toImmutableList()); + + // There should be no final status events. + assertThat(statusEvents).isEmpty(); + } + + private RequestContext createRequestContext() { + Message message = + new Message.Builder() + .messageId("msg-1") + .role(Message.Role.USER) + .parts(ImmutableList.of(new TextPart("trigger"))) + .build(); + + RequestContext ctx = mock(RequestContext.class); + when(ctx.getMessage()).thenReturn(message); + when(ctx.getTaskId()).thenReturn("task-" + UUID.randomUUID()); + when(ctx.getContextId()).thenReturn("ctx-" + UUID.randomUUID()); + return ctx; + } + @Test public void process_statefulAggregation_tracksArtifactIdAndAppendForAuthor() { Event partial1 = @@ -175,7 +440,7 @@ public void process_statefulAggregation_tracksArtifactIdAndAppendForAuthor() { } private static final class TestAgent extends BaseAgent { - private final Flowable eventsToEmit; + private Flowable eventsToEmit; TestAgent() { this(Flowable.empty()); @@ -187,6 +452,10 @@ private static final class TestAgent extends BaseAgent { this.eventsToEmit = eventsToEmit; } + void setEventsToEmit(Flowable events) { + this.eventsToEmit = events; + } + @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { return eventsToEmit;