From 51db7de889e6f89a37378cb9550c864cec00022b Mon Sep 17 00:00:00 2001 From: Param Parikh Date: Mon, 6 Apr 2026 15:30:59 -0700 Subject: [PATCH] Extend McpMetricsObserver with tool call completion, error, and tools/list callbacks Add three new default methods to McpMetricsObserver: - onToolCallComplete: fires after tool execution with latency, server ID, success/failure, and proxy flag - onToolCallError: fires on tool call errors with error message - onToolsList: fires on tools/list requests with tool count Instrument McpService.handleToolsCall() with System.nanoTime() latency tracking on both local and proxy code paths, and add error tracking for failed calls and unknown tools. Instrument handleToolsList() with tool count callback. --- .../java/mcp/server/McpMetricsObserver.java | 30 +++ .../smithy/java/mcp/server/McpService.java | 203 ++++++++++---- .../smithy/java/mcp/server/McpServerTest.java | 247 ++++++++++++++++++ 3 files changed, 431 insertions(+), 49 deletions(-) diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java index 51c941d92..f8af6c1de 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java @@ -33,4 +33,34 @@ void onToolCall( String method, String toolName ); + + /** + * Called after a tool call completes (success or failure), with timing and server context. + */ + default void onToolCallComplete( + String method, + String toolName, + String serverId, + long latencyNanos, + boolean success, + boolean isProxy + ) {} + + /** + * Called when a tool call results in an error. + */ + default void onToolCallError( + String method, + String toolName, + String serverId, + String errorMessage + ) {} + + /** + * Called when a tools/list request is received. + */ + default void onToolsList( + String method, + int toolCount + ) {} } diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java index 8978616ce..35ea7e910 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java @@ -155,37 +155,39 @@ yield switch (method) { private JsonRpcResponse handleInitialize(JsonRpcRequest req) { if (metricsObserver != null) { - var params = req.getParams(); - var clientInfo = params.getMember("clientInfo"); - var capabilities = params.getMember("capabilities"); - - String extractedProtocolVersion = params.getMember("protocolVersion") != null - ? params.getMember("protocolVersion").asString() - : null; - - String clientName = clientInfo != null && clientInfo.getMember("name") != null - ? clientInfo.getMember("name").asString() - : null; - - String clientTitle = clientInfo != null && clientInfo.getMember("title") != null - ? clientInfo.getMember("title").asString() - : null; - - boolean rootsListChanged = capabilities != null - && capabilities.getMember("roots") != null - && capabilities.getMember("roots").getMember("listChanged") != null - && capabilities.getMember("roots").getMember("listChanged").asBoolean(); - - boolean sampling = capabilities != null && capabilities.getMember("sampling") != null; - boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null; - - metricsObserver.onInitialize("initialize", - extractedProtocolVersion, - rootsListChanged, - sampling, - elicitation, - clientName, - clientTitle); + safeObserve(() -> { + var params = req.getParams(); + var clientInfo = params.getMember("clientInfo"); + var capabilities = params.getMember("capabilities"); + + String extractedProtocolVersion = params.getMember("protocolVersion") != null + ? params.getMember("protocolVersion").asString() + : null; + + String clientName = clientInfo != null && clientInfo.getMember("name") != null + ? clientInfo.getMember("name").asString() + : null; + + String clientTitle = clientInfo != null && clientInfo.getMember("title") != null + ? clientInfo.getMember("title").asString() + : null; + + boolean rootsListChanged = capabilities != null + && capabilities.getMember("roots") != null + && capabilities.getMember("roots").getMember("listChanged") != null + && capabilities.getMember("roots").getMember("listChanged").asBoolean(); + + boolean sampling = capabilities != null && capabilities.getMember("sampling") != null; + boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null; + + metricsObserver.onInitialize("initialize", + extractedProtocolVersion, + rootsListChanged, + sampling, + elicitation, + clientName, + clientTitle); + }); } this.initializeRequest.compareAndSet(null, req); @@ -251,10 +253,17 @@ private JsonRpcResponse handlePromptsGet(JsonRpcRequest req) { private JsonRpcResponse handleToolsList(JsonRpcRequest req, ProtocolVersion protocolVersion) { var supportsOutputSchema = supportsOutputSchema(protocolVersion); + var filteredTools = tools.values() + .stream() + .filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName())) + .toList(); + + if (metricsObserver != null) { + safeObserve(() -> metricsObserver.onToolsList("tools/list", filteredTools.size())); + } + var result = ListToolsResult.builder() - .tools(tools.values() - .stream() - .filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName())) + .tools(filteredTools.stream() .map(tool -> extractToolInfo(tool, supportsOutputSchema)) .toList()) .build(); @@ -266,18 +275,36 @@ private JsonRpcResponse handleToolsCall( Consumer asyncResponseCallback, ProtocolVersion protocolVersion ) { + String toolName = req.getParams().getMember("name") != null + ? req.getParams().getMember("name").asString() + : null; + if (metricsObserver != null) { - String toolName = req.getParams().getMember("name") != null - ? req.getParams().getMember("name").asString() - : null; - metricsObserver.onToolCall("tools/call", toolName); + safeObserve(() -> metricsObserver.onToolCall("tools/call", toolName)); } - var operationName = req.getParams().getMember("name").asString(); - var tool = tools.get(operationName); + long startNanos = System.nanoTime(); + var tool = tools.get(toolName); if (tool == null) { - return createErrorResponse(req, "No such tool: " + operationName); + long latencyNanos = System.nanoTime() - startNanos; + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete( + "tools/call", + toolName, + null, + latencyNanos, + false, + false); + metricsObserver.onToolCallError( + "tools/call", + toolName, + null, + "No such tool: " + toolName); + }); + } + return createErrorResponse(req, "No such tool: " + toolName); } // Check if this tool should be dispatched to a proxy @@ -291,8 +318,46 @@ private JsonRpcResponse handleToolsCall( .build(); // Get response asynchronously and invoke callback - tool.proxy().rpc(proxyRequest).thenAccept(asyncResponseCallback).exceptionally(ex -> { + tool.proxy().rpc(proxyRequest).thenAccept(response -> { + long latencyNanos = System.nanoTime() - startNanos; + boolean success = response.getError() == null; + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latencyNanos, + success, + true); + if (!success) { + String errMsg = response.getError().getMessage() != null + ? response.getError().getMessage() + : "Unknown error"; + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + errMsg); + } + }); + } + asyncResponseCallback.accept(response); + }).exceptionally(ex -> { + long latencyNanos = System.nanoTime() - startNanos; LOG.error("Error from proxy RPC", ex); + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latencyNanos, + false, + true); + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + safeErrorMessage(ex)); + }); + } asyncResponseCallback .accept(createErrorResponse(req, new RuntimeException("Proxy error: " + ex.getMessage(), ex))); return null; @@ -302,16 +367,56 @@ private JsonRpcResponse handleToolsCall( return null; } else { // Handle locally - var operation = tool.operation(); - var argumentsDoc = req.getParams().getMember("arguments"); - var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema()); - var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder()); - var output = operation.function().apply(input, null); - var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion); - return createSuccessResponse(req.getId(), result); + try { + var operation = tool.operation(); + var argumentsDoc = req.getParams().getMember("arguments"); + var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema()); + var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder()); + var output = operation.function().apply(input, null); + var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion); + long latencyNanos = System.nanoTime() - startNanos; + if (metricsObserver != null) { + safeObserve(() -> metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latencyNanos, + true, + false)); + } + return createSuccessResponse(req.getId(), result); + } catch (Exception e) { + long latencyNanos = System.nanoTime() - startNanos; + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latencyNanos, + false, + false); + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + safeErrorMessage(e)); + }); + } + throw e; + } } } + private void safeObserve(Runnable observation) { + try { + observation.run(); + } catch (Exception e) { + LOG.warn("Metrics observer error", e); + } + } + + private static String safeErrorMessage(Throwable t) { + return t.getMessage() != null ? t.getMessage() : t.getClass().getName(); + } + /** * Sets the notification writer for forwarding notifications from proxies. */ diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java index db2062fc0..2349f7850 100644 --- a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java +++ b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java @@ -1524,6 +1524,253 @@ void testOtherNotificationsDoNotInvalidateCache() { assertEquals(1, callCounter.get(), "Cache should not be invalidated by other notifications"); } + // --- Metrics Observer Tests --- + + @Test + void testMetricsObserverOnInitialize() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + write("initialize", + Document.of(Map.of( + "protocolVersion", + Document.of("2025-03-26"), + "clientInfo", + Document.of(Map.of( + "name", + Document.of("test-client"), + "title", + Document.of("Test Client"))), + "capabilities", + Document.of(Map.of( + "roots", + Document.of(Map.of("listChanged", Document.of(true))), + "sampling", + Document.of(Map.of())))))); + read(); + + assertEquals(1, observer.initializeCount); + assertEquals("2025-03-26", observer.lastProtocolVersion); + assertEquals("test-client", observer.lastClientName); + assertEquals("Test Client", observer.lastClientTitle); + assertTrue(observer.lastRootsListChanged); + assertTrue(observer.lastSampling); + } + + @Test + void testMetricsObserverOnToolsList() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/list", Document.of(Map.of())); + read(); + + assertEquals(1, observer.toolsListCount); + assertEquals(4, observer.lastToolCount); + + write("tools/list", Document.of(Map.of())); + read(); + + assertEquals(2, observer.toolsListCount); + } + + @Test + void testMetricsObserverOnToolCallForNonExistentTool() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NonExistentTool"), + "arguments", + Document.of(Map.of())))); + var response = read(); + + assertNotNull(response.getError()); + assertEquals(1, observer.toolCallCount); + assertEquals("NonExistentTool", observer.lastToolCallName); + // onToolCallComplete should also fire for "tool not found" with success=false + assertEquals(1, observer.toolCallCompleteCount); + assertFalse(observer.lastCompleteSuccess); + assertFalse(observer.lastCompleteIsProxy); + assertNull(observer.lastCompleteServerId); + assertTrue(observer.lastCompleteLatencyNanos >= 0); + assertEquals(1, observer.toolCallErrorCount); + assertEquals("NonExistentTool", observer.lastErrorToolName); + assertTrue(observer.lastErrorMessage.contains("No such tool")); + } + + @Test + void testMetricsObserverOnToolCallLocal() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + read(); + + // onToolCall fires at the start of every tool call + assertEquals(1, observer.toolCallCount); + assertEquals("NoIOOperation", observer.lastToolCallName); + // onToolCallComplete fires after execution (ProxyService proxies to localhost which + // is not running, so this is a local execution that results in an error response) + assertEquals(1, observer.toolCallCompleteCount); + assertEquals("NoIOOperation", observer.lastCompleteToolName); + assertEquals("test-mcp", observer.lastCompleteServerId); + assertFalse(observer.lastCompleteSuccess); + assertFalse(observer.lastCompleteIsProxy); + assertTrue(observer.lastCompleteLatencyNanos > 0); + } + + private static class TestMetricsObserver implements McpMetricsObserver { + int initializeCount; + String lastProtocolVersion; + boolean lastRootsListChanged; + boolean lastSampling; + boolean lastElicitation; + String lastClientName; + String lastClientTitle; + + int toolCallCount; + String lastToolCallName; + + int toolCallCompleteCount; + String lastCompleteToolName; + String lastCompleteServerId; + long lastCompleteLatencyNanos; + boolean lastCompleteSuccess; + boolean lastCompleteIsProxy; + + int toolCallErrorCount; + String lastErrorToolName; + String lastErrorServerId; + String lastErrorMessage; + + int toolsListCount; + int lastToolCount; + + @Override + public void onInitialize( + String method, + String extractedProtocolVersion, + boolean rootsListChanged, + boolean sampling, + boolean elicitation, + String clientName, + String clientTitle + ) { + initializeCount++; + lastProtocolVersion = extractedProtocolVersion; + lastRootsListChanged = rootsListChanged; + lastSampling = sampling; + lastElicitation = elicitation; + lastClientName = clientName; + lastClientTitle = clientTitle; + } + + @Override + public void onToolCall(String method, String toolName) { + toolCallCount++; + lastToolCallName = toolName; + } + + @Override + public void onToolCallComplete( + String method, + String toolName, + String serverId, + long latencyNanos, + boolean success, + boolean isProxy + ) { + toolCallCompleteCount++; + lastCompleteToolName = toolName; + lastCompleteServerId = serverId; + lastCompleteLatencyNanos = latencyNanos; + lastCompleteSuccess = success; + lastCompleteIsProxy = isProxy; + } + + @Override + public void onToolCallError( + String method, + String toolName, + String serverId, + String errorMessage + ) { + toolCallErrorCount++; + lastErrorToolName = toolName; + lastErrorServerId = serverId; + lastErrorMessage = errorMessage; + } + + @Override + public void onToolsList(String method, int toolCount) { + toolsListCount++; + lastToolCount = toolCount; + } + } + private static class CacheTestProxy extends McpServerProxy { private final AtomicInteger callCounter; private final List sentNotifications = new ArrayList<>();