diff --git a/src/api_type/anthropic.rs b/src/api_type/anthropic.rs index 6856f8b..37736ee 100644 --- a/src/api_type/anthropic.rs +++ b/src/api_type/anthropic.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use async_trait::async_trait; use axum::body::Body; use serde::Deserialize; @@ -9,10 +11,53 @@ use crate::request_metadata::RequestInspectionMetadata; use super::{ApiTypeHandler, Inspector, ResponseMetadata, ResponseMetadataInspector}; +#[derive(Debug, Deserialize)] +struct CacheCreation { + ephemeral_5m_input_tokens: Option, + ephemeral_1h_input_tokens: Option, +} + #[derive(Debug, Deserialize)] struct Usage { input_tokens: Option, output_tokens: Option, + cache_creation_input_tokens: Option, + cache_read_input_tokens: Option, + cache_creation: Option, +} + +fn build_cache_creation_map( + usage: &Usage, + cache_ttl_hint: Option, +) -> Option> { + if let Some(cc) = &usage.cache_creation { + let mut map = HashMap::new(); + if let Some(tokens) = cc.ephemeral_5m_input_tokens + && tokens > 0 + { + map.insert("5m".to_owned(), tokens); + } + if let Some(tokens) = cc.ephemeral_1h_input_tokens + && tokens > 0 + { + map.insert("1h".to_owned(), tokens); + } + if map.is_empty() { + return None; + } else { + return Some(map); + } + } else if let Some(tokens) = usage.cache_creation_input_tokens + && tokens > 0 + { + let duration = match cache_ttl_hint { + Some(300) => "5m", + Some(3600) => "1h", + _ => "unknown", + }; + return Some(HashMap::from([(duration.to_owned(), tokens)])); + } + None } // Payload for both non-streaming and SSE. @@ -27,9 +72,60 @@ struct MessageStartData { message: AnthropicDataWithUsage, } +#[derive(Debug, Deserialize)] +struct Message { + content: Option, +} + #[derive(Debug, Deserialize)] struct AnthropicRequestBody { model: Option, + system: Option, + messages: Option>, +} + +/// Collect cache_control TTL values from a JSON value that may be a single +/// content block or an array of content blocks. +fn collect_cache_ttls(value: &serde_json::Value, ttls: &mut Vec) { + let blocks: Vec<&serde_json::Value> = if let Some(arr) = value.as_array() { + arr.iter().collect() + } else if value.is_object() { + vec![value] + } else { + return; + }; + for block in blocks { + if let Some(cc) = block.get("cache_control") { + // cache_control block present → default TTL is 300s when ttl field absent + let ttl = cc.get("ttl").and_then(|v| v.as_u64()).unwrap_or(300); + ttls.push(ttl); + } + } +} + +/// Extract a uniform cache TTL from the request body. +/// Returns Some(ttl) if all cache_control blocks share the same TTL, None if mixed or absent. +fn extract_cache_ttl(body: &AnthropicRequestBody) -> Option { + let mut ttls = Vec::new(); + if let Some(system) = &body.system { + collect_cache_ttls(system, &mut ttls); + } + if let Some(messages) = &body.messages { + for msg in messages { + if let Some(content) = &msg.content { + collect_cache_ttls(content, &mut ttls); + } + } + } + if ttls.is_empty() { + return None; + } + let first = ttls[0]; + if ttls.iter().all(|&t| t == first) { + Some(first) + } else { + None + } } pub struct AnthropicMessagesHandler; @@ -59,7 +155,13 @@ impl ApiTypeHandler for AnthropicMessagesHandler { } }; let metadata = match serde_json::from_slice::(&bytes) { - Ok(body) => RequestInspectionMetadata { model: body.model }, + Ok(body) => { + let cache_ttl_secs = extract_cache_ttl(&body); + RequestInspectionMetadata { + model: body.model, + cache_ttl_secs, + } + } Err(e) => { tracing::error!("Failed to parse Anthropic request body: {e}"); RequestInspectionMetadata::default() @@ -73,6 +175,7 @@ impl ApiTypeHandler for AnthropicMessagesHandler { &self, _status: u16, headers: &http::HeaderMap, + request_metadata: &RequestInspectionMetadata, ) -> ResponseMetadataInspector { if is_event_stream(headers) { Box::new(ProtocolInspector::new( @@ -80,6 +183,9 @@ impl ApiTypeHandler for AnthropicMessagesHandler { AnthropicSseInspector { input_tokens: None, output_tokens: None, + cache_creation_tokens: None, + cache_read_input_tokens: None, + cache_ttl_hint: request_metadata.cache_ttl_secs, }, )) } else { @@ -98,10 +204,12 @@ fn is_event_stream(headers: &http::HeaderMap) -> bool { .is_some_and(|ct| ct.starts_with("text/event-stream")) } -#[derive(Default)] pub(crate) struct AnthropicSseInspector { pub(crate) input_tokens: Option, pub(crate) output_tokens: Option, + pub(crate) cache_creation_tokens: Option>, + pub(crate) cache_read_input_tokens: Option, + pub(crate) cache_ttl_hint: Option, } #[derive(Debug, Deserialize)] @@ -123,6 +231,9 @@ impl AnthropicSseInspector { "message_start" => { if let Ok(msg) = serde_json::from_str::(data) { self.input_tokens = msg.message.usage.input_tokens; + self.cache_creation_tokens = + build_cache_creation_map(&msg.message.usage, self.cache_ttl_hint); + self.cache_read_input_tokens = msg.message.usage.cache_read_input_tokens; } } "message_delta" => { @@ -146,11 +257,12 @@ impl Inspector for AnthropicSseInspector { if self.input_tokens.is_none() && self.output_tokens.is_none() { return Err(anyhow::anyhow!("no token usage found in SSE stream")); } - let response_metadata = ResponseMetadata { + Ok(ResponseMetadata { input_tokens: self.input_tokens, output_tokens: self.output_tokens, - }; - Ok(response_metadata) + cache_creation_tokens: self.cache_creation_tokens, + cache_read_input_tokens: self.cache_read_input_tokens, + }) } } @@ -173,9 +285,12 @@ impl Inspector for AnthropicJsonInspector { fn parse_anthropic_json(data: &[u8]) -> Result { let parsed = serde_json::from_slice::(data)?; + let cache_creation_tokens = build_cache_creation_map(&parsed.usage, None); Ok(ResponseMetadata { input_tokens: parsed.usage.input_tokens, output_tokens: parsed.usage.output_tokens, + cache_creation_tokens, + cache_read_input_tokens: parsed.usage.cache_read_input_tokens, }) } @@ -191,16 +306,31 @@ mod tests { } fn make_json_inspector() -> ResponseMetadataInspector { - AnthropicMessagesHandler.response_inspector(200, &http::HeaderMap::new()) + AnthropicMessagesHandler.response_inspector( + 200, + &http::HeaderMap::new(), + &RequestInspectionMetadata::default(), + ) } fn make_sse_inspector() -> ResponseMetadataInspector { + make_sse_inspector_with_hint(None) + } + + fn make_sse_inspector_with_hint(cache_ttl_secs: Option) -> ResponseMetadataInspector { let mut headers = http::HeaderMap::new(); headers.insert( http::header::CONTENT_TYPE, "text/event-stream".parse().unwrap(), ); - AnthropicMessagesHandler.response_inspector(200, &headers) + AnthropicMessagesHandler.response_inspector( + 200, + &headers, + &RequestInspectionMetadata { + cache_ttl_secs, + ..Default::default() + }, + ) } #[tokio::test] @@ -311,4 +441,145 @@ data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"outpu assert_eq!(metadata.input_tokens, Some(25)); assert_eq!(metadata.output_tokens, Some(150)); } + + // --- Cache token tests --- + + #[tokio::test] + async fn inspect_request_uniform_ttl() { + let body = br#"{ + "model": "claude-sonnet-4-20250514", + "system": [{"type": "text", "text": "You are helpful.", "cache_control": {"type": "ephemeral", "ttl": 3600}}], + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hi", "cache_control": {"type": "ephemeral", "ttl": 3600}}]}] + }"#; + let request = make_request(body); + let (result, _) = AnthropicMessagesHandler.inspect_request(request).await; + let metadata = result.unwrap(); + assert_eq!(metadata.cache_ttl_secs, Some(3600)); + } + + #[tokio::test] + async fn inspect_request_mixed_ttls() { + let body = br#"{ + "model": "claude-sonnet-4-20250514", + "system": [{"type": "text", "text": "You are helpful.", "cache_control": {"type": "ephemeral", "ttl": 300}}], + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hi", "cache_control": {"type": "ephemeral", "ttl": 3600}}]}] + }"#; + let request = make_request(body); + let (result, _) = AnthropicMessagesHandler.inspect_request(request).await; + let metadata = result.unwrap(); + assert_eq!(metadata.cache_ttl_secs, None); + } + + #[tokio::test] + async fn inspect_request_default_ttl() { + let body = br#"{ + "model": "claude-sonnet-4-20250514", + "system": [{"type": "text", "text": "You are helpful.", "cache_control": {"type": "ephemeral"}}] + }"#; + let request = make_request(body); + let (result, _) = AnthropicMessagesHandler.inspect_request(request).await; + let metadata = result.unwrap(); + assert_eq!(metadata.cache_ttl_secs, Some(300)); + } + + #[tokio::test] + async fn inspect_request_no_cache_control() { + let body = br#"{ + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Hi"}] + }"#; + let request = make_request(body); + let (result, _) = AnthropicMessagesHandler.inspect_request(request).await; + let metadata = result.unwrap(); + assert_eq!(metadata.cache_ttl_secs, None); + } + + #[test] + fn inspect_json_cache_creation_breakdown() { + let body = br#"{ + "id": "msg_123", + "type": "message", + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "cache_creation_input_tokens": 348, + "cache_read_input_tokens": 1800, + "cache_creation": { + "ephemeral_5m_input_tokens": 248, + "ephemeral_1h_input_tokens": 100 + } + } + }"#; + let mut inspector = make_json_inspector(); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + let map = metadata.cache_creation_tokens.unwrap(); + assert_eq!(map.get("5m"), Some(&248)); + assert_eq!(map.get("1h"), Some(&100)); + assert_eq!(metadata.cache_read_input_tokens, Some(1800)); + } + + #[test] + fn inspect_json_no_cache_fields() { + let body = br#"{ + "id": "msg_123", + "type": "message", + "usage": {"input_tokens": 25, "output_tokens": 150} + }"#; + let mut inspector = make_json_inspector(); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + assert_eq!(metadata.cache_creation_tokens, None); + assert_eq!(metadata.cache_read_input_tokens, None); + } + + #[test] + fn inspect_sse_cache_with_ttl_hint_5m() { + let body = br#"event: message_start +data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","usage":{"input_tokens":25,"output_tokens":1,"cache_creation_input_tokens":348,"cache_read_input_tokens":0}}} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":150}} + +"#; + let mut inspector = make_sse_inspector_with_hint(Some(300)); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + let map = metadata.cache_creation_tokens.unwrap(); + assert_eq!(map.get("5m"), Some(&348)); + assert_eq!(metadata.cache_read_input_tokens, Some(0)); + } + + #[test] + fn inspect_sse_cache_with_ttl_hint_1h() { + let body = br#"event: message_start +data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","usage":{"input_tokens":25,"output_tokens":1,"cache_creation_input_tokens":500,"cache_read_input_tokens":200}}} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":10}} + +"#; + let mut inspector = make_sse_inspector_with_hint(Some(3600)); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + let map = metadata.cache_creation_tokens.unwrap(); + assert_eq!(map.get("1h"), Some(&500)); + assert_eq!(metadata.cache_read_input_tokens, Some(200)); + } + + #[test] + fn inspect_sse_cache_no_ttl_hint() { + let body = br#"event: message_start +data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","usage":{"input_tokens":25,"output_tokens":1,"cache_creation_input_tokens":348,"cache_read_input_tokens":0}}} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":150}} + +"#; + let mut inspector = make_sse_inspector_with_hint(None); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + let map = metadata.cache_creation_tokens.unwrap(); + assert_eq!(map.get("unknown"), Some(&348)); + } } diff --git a/src/api_type/bedrock/eventstream_inspector.rs b/src/api_type/bedrock/eventstream_inspector.rs index 8096aae..d0fe495 100644 --- a/src/api_type/bedrock/eventstream_inspector.rs +++ b/src/api_type/bedrock/eventstream_inspector.rs @@ -27,11 +27,12 @@ impl Inspector for AnthropicSseInspector { if self.input_tokens.is_none() && self.output_tokens.is_none() { return Err(anyhow::anyhow!("no token usage found in SSE stream")); } - let response_metadata = ResponseMetadata { + Ok(ResponseMetadata { input_tokens: self.input_tokens, output_tokens: self.output_tokens, - }; - Ok(response_metadata) + cache_creation_tokens: self.cache_creation_tokens, + cache_read_input_tokens: self.cache_read_input_tokens, + }) } } @@ -56,6 +57,9 @@ mod tests { AnthropicSseInspector { input_tokens: None, output_tokens: None, + cache_creation_tokens: None, + cache_read_input_tokens: None, + cache_ttl_hint: None, }, )); diff --git a/src/api_type/bedrock/headers_inspector.rs b/src/api_type/bedrock/headers_inspector.rs index ea35cd4..4222091 100644 --- a/src/api_type/bedrock/headers_inspector.rs +++ b/src/api_type/bedrock/headers_inspector.rs @@ -11,6 +11,7 @@ impl ApiTypeHandler for BedrockModelInvokeJsonHandler { &self, _status: u16, headers: &http::HeaderMap, + _request_metadata: &crate::request_metadata::RequestInspectionMetadata, ) -> ResponseMetadataInspector { let input_tokens = parse_token_header(headers, "x-amzn-bedrock-input-token-count").unwrap_or(None); @@ -19,6 +20,7 @@ impl ApiTypeHandler for BedrockModelInvokeJsonHandler { Box::new(StaticInspector::new(ResponseMetadata { input_tokens, output_tokens, + ..Default::default() })) } } @@ -54,7 +56,11 @@ mod tests { ("x-amzn-bedrock-input-token-count", "25"), ("x-amzn-bedrock-output-token-count", "150"), ]); - let inspector = BedrockModelInvokeJsonHandler.response_inspector(200, &headers); + let inspector = BedrockModelInvokeJsonHandler.response_inspector( + 200, + &headers, + &crate::request_metadata::RequestInspectionMetadata::default(), + ); let metadata = inspector.finish().unwrap(); assert_eq!(metadata.input_tokens, Some(25)); assert_eq!(metadata.output_tokens, Some(150)); @@ -63,7 +69,11 @@ mod tests { #[test] fn inspect_response_missing_headers() { let headers = http::HeaderMap::new(); - let inspector = BedrockModelInvokeJsonHandler.response_inspector(200, &headers); + let inspector = BedrockModelInvokeJsonHandler.response_inspector( + 200, + &headers, + &crate::request_metadata::RequestInspectionMetadata::default(), + ); let metadata = inspector.finish().unwrap(); assert_eq!(metadata.input_tokens, None); assert_eq!(metadata.output_tokens, None); @@ -72,7 +82,11 @@ mod tests { #[test] fn inspect_response_invalid_header() { let headers = headers_to_map(&[("x-amzn-bedrock-input-token-count", "not_a_number")]); - let inspector = BedrockModelInvokeJsonHandler.response_inspector(200, &headers); + let inspector = BedrockModelInvokeJsonHandler.response_inspector( + 200, + &headers, + &crate::request_metadata::RequestInspectionMetadata::default(), + ); let metadata = inspector.finish().unwrap(); assert_eq!(metadata.input_tokens, None); assert_eq!(metadata.output_tokens, None); diff --git a/src/api_type/bedrock/mod.rs b/src/api_type/bedrock/mod.rs index c93de84..27ad794 100644 --- a/src/api_type/bedrock/mod.rs +++ b/src/api_type/bedrock/mod.rs @@ -44,21 +44,38 @@ impl ApiTypeHandler for BedrockModelInvokeHandler { axum::extract::Request, ) { let model = extract_model_from_path(request.uri().path()); - (Ok(RequestInspectionMetadata { model }), request) + ( + Ok(RequestInspectionMetadata { + model, + ..Default::default() + }), + request, + ) } fn response_inspector( &self, status: u16, headers: &http::HeaderMap, + request_metadata: &RequestInspectionMetadata, ) -> ResponseMetadataInspector { if is_amazon_event_stream(headers) { Box::new(ProtocolInspector::new( AmazonEventstreamProtocol::default(), - AnthropicSseInspector::default(), + AnthropicSseInspector { + input_tokens: None, + output_tokens: None, + cache_creation_tokens: None, + cache_read_input_tokens: None, + cache_ttl_hint: request_metadata.cache_ttl_secs, + }, )) } else { - headers_inspector::BedrockModelInvokeJsonHandler.response_inspector(status, headers) + headers_inspector::BedrockModelInvokeJsonHandler.response_inspector( + status, + headers, + request_metadata, + ) } } } diff --git a/src/api_type/mod.rs b/src/api_type/mod.rs index 05ebd26..93f2581 100644 --- a/src/api_type/mod.rs +++ b/src/api_type/mod.rs @@ -35,6 +35,7 @@ pub trait ApiTypeHandler { &self, _status: u16, _headers: &http::HeaderMap, + _request_metadata: &RequestInspectionMetadata, ) -> ResponseMetadataInspector { Box::new(StaticInspector::default()) } diff --git a/src/api_type/openai/chat_completion.rs b/src/api_type/openai/chat_completion.rs index c6d40c4..cbc1f98 100644 --- a/src/api_type/openai/chat_completion.rs +++ b/src/api_type/openai/chat_completion.rs @@ -4,10 +4,16 @@ use crate::inspector::protocol::text::{TextBody, TextProtocol}; use serde::Deserialize; // https://platform.openai.com/docs/api-reference/chat/object +#[derive(Debug, Deserialize)] +struct PromptTokensDetails { + cached_tokens: Option, +} + #[derive(Debug, Deserialize)] struct OpenAiChatCompletionUsage { prompt_tokens: u64, completion_tokens: u64, + prompt_tokens_details: Option, } #[derive(Debug, Deserialize)] @@ -26,6 +32,7 @@ impl ApiTypeHandler for OpenAiChatCompletionHandler { &self, _status: u16, _headers: &http::HeaderMap, + _request_metadata: &crate::request_metadata::RequestInspectionMetadata, ) -> ResponseMetadataInspector { Box::new(ProtocolInspector::new( TextProtocol::new(), @@ -53,9 +60,15 @@ impl Inspector for OpenAiChatInspector { fn parse_chat_completion(data: &[u8]) -> Result { let parsed = serde_json::from_slice::(data)?; + let cache_read_input_tokens = parsed + .usage + .prompt_tokens_details + .and_then(|d| d.cached_tokens); Ok(ResponseMetadata { input_tokens: Some(parsed.usage.prompt_tokens), output_tokens: Some(parsed.usage.completion_tokens), + cache_read_input_tokens, + ..Default::default() }) } @@ -64,7 +77,11 @@ mod tests { use super::*; fn make_inspector() -> ResponseMetadataInspector { - OpenAiChatCompletionHandler.response_inspector(200, &http::HeaderMap::new()) + OpenAiChatCompletionHandler.response_inspector( + 200, + &http::HeaderMap::new(), + &crate::request_metadata::RequestInspectionMetadata::default(), + ) } #[test] @@ -96,4 +113,37 @@ mod tests { inspector.feed(b"not json"); assert!(inspector.finish().is_err()); } + + #[test] + fn inspect_response_cached_tokens() { + let body = br#"{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "model": "gpt-4o", + "usage": { + "prompt_tokens": 100, + "completion_tokens": 42, + "prompt_tokens_details": {"cached_tokens": 80} + } + }"#; + let mut inspector = make_inspector(); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + assert_eq!(metadata.input_tokens, Some(100)); + assert_eq!(metadata.cache_read_input_tokens, Some(80)); + } + + #[test] + fn inspect_response_no_cached_tokens() { + let body = br#"{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "model": "gpt-4o", + "usage": {"prompt_tokens": 15, "completion_tokens": 42} + }"#; + let mut inspector = make_inspector(); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + assert_eq!(metadata.cache_read_input_tokens, None); + } } diff --git a/src/api_type/openai/responses.rs b/src/api_type/openai/responses.rs index ced829f..7c6ef4a 100644 --- a/src/api_type/openai/responses.rs +++ b/src/api_type/openai/responses.rs @@ -4,10 +4,16 @@ use crate::inspector::protocol::text::{TextBody, TextProtocol}; use serde::Deserialize; // https://platform.openai.com/docs/api-reference/responses/object +#[derive(Debug, Deserialize)] +struct InputTokensDetails { + cached_tokens: Option, +} + #[derive(Debug, Deserialize)] struct OpenAiResponsesUsage { input_tokens: u64, output_tokens: u64, + input_tokens_details: Option, } #[derive(Debug, Deserialize)] @@ -26,6 +32,7 @@ impl ApiTypeHandler for OpenAiResponsesHandler { &self, _status: u16, _headers: &http::HeaderMap, + _request_metadata: &crate::request_metadata::RequestInspectionMetadata, ) -> ResponseMetadataInspector { Box::new(ProtocolInspector::new( TextProtocol::new(), @@ -53,9 +60,15 @@ impl Inspector for OpenAiResponsesInspector { fn parse_responses(data: &[u8]) -> Result { let parsed = serde_json::from_slice::(data)?; + let cache_read_input_tokens = parsed + .usage + .input_tokens_details + .and_then(|d| d.cached_tokens); Ok(ResponseMetadata { input_tokens: Some(parsed.usage.input_tokens), output_tokens: Some(parsed.usage.output_tokens), + cache_read_input_tokens, + ..Default::default() }) } @@ -64,7 +77,11 @@ mod tests { use super::*; fn make_inspector() -> ResponseMetadataInspector { - OpenAiResponsesHandler.response_inspector(200, &http::HeaderMap::new()) + OpenAiResponsesHandler.response_inspector( + 200, + &http::HeaderMap::new(), + &crate::request_metadata::RequestInspectionMetadata::default(), + ) } #[test] @@ -95,4 +112,37 @@ mod tests { inspector.feed(b"not json"); assert!(inspector.finish().is_err()); } + + #[test] + fn inspect_response_cached_tokens() { + let body = br#"{ + "id": "resp_abc123", + "object": "response", + "model": "gpt-4o", + "usage": { + "input_tokens": 100, + "output_tokens": 30, + "input_tokens_details": {"cached_tokens": 60} + } + }"#; + let mut inspector = make_inspector(); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + assert_eq!(metadata.input_tokens, Some(100)); + assert_eq!(metadata.cache_read_input_tokens, Some(60)); + } + + #[test] + fn inspect_response_no_cached_tokens() { + let body = br#"{ + "id": "resp_abc123", + "object": "response", + "model": "gpt-4o", + "usage": {"input_tokens": 10, "output_tokens": 30} + }"#; + let mut inspector = make_inspector(); + inspector.feed(body); + let metadata = inspector.finish().unwrap(); + assert_eq!(metadata.cache_read_input_tokens, None); + } } diff --git a/src/authorization.rs b/src/authorization.rs index ff144d4..05df22c 100644 --- a/src/authorization.rs +++ b/src/authorization.rs @@ -73,6 +73,7 @@ mod tests { user_agent: None, inspected: RequestInspectionMetadata { model: model.map(|m| m.to_string()), + ..Default::default() }, labels: std::collections::HashMap::new(), } diff --git a/src/http_handlers/proxy.rs b/src/http_handlers/proxy.rs index aa06e9d..c852f87 100644 --- a/src/http_handlers/proxy.rs +++ b/src/http_handlers/proxy.rs @@ -105,11 +105,16 @@ async fn try_forward_to_provider( info!(%method, %path, %status, ?model, "upstream_resp"); metrics::record_request(&request_metadata); - let upstream_res = - wrap_upstream_response(api_type_handler, upstream_res, move |result| match result { + let request_inspection_metadata = request_metadata.inspected.clone(); + let upstream_res = wrap_upstream_response( + api_type_handler, + upstream_res, + &request_inspection_metadata, + move |result| match result { Ok(metadata) => metrics::record_response(&request_metadata, metadata), Err(e) => warn!("Failed to inspect response: {e}"), - }); + }, + ); Ok(upstream_res) } @@ -141,13 +146,15 @@ pub async fn proxy_handler( fn wrap_upstream_response( api_type_handler: Option>, upstream_res: reqwest::Response, + request_inspection_metadata: &crate::request_metadata::RequestInspectionMetadata, on_response: impl FnOnce(&Result) + Send + 'static, ) -> Response { let status_code = upstream_res.status().as_u16(); let headers = upstream_res.headers().clone(); let body = if let Some(api_type_handler) = api_type_handler { - let inspector = api_type_handler.response_inspector(status_code, &headers); + let inspector = + api_type_handler.response_inspector(status_code, &headers, request_inspection_metadata); let inspector = if let Some(encoding) = headers .get("content-encoding") .and_then(|s| s.to_str().ok()) diff --git a/src/metrics.rs b/src/metrics.rs index e20dcff..2664821 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -34,4 +34,14 @@ pub fn record_response(request_metadata: &RequestMetadata, response_metadata: &R if total_tokens > 0 { metrics::counter!("lacuna_provider_tokens_total", &labels).increment(total_tokens); } + if let Some(map) = &response_metadata.cache_creation_tokens { + for (duration, tokens) in map { + let mut l = labels.clone(); + l.push(("cache_duration".to_owned(), duration.clone())); + metrics::counter!("lacuna_provider_tokens_cache_creation_total", &l).increment(*tokens); + } + } + if let Some(tokens) = response_metadata.cache_read_input_tokens { + metrics::counter!("lacuna_provider_tokens_cache_read_total", &labels).increment(tokens); + } } diff --git a/src/request_metadata.rs b/src/request_metadata.rs index 6db7d27..b2ce1b5 100644 --- a/src/request_metadata.rs +++ b/src/request_metadata.rs @@ -6,11 +6,14 @@ use std::collections::HashMap; pub struct ResponseMetadata { pub input_tokens: Option, pub output_tokens: Option, + pub cache_creation_tokens: Option>, + pub cache_read_input_tokens: Option, } #[derive(Debug, Clone, Default, PartialEq)] pub struct RequestInspectionMetadata { pub model: Option, + pub cache_ttl_secs: Option, } #[derive(Debug)] @@ -62,6 +65,7 @@ mod tests { user_agent: None, inspected: RequestInspectionMetadata { model: Some("gpt-4o".to_owned()), + ..Default::default() }, labels, }