Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,34 @@ void onToolCall(
String method,
String toolName
);

/**
* Called after a tool call completes (success or failure), with timing and server context.
*/
default void onToolCallComplete(
String method,
String toolName,
String serverId,
long latencyNanos,
boolean success,
boolean isProxy
) {}

/**
* Called when a tool call results in an error.
*/
default void onToolCallError(
String method,
String toolName,
String serverId,
String errorMessage
) {}

/**
* Called when a tools/list request is received.
*/
default void onToolsList(
String method,
int toolCount
) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,37 +155,39 @@ yield switch (method) {

private JsonRpcResponse handleInitialize(JsonRpcRequest req) {
if (metricsObserver != null) {
var params = req.getParams();
var clientInfo = params.getMember("clientInfo");
var capabilities = params.getMember("capabilities");

String extractedProtocolVersion = params.getMember("protocolVersion") != null
? params.getMember("protocolVersion").asString()
: null;

String clientName = clientInfo != null && clientInfo.getMember("name") != null
? clientInfo.getMember("name").asString()
: null;

String clientTitle = clientInfo != null && clientInfo.getMember("title") != null
? clientInfo.getMember("title").asString()
: null;

boolean rootsListChanged = capabilities != null
&& capabilities.getMember("roots") != null
&& capabilities.getMember("roots").getMember("listChanged") != null
&& capabilities.getMember("roots").getMember("listChanged").asBoolean();

boolean sampling = capabilities != null && capabilities.getMember("sampling") != null;
boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null;

metricsObserver.onInitialize("initialize",
extractedProtocolVersion,
rootsListChanged,
sampling,
elicitation,
clientName,
clientTitle);
safeObserve(() -> {
var params = req.getParams();
var clientInfo = params.getMember("clientInfo");
var capabilities = params.getMember("capabilities");

String extractedProtocolVersion = params.getMember("protocolVersion") != null
? params.getMember("protocolVersion").asString()
: null;

String clientName = clientInfo != null && clientInfo.getMember("name") != null
? clientInfo.getMember("name").asString()
: null;

String clientTitle = clientInfo != null && clientInfo.getMember("title") != null
? clientInfo.getMember("title").asString()
: null;

boolean rootsListChanged = capabilities != null
&& capabilities.getMember("roots") != null
&& capabilities.getMember("roots").getMember("listChanged") != null
&& capabilities.getMember("roots").getMember("listChanged").asBoolean();

boolean sampling = capabilities != null && capabilities.getMember("sampling") != null;
boolean elicitation = capabilities != null && capabilities.getMember("elicitation") != null;

metricsObserver.onInitialize("initialize",
extractedProtocolVersion,
rootsListChanged,
sampling,
elicitation,
clientName,
clientTitle);
});
}

this.initializeRequest.compareAndSet(null, req);
Expand Down Expand Up @@ -251,10 +253,17 @@ private JsonRpcResponse handlePromptsGet(JsonRpcRequest req) {

private JsonRpcResponse handleToolsList(JsonRpcRequest req, ProtocolVersion protocolVersion) {
var supportsOutputSchema = supportsOutputSchema(protocolVersion);
var filteredTools = tools.values()
.stream()
.filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName()))
.toList();

if (metricsObserver != null) {
safeObserve(() -> metricsObserver.onToolsList("tools/list", filteredTools.size()));
}

var result = ListToolsResult.builder()
.tools(tools.values()
.stream()
.filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName()))
.tools(filteredTools.stream()
.map(tool -> extractToolInfo(tool, supportsOutputSchema))
.toList())
.build();
Expand All @@ -266,18 +275,36 @@ private JsonRpcResponse handleToolsCall(
Consumer<JsonRpcResponse> asyncResponseCallback,
ProtocolVersion protocolVersion
) {
String toolName = req.getParams().getMember("name") != null
? req.getParams().getMember("name").asString()
: null;

if (metricsObserver != null) {
String toolName = req.getParams().getMember("name") != null
? req.getParams().getMember("name").asString()
: null;
metricsObserver.onToolCall("tools/call", toolName);
safeObserve(() -> metricsObserver.onToolCall("tools/call", toolName));
}

var operationName = req.getParams().getMember("name").asString();
var tool = tools.get(operationName);
long startNanos = System.nanoTime();
var tool = tools.get(toolName);

if (tool == null) {
return createErrorResponse(req, "No such tool: " + operationName);
long latencyNanos = System.nanoTime() - startNanos;
if (metricsObserver != null) {
safeObserve(() -> {
metricsObserver.onToolCallComplete(
"tools/call",
toolName,
null,
latencyNanos,
false,
false);
metricsObserver.onToolCallError(
"tools/call",
toolName,
null,
"No such tool: " + toolName);
});
}
return createErrorResponse(req, "No such tool: " + toolName);
}

// Check if this tool should be dispatched to a proxy
Expand All @@ -291,8 +318,46 @@ private JsonRpcResponse handleToolsCall(
.build();

// Get response asynchronously and invoke callback
tool.proxy().rpc(proxyRequest).thenAccept(asyncResponseCallback).exceptionally(ex -> {
tool.proxy().rpc(proxyRequest).thenAccept(response -> {
long latencyNanos = System.nanoTime() - startNanos;
boolean success = response.getError() == null;
if (metricsObserver != null) {
safeObserve(() -> {
metricsObserver.onToolCallComplete("tools/call",
toolName,
tool.serverId(),
latencyNanos,
success,
true);
if (!success) {
String errMsg = response.getError().getMessage() != null
? response.getError().getMessage()
: "Unknown error";
metricsObserver.onToolCallError("tools/call",
toolName,
tool.serverId(),
errMsg);
}
});
}
asyncResponseCallback.accept(response);
}).exceptionally(ex -> {
long latencyNanos = System.nanoTime() - startNanos;
LOG.error("Error from proxy RPC", ex);
if (metricsObserver != null) {
safeObserve(() -> {
metricsObserver.onToolCallComplete("tools/call",
toolName,
tool.serverId(),
latencyNanos,
false,
true);
metricsObserver.onToolCallError("tools/call",
toolName,
tool.serverId(),
safeErrorMessage(ex));
});
}
asyncResponseCallback
.accept(createErrorResponse(req, new RuntimeException("Proxy error: " + ex.getMessage(), ex)));
return null;
Expand All @@ -302,16 +367,56 @@ private JsonRpcResponse handleToolsCall(
return null;
} else {
// Handle locally
var operation = tool.operation();
var argumentsDoc = req.getParams().getMember("arguments");
var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema());
var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder());
var output = operation.function().apply(input, null);
var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion);
return createSuccessResponse(req.getId(), result);
try {
var operation = tool.operation();
var argumentsDoc = req.getParams().getMember("arguments");
var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema());
var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder());
var output = operation.function().apply(input, null);
var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion);
long latencyNanos = System.nanoTime() - startNanos;
if (metricsObserver != null) {
safeObserve(() -> metricsObserver.onToolCallComplete("tools/call",
toolName,
tool.serverId(),
latencyNanos,
true,
false));
}
return createSuccessResponse(req.getId(), result);
} catch (Exception e) {
long latencyNanos = System.nanoTime() - startNanos;
if (metricsObserver != null) {
safeObserve(() -> {
metricsObserver.onToolCallComplete("tools/call",
toolName,
tool.serverId(),
latencyNanos,
false,
false);
metricsObserver.onToolCallError("tools/call",
toolName,
tool.serverId(),
safeErrorMessage(e));
});
}
throw e;
}
}
}

private void safeObserve(Runnable observation) {
try {
observation.run();
} catch (Exception e) {
LOG.warn("Metrics observer error", e);
}
}

private static String safeErrorMessage(Throwable t) {
return t.getMessage() != null ? t.getMessage() : t.getClass().getName();
}

/**
* Sets the notification writer for forwarding notifications from proxies.
*/
Expand Down
Loading