diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java index 51c941d92..f8af6c1de 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpMetricsObserver.java @@ -33,4 +33,34 @@ void onToolCall( String method, String toolName ); + + /** + * Called after a tool call completes (success or failure), with timing and server context. + */ + default void onToolCallComplete( + String method, + String toolName, + String serverId, + long latencyNanos, + boolean success, + boolean isProxy + ) {} + + /** + * Called when a tool call results in an error. + */ + default void onToolCallError( + String method, + String toolName, + String serverId, + String errorMessage + ) {} + + /** + * Called when a tools/list request is received. + */ + default void onToolsList( + String method, + int toolCount + ) {} } diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java index 8978616ce..35ea7e910 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java @@ -155,37 +155,39 @@ yield switch (method) { private JsonRpcResponse handleInitialize(JsonRpcRequest req) { if (metricsObserver != null) { - var params = req.getParams(); - var clientInfo = params.getMember("clientInfo"); - var capabilities = params.getMember("capabilities"); - - String extractedProtocolVersion = params.getMember("protocolVersion") != null - ? params.getMember("protocolVersion").asString() - : null; - - String clientName = clientInfo != null && clientInfo.getMember("name") != null - ? clientInfo.getMember("name").asString() - : null; - - String clientTitle = clientInfo != null && clientInfo.getMember("title") != null - ? clientInfo.getMember("title").asString() - : null; - - boolean rootsListChanged = capabilities != null - && capabilities.getMember("roots") != null - && capabilities.getMember("roots").getMember("listChanged") != null - && capabilities.getMember("roots").getMember("listChanged").asBoolean(); - - boolean sampling = capabilities != null && capabilities.getMember("sampling") != null; - boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null; - - metricsObserver.onInitialize("initialize", - extractedProtocolVersion, - rootsListChanged, - sampling, - elicitation, - clientName, - clientTitle); + safeObserve(() -> { + var params = req.getParams(); + var clientInfo = params.getMember("clientInfo"); + var capabilities = params.getMember("capabilities"); + + String extractedProtocolVersion = params.getMember("protocolVersion") != null + ? params.getMember("protocolVersion").asString() + : null; + + String clientName = clientInfo != null && clientInfo.getMember("name") != null + ? clientInfo.getMember("name").asString() + : null; + + String clientTitle = clientInfo != null && clientInfo.getMember("title") != null + ? clientInfo.getMember("title").asString() + : null; + + boolean rootsListChanged = capabilities != null + && capabilities.getMember("roots") != null + && capabilities.getMember("roots").getMember("listChanged") != null + && capabilities.getMember("roots").getMember("listChanged").asBoolean(); + + boolean sampling = capabilities != null && capabilities.getMember("sampling") != null; + boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null; + + metricsObserver.onInitialize("initialize", + extractedProtocolVersion, + rootsListChanged, + sampling, + elicitation, + clientName, + clientTitle); + }); } this.initializeRequest.compareAndSet(null, req); @@ -251,10 +253,17 @@ private JsonRpcResponse handlePromptsGet(JsonRpcRequest req) { private JsonRpcResponse handleToolsList(JsonRpcRequest req, ProtocolVersion protocolVersion) { var supportsOutputSchema = supportsOutputSchema(protocolVersion); + var filteredTools = tools.values() + .stream() + .filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName())) + .toList(); + + if (metricsObserver != null) { + safeObserve(() -> metricsObserver.onToolsList("tools/list", filteredTools.size())); + } + var result = ListToolsResult.builder() - .tools(tools.values() - .stream() - .filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName())) + .tools(filteredTools.stream() .map(tool -> extractToolInfo(tool, supportsOutputSchema)) .toList()) .build(); @@ -266,18 +275,36 @@ private JsonRpcResponse handleToolsCall( Consumer asyncResponseCallback, ProtocolVersion protocolVersion ) { + String toolName = req.getParams().getMember("name") != null + ? req.getParams().getMember("name").asString() + : null; + if (metricsObserver != null) { - String toolName = req.getParams().getMember("name") != null - ? req.getParams().getMember("name").asString() - : null; - metricsObserver.onToolCall("tools/call", toolName); + safeObserve(() -> metricsObserver.onToolCall("tools/call", toolName)); } - var operationName = req.getParams().getMember("name").asString(); - var tool = tools.get(operationName); + long startNanos = System.nanoTime(); + var tool = tools.get(toolName); if (tool == null) { - return createErrorResponse(req, "No such tool: " + operationName); + long latencyNanos = System.nanoTime() - startNanos; + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete( + "tools/call", + toolName, + null, + latencyNanos, + false, + false); + metricsObserver.onToolCallError( + "tools/call", + toolName, + null, + "No such tool: " + toolName); + }); + } + return createErrorResponse(req, "No such tool: " + toolName); } // Check if this tool should be dispatched to a proxy @@ -291,8 +318,46 @@ private JsonRpcResponse handleToolsCall( .build(); // Get response asynchronously and invoke callback - tool.proxy().rpc(proxyRequest).thenAccept(asyncResponseCallback).exceptionally(ex -> { + tool.proxy().rpc(proxyRequest).thenAccept(response -> { + long latencyNanos = System.nanoTime() - startNanos; + boolean success = response.getError() == null; + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latencyNanos, + success, + true); + if (!success) { + String errMsg = response.getError().getMessage() != null + ? response.getError().getMessage() + : "Unknown error"; + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + errMsg); + } + }); + } + asyncResponseCallback.accept(response); + }).exceptionally(ex -> { + long latencyNanos = System.nanoTime() - startNanos; LOG.error("Error from proxy RPC", ex); + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latencyNanos, + false, + true); + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + safeErrorMessage(ex)); + }); + } asyncResponseCallback .accept(createErrorResponse(req, new RuntimeException("Proxy error: " + ex.getMessage(), ex))); return null; @@ -302,16 +367,56 @@ private JsonRpcResponse handleToolsCall( return null; } else { // Handle locally - var operation = tool.operation(); - var argumentsDoc = req.getParams().getMember("arguments"); - var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema()); - var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder()); - var output = operation.function().apply(input, null); - var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion); - return createSuccessResponse(req.getId(), result); + try { + var operation = tool.operation(); + var argumentsDoc = req.getParams().getMember("arguments"); + var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema()); + var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder()); + var output = operation.function().apply(input, null); + var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion); + long latencyNanos = System.nanoTime() - startNanos; + if (metricsObserver != null) { + safeObserve(() -> metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latencyNanos, + true, + false)); + } + return createSuccessResponse(req.getId(), result); + } catch (Exception e) { + long latencyNanos = System.nanoTime() - startNanos; + if (metricsObserver != null) { + safeObserve(() -> { + metricsObserver.onToolCallComplete("tools/call", + toolName, + tool.serverId(), + latencyNanos, + false, + false); + metricsObserver.onToolCallError("tools/call", + toolName, + tool.serverId(), + safeErrorMessage(e)); + }); + } + throw e; + } } } + private void safeObserve(Runnable observation) { + try { + observation.run(); + } catch (Exception e) { + LOG.warn("Metrics observer error", e); + } + } + + private static String safeErrorMessage(Throwable t) { + return t.getMessage() != null ? t.getMessage() : t.getClass().getName(); + } + /** * Sets the notification writer for forwarding notifications from proxies. */ diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java index db2062fc0..2349f7850 100644 --- a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java +++ b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java @@ -1524,6 +1524,253 @@ void testOtherNotificationsDoNotInvalidateCache() { assertEquals(1, callCounter.get(), "Cache should not be invalidated by other notifications"); } + // --- Metrics Observer Tests --- + + @Test + void testMetricsObserverOnInitialize() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + write("initialize", + Document.of(Map.of( + "protocolVersion", + Document.of("2025-03-26"), + "clientInfo", + Document.of(Map.of( + "name", + Document.of("test-client"), + "title", + Document.of("Test Client"))), + "capabilities", + Document.of(Map.of( + "roots", + Document.of(Map.of("listChanged", Document.of(true))), + "sampling", + Document.of(Map.of())))))); + read(); + + assertEquals(1, observer.initializeCount); + assertEquals("2025-03-26", observer.lastProtocolVersion); + assertEquals("test-client", observer.lastClientName); + assertEquals("Test Client", observer.lastClientTitle); + assertTrue(observer.lastRootsListChanged); + assertTrue(observer.lastSampling); + } + + @Test + void testMetricsObserverOnToolsList() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/list", Document.of(Map.of())); + read(); + + assertEquals(1, observer.toolsListCount); + assertEquals(4, observer.lastToolCount); + + write("tools/list", Document.of(Map.of())); + read(); + + assertEquals(2, observer.toolsListCount); + } + + @Test + void testMetricsObserverOnToolCallForNonExistentTool() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NonExistentTool"), + "arguments", + Document.of(Map.of())))); + var response = read(); + + assertNotNull(response.getError()); + assertEquals(1, observer.toolCallCount); + assertEquals("NonExistentTool", observer.lastToolCallName); + // onToolCallComplete should also fire for "tool not found" with success=false + assertEquals(1, observer.toolCallCompleteCount); + assertFalse(observer.lastCompleteSuccess); + assertFalse(observer.lastCompleteIsProxy); + assertNull(observer.lastCompleteServerId); + assertTrue(observer.lastCompleteLatencyNanos >= 0); + assertEquals(1, observer.toolCallErrorCount); + assertEquals("NonExistentTool", observer.lastErrorToolName); + assertTrue(observer.lastErrorMessage.contains("No such tool")); + } + + @Test + void testMetricsObserverOnToolCallLocal() { + var observer = new TestMetricsObserver(); + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .metricsObserver(observer) + .build(); + + server.start(); + + initializeWithProtocolVersion(null); + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + read(); + + // onToolCall fires at the start of every tool call + assertEquals(1, observer.toolCallCount); + assertEquals("NoIOOperation", observer.lastToolCallName); + // onToolCallComplete fires after execution (ProxyService proxies to localhost which + // is not running, so this is a local execution that results in an error response) + assertEquals(1, observer.toolCallCompleteCount); + assertEquals("NoIOOperation", observer.lastCompleteToolName); + assertEquals("test-mcp", observer.lastCompleteServerId); + assertFalse(observer.lastCompleteSuccess); + assertFalse(observer.lastCompleteIsProxy); + assertTrue(observer.lastCompleteLatencyNanos > 0); + } + + private static class TestMetricsObserver implements McpMetricsObserver { + int initializeCount; + String lastProtocolVersion; + boolean lastRootsListChanged; + boolean lastSampling; + boolean lastElicitation; + String lastClientName; + String lastClientTitle; + + int toolCallCount; + String lastToolCallName; + + int toolCallCompleteCount; + String lastCompleteToolName; + String lastCompleteServerId; + long lastCompleteLatencyNanos; + boolean lastCompleteSuccess; + boolean lastCompleteIsProxy; + + int toolCallErrorCount; + String lastErrorToolName; + String lastErrorServerId; + String lastErrorMessage; + + int toolsListCount; + int lastToolCount; + + @Override + public void onInitialize( + String method, + String extractedProtocolVersion, + boolean rootsListChanged, + boolean sampling, + boolean elicitation, + String clientName, + String clientTitle + ) { + initializeCount++; + lastProtocolVersion = extractedProtocolVersion; + lastRootsListChanged = rootsListChanged; + lastSampling = sampling; + lastElicitation = elicitation; + lastClientName = clientName; + lastClientTitle = clientTitle; + } + + @Override + public void onToolCall(String method, String toolName) { + toolCallCount++; + lastToolCallName = toolName; + } + + @Override + public void onToolCallComplete( + String method, + String toolName, + String serverId, + long latencyNanos, + boolean success, + boolean isProxy + ) { + toolCallCompleteCount++; + lastCompleteToolName = toolName; + lastCompleteServerId = serverId; + lastCompleteLatencyNanos = latencyNanos; + lastCompleteSuccess = success; + lastCompleteIsProxy = isProxy; + } + + @Override + public void onToolCallError( + String method, + String toolName, + String serverId, + String errorMessage + ) { + toolCallErrorCount++; + lastErrorToolName = toolName; + lastErrorServerId = serverId; + lastErrorMessage = errorMessage; + } + + @Override + public void onToolsList(String method, int toolCount) { + toolsListCount++; + lastToolCount = toolCount; + } + } + private static class CacheTestProxy extends McpServerProxy { private final AtomicInteger callCounter; private final List sentNotifications = new ArrayList<>();