From 0b164898f66b25eca25c645aed05f71e1961b597 Mon Sep 17 00:00:00 2001 From: "alexgangxi@163.com" Date: Wed, 8 Apr 2026 23:15:33 +0800 Subject: [PATCH] fix(tool): support Flux-returning tool methods --- .../core/tool/ToolMethodInvoker.java | 102 +++++++++++++++++- .../agentscope/core/tool/AsyncToolTest.java | 53 ++++++++- .../core/tool/ToolMethodInvokerTest.java | 90 ++++++++++++++++ .../core/tool/test/SampleTools.java | 13 +++ 4 files changed, 256 insertions(+), 2 deletions(-) diff --git a/agentscope-core/src/main/java/io/agentscope/core/tool/ToolMethodInvoker.java b/agentscope-core/src/main/java/io/agentscope/core/tool/ToolMethodInvoker.java index 5da99f2d6..9b011769c 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/tool/ToolMethodInvoker.java +++ b/agentscope-core/src/main/java/io/agentscope/core/tool/ToolMethodInvoker.java @@ -25,8 +25,10 @@ import java.lang.reflect.Parameter; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -105,6 +107,9 @@ r, extractGenericType(method))) .onErrorResume(this::handleError)) .onErrorResume(this::handleError); + } else if (returnType == Flux.class) { + return invokeFlux(toolObject, method, input, agent, context, emitter, converter); + } else { // Sync method: wrap in Mono.fromCallable return Mono.fromCallable( @@ -119,6 +124,101 @@ r, extractGenericType(method))) } } + private Mono invokeFlux( + Object toolObject, + Method method, + Map input, + Agent agent, + ToolExecutionContext context, + ToolEmitter emitter, + ToolResultConverter converter) { + Type itemType = extractGenericType(method); + + return Mono.fromCallable( + () -> { + method.setAccessible(true); + Object[] args = + convertParameters(method, input, agent, context, emitter); + @SuppressWarnings("unchecked") + Flux flux = (Flux) method.invoke(toolObject, args); + return flux != null ? flux : Flux.empty(); + }) + .flatMap( + flux -> + flux.doOnNext( + item -> + emitFluxChunk( + emitter, converter, item, itemType)) + .collectList() + .map( + items -> + converter.convert( + aggregateFluxItems(items, itemType), + resolveFluxAggregateType( + items, itemType))) + .onErrorResume(this::handleError)) + .onErrorResume(this::handleError); + } + + private void emitFluxChunk( + ToolEmitter emitter, ToolResultConverter converter, Object item, Type itemType) { + if (item == null) { + return; + } + emitter.emit(toStreamingChunk(item, itemType, converter)); + } + + private ToolResultBlock toStreamingChunk( + Object item, Type itemType, ToolResultConverter converter) { + if (item instanceof ToolResultBlock) { + return (ToolResultBlock) item; + } + if (item instanceof CharSequence + || item instanceof Number + || item instanceof Boolean + || item instanceof Character) { + return ToolResultBlock.text(String.valueOf(item)); + } + return converter.convert(item, itemType); + } + + private Object aggregateFluxItems(List items, Type itemType) { + if (shouldConcatenateFluxItems(items, itemType)) { + StringBuilder aggregated = new StringBuilder(); + for (Object item : items) { + if (item != null) { + aggregated.append(item); + } + } + return aggregated.toString(); + } + if (items.isEmpty()) { + return null; + } + if (items.size() == 1) { + return items.get(0); + } + return items; + } + + private Type resolveFluxAggregateType(List items, Type itemType) { + if (shouldConcatenateFluxItems(items, itemType)) { + return String.class; + } + if (items.size() == 1) { + return itemType; + } + return List.class; + } + + private boolean shouldConcatenateFluxItems(List items, Type itemType) { + if (itemType == String.class || itemType == CharSequence.class) { + return true; + } + return !items.isEmpty() + && items.stream().allMatch(item -> item == null || item instanceof CharSequence); + } + /** * Convert input parameters to method arguments with automatic injection support. * @@ -363,7 +463,7 @@ private ToolResultBlock handleInvocationError(Throwable e) { } /** - * Extract generic type from method return type (for CompletableFuture or Mono). + * Extract generic type from method return type (for CompletableFuture, Mono, or Flux). * * @param method the method * @return the generic type, or null if not found diff --git a/agentscope-core/src/test/java/io/agentscope/core/tool/AsyncToolTest.java b/agentscope-core/src/test/java/io/agentscope/core/tool/AsyncToolTest.java index 669f3ffe3..3771dc5d6 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/tool/AsyncToolTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/tool/AsyncToolTest.java @@ -26,6 +26,7 @@ import io.agentscope.core.tool.test.SampleTools; import io.agentscope.core.util.JsonUtils; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; import org.junit.jupiter.api.BeforeEach; @@ -34,7 +35,7 @@ import org.junit.jupiter.api.Test; /** - * Tests for async tool execution with CompletableFuture and Mono return types. + * Tests for async tool execution with CompletableFuture, Mono, and Flux return types. */ @Tag("unit") @DisplayName("Async Tool Tests") @@ -90,6 +91,56 @@ void shouldExecuteMonoAsyncTool() { assertEquals("\"HelloWorld\"", extractFirstText(response)); } + @Test + @DisplayName("Should execute Flux async tool") + void shouldExecuteFluxAsyncTool() { + Map input = Map.of("str1", "Hello", "str2", "World"); + ToolUseBlock toolCall = + ToolUseBlock.builder() + .id("call-async-flux") + .name("async_flux_concat") + .input(input) + .content(JsonUtils.getJsonCodec().toJson(input)) + .build(); + + ToolResultBlock response = + toolkit.callTool(ToolCallParam.builder().toolUseBlock(toolCall).build()) + .block(TIMEOUT); + + assertNotNull(response, "Response should not be null"); + assertEquals("\"HelloWorld\"", extractFirstText(response)); + } + + @Test + @DisplayName("Should emit Flux chunks while aggregating final tool result") + void shouldEmitFluxChunksWhileAggregatingFinalToolResult() { + List chunkToolIds = new ArrayList<>(); + List chunkTexts = new ArrayList<>(); + toolkit.setChunkCallback( + (toolUse, chunk) -> { + chunkToolIds.add(toolUse.getId()); + chunkTexts.add(extractFirstText(chunk)); + }); + + Map input = Map.of("str1", "Alpha", "str2", "Beta"); + ToolUseBlock toolCall = + ToolUseBlock.builder() + .id("call-async-flux-chunk") + .name("async_flux_concat") + .input(input) + .content(JsonUtils.getJsonCodec().toJson(input)) + .build(); + + ToolResultBlock response = + toolkit.callTool(ToolCallParam.builder().toolUseBlock(toolCall).build()) + .block(TIMEOUT); + + assertNotNull(response, "Response should not be null"); + assertEquals(List.of("call-async-flux-chunk", "call-async-flux-chunk"), chunkToolIds); + assertEquals(List.of("Alpha", "Beta"), chunkTexts); + assertEquals("\"AlphaBeta\"", extractFirstText(response)); + } + @Test @DisplayName("Should execute async tool with delay") void shouldExecuteAsyncToolWithDelay() { diff --git a/agentscope-core/src/test/java/io/agentscope/core/tool/ToolMethodInvokerTest.java b/agentscope-core/src/test/java/io/agentscope/core/tool/ToolMethodInvokerTest.java index 2fb341bef..00aa3c6c5 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/tool/ToolMethodInvokerTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/tool/ToolMethodInvokerTest.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -205,6 +206,26 @@ public Mono suspendToolMonoSync( @ToolParam(name = "reason", description = "reason") String reason) { throw new ToolSuspendException(reason); } + + public Flux fluxConcat( + @ToolParam(name = "prefix", description = "prefix") String prefix, + @ToolParam(name = "suffix", description = "suffix") String suffix) { + return Flux.just(prefix, suffix); + } + + public Flux fluxSingleNumber( + @ToolParam(name = "value", description = "value") Integer value) { + return Flux.just(value); + } + + public Flux fluxNumbers( + @ToolParam(name = "start", description = "start") Integer start) { + return Flux.just(start, start + 1, start + 2); + } + + public Flux emptyFluxString() { + return Flux.empty(); + } } /** Test POJO for generic type testing (Issue #677). */ @@ -867,6 +888,75 @@ void testGenericMap_WithCustomClassValue() throws Exception { } /** Test nested generic types like List<List<Integer>>. */ + @Test + void testFluxStringAggregationAndChunkEmission() throws Exception { + TestTools tools = new TestTools(); + Method method = TestTools.class.getMethod("fluxConcat", String.class, String.class); + + Map input = new HashMap<>(); + input.put("prefix", "Hello"); + input.put("suffix", "World"); + + List emittedChunks = new ArrayList<>(); + ToolUseBlock toolUseBlock = new ToolUseBlock("flux-id", method.getName(), input); + ToolCallParam param = + ToolCallParam.builder() + .toolUseBlock(toolUseBlock) + .input(input) + .emitter(chunk -> emittedChunks.add(ToolTestUtils.extractContent(chunk))) + .build(); + + ToolResultBlock response = + invoker.invokeAsync(tools, method, param, responseConverter).block(); + + Assertions.assertNotNull(response); + Assertions.assertFalse(ToolTestUtils.isErrorResponse(response)); + Assertions.assertEquals("\"HelloWorld\"", ToolTestUtils.extractContent(response)); + Assertions.assertEquals(List.of("Hello", "World"), emittedChunks); + } + + @Test + void testFluxSingleValueAggregation() throws Exception { + TestTools tools = new TestTools(); + Method method = TestTools.class.getMethod("fluxSingleNumber", Integer.class); + + Map input = new HashMap<>(); + input.put("value", 7); + + ToolResultBlock response = invokeWithParam(tools, method, input); + + Assertions.assertNotNull(response); + Assertions.assertFalse(ToolTestUtils.isErrorResponse(response)); + Assertions.assertEquals("7", ToolTestUtils.extractContent(response)); + } + + @Test + void testFluxMultipleValuesAggregateToJsonArray() throws Exception { + TestTools tools = new TestTools(); + Method method = TestTools.class.getMethod("fluxNumbers", Integer.class); + + Map input = new HashMap<>(); + input.put("start", 3); + + ToolResultBlock response = invokeWithParam(tools, method, input); + + Assertions.assertNotNull(response); + Assertions.assertFalse(ToolTestUtils.isErrorResponse(response)); + Assertions.assertEquals("[3,4,5]", ToolTestUtils.extractContent(response)); + } + + @Test + void testEmptyFluxStringAggregatesToEmptyString() throws Exception { + TestTools tools = new TestTools(); + Method method = TestTools.class.getMethod("emptyFluxString"); + + ToolResultBlock response = invokeWithParam(tools, method, new HashMap<>()); + + Assertions.assertNotNull(response); + Assertions.assertFalse(ToolTestUtils.isErrorResponse(response)); + Assertions.assertEquals("\"\"", ToolTestUtils.extractContent(response)); + } + @Test void testNestedGenericList() throws Exception { TestTools tools = new TestTools(); diff --git a/agentscope-core/src/test/java/io/agentscope/core/tool/test/SampleTools.java b/agentscope-core/src/test/java/io/agentscope/core/tool/test/SampleTools.java index 1747c3aef..8fd996125 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/tool/test/SampleTools.java +++ b/agentscope-core/src/test/java/io/agentscope/core/tool/test/SampleTools.java @@ -19,6 +19,7 @@ import io.agentscope.core.tool.ToolParam; import java.time.Duration; import java.util.concurrent.CompletableFuture; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -132,6 +133,18 @@ public Mono asyncConcat( return Mono.fromCallable(() -> str1 + str2); } + /** + * Async tool using Flux that streams string chunks. + */ + @Tool( + name = "async_flux_concat", + description = "Asynchronously stream and concatenate two strings") + public Flux asyncFluxConcat( + @ToolParam(name = "str1", description = "First string") String str1, + @ToolParam(name = "str2", description = "Second string") String str2) { + return Flux.just(str1, str2).delayElements(Duration.ofMillis(25)); + } + /** * Async tool using Mono that simulates delay. */