Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,12 @@ else if (contentBlock.isWebSearchToolResult()) {
// message_delta
long inputTokens = streamingState.getInputTokens();
long outputTokens = deltaEvent.usage().outputTokens();
Usage usage = new DefaultUsage(Math.toIntExact(inputTokens), Math.toIntExact(outputTokens),
Math.toIntExact(inputTokens + outputTokens), deltaEvent.usage());
Long cacheRead = deltaEvent.usage().cacheReadInputTokens().orElse(null);
Long cacheWrite = deltaEvent.usage().cacheCreationInputTokens().orElse(null);
Usage usage = new DefaultUsage(Integer.valueOf(Math.toIntExact(inputTokens)),
Integer.valueOf(Math.toIntExact(outputTokens)),
Integer.valueOf(Math.toIntExact(inputTokens + outputTokens)), deltaEvent.usage(), cacheRead,
cacheWrite);

Usage accumulatedUsage = previousChatResponse != null
? UsageCalculator.getCumulativeUsage(usage, previousChatResponse) : usage;
Expand Down Expand Up @@ -1054,8 +1058,11 @@ private Usage getDefaultUsage(com.anthropic.models.messages.Usage usage) {
}
long inputTokens = usage.inputTokens();
long outputTokens = usage.outputTokens();
return new DefaultUsage(Math.toIntExact(inputTokens), Math.toIntExact(outputTokens),
Math.toIntExact(inputTokens + outputTokens), usage);
Long cacheRead = usage.cacheReadInputTokens().orElse(null);
Long cacheWrite = usage.cacheCreationInputTokens().orElse(null);
return new DefaultUsage(Integer.valueOf(Math.toIntExact(inputTokens)),
Integer.valueOf(Math.toIntExact(outputTokens)),
Integer.valueOf(Math.toIntExact(inputTokens + outputTokens)), usage, cacheRead, cacheWrite);
}

private @Nullable Citation convertTextCitation(TextCitation textCitation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ void shouldCacheSystemMessageOnly() {
cacheCreation, cacheRead)
.isTrue();

// Verify unified Usage interface reports the same cache metrics
org.springframework.ai.chat.metadata.Usage springUsage = response.getMetadata().getUsage();
assertThat(springUsage.getCacheWriteInputTokens() != null || springUsage.getCacheReadInputTokens() != null)
.withFailMessage("Expected cache metrics on Usage interface")
.isTrue();
if (cacheCreation > 0) {
assertThat(springUsage.getCacheWriteInputTokens()).isEqualTo(cacheCreation);
}
if (cacheRead > 0) {
assertThat(springUsage.getCacheReadInputTokens()).isEqualTo(cacheRead);
}

logger.info("Cache creation tokens: {}, Cache read tokens: {}", cacheCreation, cacheRead);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,9 @@ else if (previousTokenUsage.cacheWriteInputTokens() != null) {
.cacheWriteInputTokens(cacheWriteInputTokens)
.build();

DefaultUsage usage = new DefaultUsage(promptTokens, generationTokens, totalTokens, nativeTokenUsage);
DefaultUsage usage = new DefaultUsage(promptTokens, generationTokens, totalTokens, nativeTokenUsage,
cacheReadInputTokens != null ? cacheReadInputTokens.longValue() : null,
cacheWriteInputTokens != null ? cacheWriteInputTokens.longValue() : null);

Document modelResponseFields = response.additionalModelResponseFields();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,12 @@ private void emitChatResponse(Generation generation) {
}

private Usage getCurrentUsage() {
TokenUsage nativeUsage = this.tokenUsageRef.get();
Integer cacheReadInt = nativeUsage != null ? nativeUsage.cacheReadInputTokens() : null;
Integer cacheWriteInt = nativeUsage != null ? nativeUsage.cacheWriteInputTokens() : null;
return new DefaultUsage(this.promptTokens.get(), this.generationTokens.get(), this.totalTokens.get(),
this.tokenUsageRef.get());
nativeUsage, cacheReadInt != null ? cacheReadInt.longValue() : null,
cacheWriteInt != null ? cacheWriteInt.longValue() : null);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,12 @@ void testSystemOnlyPromptCaching() {
assertThat(cacheRead).as("Cache read should meet the 4096 token minimum for Claude Haiku 4.5")
.isGreaterThan(4096);
assertThat(cacheWrite).as("A cache read hit should not also write").isIn(null, 0);

// Verify unified Usage interface reports the same cache metrics
org.springframework.ai.chat.metadata.Usage springUsage = response.getMetadata().getUsage();
assertThat(springUsage.getCacheReadInputTokens())
.as("Usage interface should report same cache read tokens as metadata")
.isEqualTo(cacheRead.longValue());
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ public Integer getCachedContentTokenCount() {
return this.cachedContentTokenCount;
}

@Override
public @Nullable Long getCacheReadInputTokens() {
return this.cachedContentTokenCount != null ? this.cachedContentTokenCount.longValue() : null;
}

/**
* Returns the number of tokens present in tool-use prompts.
* @return the tool-use prompt token count, or null if not available
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,9 @@ private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
}

private DefaultUsage getDefaultUsage(CompletionUsage usage) {
Long cacheRead = usage.promptTokensDetails().flatMap(details -> details.cachedTokens()).orElse(null);
return new DefaultUsage(Math.toIntExact(usage.promptTokens()), Math.toIntExact(usage.completionTokens()),
Math.toIntExact(usage.totalTokens()), usage);
Math.toIntExact(usage.totalTokens()), usage, cacheRead, null);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,44 @@ ChatResponse response = chatClient.prompt(prompt)
Usage usage = response.getMetadata().getUsage();
```

== Prompt Cache Usage Metrics

For providers that support prompt caching, the `Usage` interface provides unified access to cache metrics without requiring provider-specific casting:

```java
Usage usage = response.getMetadata().getUsage();

// Unified cache metrics — works across all providers
Long cacheReadTokens = usage.getCacheReadInputTokens();
Long cacheWriteTokens = usage.getCacheWriteInputTokens();

if (cacheReadTokens != null && cacheReadTokens > 0) {
System.out.println("Cache hit: " + cacheReadTokens + " tokens read from cache");
}
if (cacheWriteTokens != null && cacheWriteTokens > 0) {
System.out.println("Cache write: " + cacheWriteTokens + " tokens written to cache");
}
```

These methods return `null` for providers that do not support prompt caching.

The following table shows prompt cache metrics availability by provider:

[cols="1,1,1"]
|===
|Provider |Cache Read Tokens |Cache Write Tokens

|Anthropic |Yes |Yes (`cacheCreationInputTokens`)
|AWS Bedrock |Yes |Yes
|OpenAI |Yes (`cachedTokens`) |No
|Google Gemini |Yes (`cachedContentTokenCount`) |No
|DeepSeek |No |No
|Mistral |No |No
|Ollama |No |No
|===

NOTE: For detailed provider-specific cache metrics (such as per-modality cache breakdowns in Gemini), use `getNativeUsage()` to access the provider's native usage object.

== Benefits

**Standardization**: Provides a consistent way to handle usage across different AI models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
* @author Ilayaperumal Gopinathan
* @since 1.0.0
*/
@JsonPropertyOrder({ "promptTokens", "completionTokens", "totalTokens", "nativeUsage" })
@JsonPropertyOrder({ "promptTokens", "completionTokens", "totalTokens", "cacheReadInputTokens", "cacheWriteInputTokens",
"nativeUsage" })
public class DefaultUsage implements Usage {

private final Integer promptTokens;
Expand All @@ -42,6 +43,10 @@ public class DefaultUsage implements Usage {

private final @Nullable Object nativeUsage;

private final @Nullable Long cacheReadInputTokens;

private final @Nullable Long cacheWriteInputTokens;

/**
* Create a new DefaultUsage with promptTokens, completionTokens, totalTokens and
* native {@link Usage} object.
Expand All @@ -56,11 +61,35 @@ public class DefaultUsage implements Usage {
*/
public DefaultUsage(@Nullable Integer promptTokens, @Nullable Integer completionTokens,
@Nullable Integer totalTokens, @Nullable Object nativeUsage) {
this(promptTokens, completionTokens, totalTokens, nativeUsage, null, null);
}

/**
* Create a new DefaultUsage with all fields including prompt cache metrics.
* @param promptTokens the number of tokens in the prompt, or {@code null} if not
* available
* @param completionTokens the number of tokens in the generation, or {@code null} if
* not available
* @param totalTokens the total number of tokens, or {@code null} to calculate from
* promptTokens and completionTokens
* @param nativeUsage the native usage object returned by the model provider, or
* {@code null} to return the map of prompt, completion and total tokens.
* @param cacheReadInputTokens the number of input tokens read from prompt cache, or
* {@code null} if not available
* @param cacheWriteInputTokens the number of input tokens written to prompt cache, or
* {@code null} if not available
* @since 2.0.0
*/
public DefaultUsage(@Nullable Integer promptTokens, @Nullable Integer completionTokens,
@Nullable Integer totalTokens, @Nullable Object nativeUsage, @Nullable Long cacheReadInputTokens,
@Nullable Long cacheWriteInputTokens) {
this.promptTokens = promptTokens != null ? promptTokens : 0;
this.completionTokens = completionTokens != null ? completionTokens : 0;
this.totalTokens = totalTokens != null ? totalTokens
: calculateTotalTokens(this.promptTokens, this.completionTokens);
this.nativeUsage = nativeUsage;
this.cacheReadInputTokens = cacheReadInputTokens;
this.cacheWriteInputTokens = cacheWriteInputTokens;
}

/**
Expand Down Expand Up @@ -100,8 +129,11 @@ public DefaultUsage(Integer promptTokens, Integer completionTokens, Integer tota
@JsonCreator
public static DefaultUsage fromJson(@JsonProperty("promptTokens") Integer promptTokens,
@JsonProperty("completionTokens") Integer completionTokens,
@JsonProperty("totalTokens") Integer totalTokens, @JsonProperty("nativeUsage") Object nativeUsage) {
return new DefaultUsage(promptTokens, completionTokens, totalTokens, nativeUsage);
@JsonProperty("totalTokens") Integer totalTokens, @JsonProperty("nativeUsage") Object nativeUsage,
@JsonProperty("cacheReadInputTokens") @Nullable Long cacheReadInputTokens,
@JsonProperty("cacheWriteInputTokens") @Nullable Long cacheWriteInputTokens) {
return new DefaultUsage(promptTokens, completionTokens, totalTokens, nativeUsage, cacheReadInputTokens,
cacheWriteInputTokens);
}

@Override
Expand Down Expand Up @@ -129,6 +161,20 @@ public Integer getTotalTokens() {
return this.nativeUsage;
}

@Override
@JsonProperty("cacheReadInputTokens")
@JsonInclude(JsonInclude.Include.NON_NULL)
public @Nullable Long getCacheReadInputTokens() {
return this.cacheReadInputTokens;
}

@Override
@JsonProperty("cacheWriteInputTokens")
@JsonInclude(JsonInclude.Include.NON_NULL)
public @Nullable Long getCacheWriteInputTokens() {
return this.cacheWriteInputTokens;
}

private Integer calculateTotalTokens(Integer promptTokens, Integer completionTokens) {
return promptTokens + completionTokens;
}
Expand All @@ -145,7 +191,9 @@ public boolean equals(Object o) {
DefaultUsage that = (DefaultUsage) o;
return this.totalTokens == that.totalTokens && Objects.equals(this.promptTokens, that.promptTokens)
&& Objects.equals(this.completionTokens, that.completionTokens)
&& Objects.equals(this.nativeUsage, that.nativeUsage);
&& Objects.equals(this.nativeUsage, that.nativeUsage)
&& Objects.equals(this.cacheReadInputTokens, that.cacheReadInputTokens)
&& Objects.equals(this.cacheWriteInputTokens, that.cacheWriteInputTokens);
}

@Override
Expand All @@ -154,13 +202,25 @@ public int hashCode() {
result = 31 * result + Objects.hashCode(this.completionTokens);
result = 31 * result + this.totalTokens;
result = 31 * result + Objects.hashCode(this.nativeUsage);
result = 31 * result + Objects.hashCode(this.cacheReadInputTokens);
result = 31 * result + Objects.hashCode(this.cacheWriteInputTokens);
return result;
}

@Override
public String toString() {
return "DefaultUsage{" + "promptTokens=" + this.promptTokens + ", completionTokens=" + this.completionTokens
+ ", totalTokens=" + this.totalTokens + '}';
StringBuilder sb = new StringBuilder("DefaultUsage{");
sb.append("promptTokens=").append(this.promptTokens);
sb.append(", completionTokens=").append(this.completionTokens);
sb.append(", totalTokens=").append(this.totalTokens);
if (this.cacheReadInputTokens != null) {
sb.append(", cacheReadInputTokens=").append(this.cacheReadInputTokens);
}
if (this.cacheWriteInputTokens != null) {
sb.append(", cacheWriteInputTokens=").append(this.cacheWriteInputTokens);
}
sb.append('}');
return sb.toString();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,28 @@ default Integer getTotalTokens() {
*/
@Nullable Object getNativeUsage();

/**
* Returns the number of input tokens read from the prompt cache, if the provider
* supports prompt caching. Cached tokens are tokens that were previously processed
* and stored by the provider, reducing cost and latency for repeated prompt prefixes.
* @return the number of cached input tokens read, or {@code null} if the provider
* does not support prompt caching or no cache hit occurred.
* @since 2.0.0
*/
default @Nullable Long getCacheReadInputTokens() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if null is the right approach to express no-tokens.
I can see that in getTotalTokens we default to 0. Is it too bad to use 0 if there are no CachedReadInputTokens()? I was thinking of negative numbers (e.g. NO_OPTS = -123) but this could lead to wrong token subtraction if we miss to validate.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second read, null has special and useful semantics in Json serialization (e.g. you can filter out null value fields)

return null;
}

/**
* Returns the number of input tokens written to the prompt cache, if the provider
* supports prompt caching. Cache writes occur when new prompt content is cached for
* the first time.
* @return the number of input tokens written to cache, or {@code null} if the
* provider does not support prompt caching or no cache write occurred.
* @since 2.0.0
*/
default @Nullable Long getCacheWriteInputTokens() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same like above

return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,24 @@ public static Usage getCumulativeUsage(final Usage currentUsage,
promptTokens += usageFromPreviousChatResponse.getPromptTokens();
generationTokens += usageFromPreviousChatResponse.getCompletionTokens();
totalTokens += usageFromPreviousChatResponse.getTotalTokens();
return new DefaultUsage(promptTokens, generationTokens, totalTokens);
// Accumulate cache metrics, preserving null when neither side reports them.
Long cacheRead = null;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might change if the null semantics change

if (currentUsage.getCacheReadInputTokens() != null
|| usageFromPreviousChatResponse.getCacheReadInputTokens() != null) {
cacheRead = (currentUsage.getCacheReadInputTokens() != null ? currentUsage.getCacheReadInputTokens()
: 0L)
+ (usageFromPreviousChatResponse.getCacheReadInputTokens() != null
? usageFromPreviousChatResponse.getCacheReadInputTokens() : 0L);
}
Long cacheWrite = null;
if (currentUsage.getCacheWriteInputTokens() != null
|| usageFromPreviousChatResponse.getCacheWriteInputTokens() != null) {
cacheWrite = (currentUsage.getCacheWriteInputTokens() != null ? currentUsage.getCacheWriteInputTokens()
: 0L)
+ (usageFromPreviousChatResponse.getCacheWriteInputTokens() != null
? usageFromPreviousChatResponse.getCacheWriteInputTokens() : 0L);
}
return new DefaultUsage(promptTokens, generationTokens, totalTokens, null, cacheRead, cacheWrite);
}
// When current usage is empty, return the usage from the previous chat response.
return usageFromPreviousChatResponse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,52 @@ void testNegativeTokenValues() throws Exception {
assertThat(json).isEqualTo("{\"promptTokens\":-1,\"completionTokens\":-2,\"totalTokens\":-3}");
}

@Test
void testCacheFields() {
DefaultUsage usage = new DefaultUsage(100, 50, 150, null, 500L, 200L);
assertThat(usage.getCacheReadInputTokens()).isEqualTo(500L);
assertThat(usage.getCacheWriteInputTokens()).isEqualTo(200L);
}

@Test
void testCacheFieldsNullByDefault() {
DefaultUsage usage = new DefaultUsage(100, 50, 150);
assertThat(usage.getCacheReadInputTokens()).isNull();
assertThat(usage.getCacheWriteInputTokens()).isNull();
}

@Test
void testToStringWithCacheFields() {
DefaultUsage usage = new DefaultUsage(100, 50, 150, null, 500L, 200L);
assertThat(usage).hasToString("DefaultUsage{promptTokens=100, completionTokens=50, totalTokens=150, "
+ "cacheReadInputTokens=500, cacheWriteInputTokens=200}");
}

@Test
void testSerializationWithCacheFields() throws Exception {
DefaultUsage usage = new DefaultUsage(100, 50, 150, null, 500L, 200L);
String json = JsonMapper.shared().writeValueAsString(usage);
assertThat(json).contains("\"cacheReadInputTokens\":500");
assertThat(json).contains("\"cacheWriteInputTokens\":200");
}

@Test
void testDeserializationWithCacheFields() throws Exception {
String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,"
+ "\"cacheReadInputTokens\":500,\"cacheWriteInputTokens\":200}";
DefaultUsage usage = JsonMapper.shared().readValue(json, DefaultUsage.class);
assertThat(usage.getCacheReadInputTokens()).isEqualTo(500L);
assertThat(usage.getCacheWriteInputTokens()).isEqualTo(200L);
}

@Test
void testDeserializationWithoutCacheFields() throws Exception {
String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}";
DefaultUsage usage = JsonMapper.shared().readValue(json, DefaultUsage.class);
assertThat(usage.getCacheReadInputTokens()).isNull();
assertThat(usage.getCacheWriteInputTokens()).isNull();
}

@Test
void testCalculatedTotalTokens() {
// Test when total tokens is null and should be calculated
Expand Down
Loading