diff --git a/e2e/scenarios/ai-sdk-instrumentation/assertions.ts b/e2e/scenarios/ai-sdk-instrumentation/assertions.ts index 2bcc5b64..8ddd56cd 100644 --- a/e2e/scenarios/ai-sdk-instrumentation/assertions.ts +++ b/e2e/scenarios/ai-sdk-instrumentation/assertions.ts @@ -189,6 +189,20 @@ function findStreamTrace(events: CapturedLogEvent[]) { return { child, operation, parent }; } +function findEmbedTrace(events: CapturedLogEvent[]) { + const operation = findLatestSpan(events, "ai-sdk-embed-operation"); + const parent = findParentSpan(events, "embed", operation?.span.id); + + return { operation, parent }; +} + +function findEmbedManyTrace(events: CapturedLogEvent[]) { + const operation = findLatestSpan(events, "ai-sdk-embed-many-operation"); + const parent = findParentSpan(events, "embedMany", operation?.span.id); + + return { operation, parent }; +} + function findToolTrace(events: CapturedLogEvent[]) { const operation = findLatestSpan(events, "ai-sdk-tool-operation"); const parent = findParentSpan(events, "generateText", operation?.span.id); @@ -641,6 +655,24 @@ function expectAISDKParentSpan(span: CapturedLogEvent | undefined) { ).toBe("string"); } +function expectEmbeddingTokenMetrics(span: CapturedLogEvent | undefined) { + const metrics = span?.metrics as Record | undefined; + const totalTokens = metrics?.tokens; + const promptTokens = metrics?.prompt_tokens; + + const tokenMetric = + typeof totalTokens === "number" + ? totalTokens + : typeof promptTokens === "number" + ? promptTokens + : undefined; + + expect(tokenMetric).toEqual(expect.any(Number)); + if (typeof tokenMetric === "number") { + expect(tokenMetric).toBeGreaterThan(0); + } +} + export function defineAISDKInstrumentationAssertions(options: { agentSpanName?: AgentSpanName; name: string; @@ -754,6 +786,50 @@ export function defineAISDKInstrumentationAssertions(options: { } }); + test("captures trace for embed()", testConfig, () => { + const root = findLatestSpan(events, ROOT_NAME); + const trace = findEmbedTrace(events); + + expectOperationParentedByRoot(trace.operation, root); + expectAISDKParentSpan(trace.parent); + expect(operationName(trace.operation)).toBe("embed"); + expectEmbeddingTokenMetrics(trace.parent); + const input = isRecord(trace.parent?.input) ? trace.parent.input : null; + expect(typeof input?.value).toBe("string"); + const output = extractOutputRecord(trace.parent); + expect(output).toBeDefined(); + if (output) { + expect(output.embedding).toBeUndefined(); + expect(output.embedding_length).toEqual(expect.any(Number)); + expect(output.embedding_length).toBeGreaterThan(0); + } + }); + + test("captures trace for embedMany()", testConfig, () => { + const root = findLatestSpan(events, ROOT_NAME); + const trace = findEmbedManyTrace(events); + + expectOperationParentedByRoot(trace.operation, root); + expectAISDKParentSpan(trace.parent); + expect(operationName(trace.operation)).toBe("embed-many"); + expectEmbeddingTokenMetrics(trace.parent); + const input = isRecord(trace.parent?.input) ? trace.parent.input : null; + expect(Array.isArray(input?.values)).toBe(true); + if (Array.isArray(input?.values)) { + expect(input.values.length).toBeGreaterThanOrEqual(2); + } + const output = extractOutputRecord(trace.parent); + expect(output).toBeDefined(); + if (output) { + expect(output.embeddings).toBeUndefined(); + expect(output.responses).toBeUndefined(); + expect(output.embedding_count).toEqual(expect.any(Number)); + expect(output.embedding_count).toBeGreaterThanOrEqual(2); + expect(output.embedding_length).toEqual(expect.any(Number)); + expect(output.embedding_length).toBeGreaterThan(0); + } + }); + if (options.supportsOutputObjectScenario) { test( "captures Output.object schema on generateText()", diff --git a/e2e/scenarios/ai-sdk-instrumentation/scenario.impl.mjs b/e2e/scenarios/ai-sdk-instrumentation/scenario.impl.mjs index 79c00f54..979b86e0 100644 --- a/e2e/scenarios/ai-sdk-instrumentation/scenario.impl.mjs +++ b/e2e/scenarios/ai-sdk-instrumentation/scenario.impl.mjs @@ -125,6 +125,9 @@ async function runAISDKInstrumentationScenario( ) { const instrumentedAI = decorateAI ? decorateAI(options.ai) : options.ai; const openaiModel = options.openai("gpt-4o-mini"); + const openaiEmbeddingModel = options.openai.textEmbeddingModel( + "text-embedding-3-small", + ); const sdkMajorVersion = parseMajorVersion(options.sdkVersion); const supportsRichInputScenarios = sdkMajorVersion >= 5; const outputObject = createOutputObjectIfSupported(options.ai); @@ -169,6 +172,28 @@ async function runAISDKInstrumentationScenario( } }); + await runOperation("ai-sdk-embed-operation", "embed", async () => { + await instrumentedAI.embed({ + model: openaiEmbeddingModel, + value: "Paris is the capital of France.", + }); + }); + + await runOperation( + "ai-sdk-embed-many-operation", + "embed-many", + async () => { + await instrumentedAI.embedMany({ + model: openaiEmbeddingModel, + values: [ + "Paris is in France.", + "Berlin is in Germany.", + "Vienna is in Austria.", + ], + }); + }, + ); + await runOperation("ai-sdk-tool-operation", "tool", async () => { const toolRequest = { model: openaiModel, diff --git a/js/src/auto-instrumentations/configs/ai-sdk.test.ts b/js/src/auto-instrumentations/configs/ai-sdk.test.ts new file mode 100644 index 00000000..9955bdda --- /dev/null +++ b/js/src/auto-instrumentations/configs/ai-sdk.test.ts @@ -0,0 +1,46 @@ +import { describe, expect, it } from "vitest"; +import { aiSDKChannels } from "../../instrumentation/plugins/ai-sdk-channels"; +import { aiSDKConfigs } from "./ai-sdk"; + +function findConfigsByFunctionName(functionName: string) { + return aiSDKConfigs.filter((config) => { + if (!("functionQuery" in config)) { + return false; + } + const query = config.functionQuery as { functionName?: unknown }; + return query.functionName === functionName; + }); +} + +describe("aiSDKConfigs", () => { + it("defines embed channels", () => { + expect(aiSDKChannels.embed.channelName).toBe("embed"); + expect(aiSDKChannels.embedMany.channelName).toBe("embedMany"); + }); + + it("instruments embed() in both ESM and CJS entrypoints", () => { + const embedConfigs = findConfigsByFunctionName("embed"); + + expect(embedConfigs).toHaveLength(2); + expect(embedConfigs.map((config) => config.channelName)).toEqual([ + aiSDKChannels.embed.channelName, + aiSDKChannels.embed.channelName, + ]); + expect(embedConfigs.map((config) => config.module.filePath).sort()).toEqual( + ["dist/index.js", "dist/index.mjs"], + ); + }); + + it("instruments embedMany() in both ESM and CJS entrypoints", () => { + const embedManyConfigs = findConfigsByFunctionName("embedMany"); + + expect(embedManyConfigs).toHaveLength(2); + expect(embedManyConfigs.map((config) => config.channelName)).toEqual([ + aiSDKChannels.embedMany.channelName, + aiSDKChannels.embedMany.channelName, + ]); + expect( + embedManyConfigs.map((config) => config.module.filePath).sort(), + ).toEqual(["dist/index.js", "dist/index.mjs"]); + }); +}); diff --git a/js/src/auto-instrumentations/configs/ai-sdk.ts b/js/src/auto-instrumentations/configs/ai-sdk.ts index 2c0ea914..8f94a28d 100644 --- a/js/src/auto-instrumentations/configs/ai-sdk.ts +++ b/js/src/auto-instrumentations/configs/ai-sdk.ts @@ -105,6 +105,58 @@ export const aiSDKConfigs: InstrumentationConfig[] = [ }, }, + // embed - async function + { + channelName: aiSDKChannels.embed.channelName, + module: { + name: "ai", + versionRange: ">=3.0.0", + filePath: "dist/index.mjs", + }, + functionQuery: { + functionName: "embed", + kind: "Async", + }, + }, + { + channelName: aiSDKChannels.embed.channelName, + module: { + name: "ai", + versionRange: ">=3.0.0", + filePath: "dist/index.js", + }, + functionQuery: { + functionName: "embed", + kind: "Async", + }, + }, + + // embedMany - async function + { + channelName: aiSDKChannels.embedMany.channelName, + module: { + name: "ai", + versionRange: ">=3.0.0", + filePath: "dist/index.mjs", + }, + functionQuery: { + functionName: "embedMany", + kind: "Async", + }, + }, + { + channelName: aiSDKChannels.embedMany.channelName, + module: { + name: "ai", + versionRange: ">=3.0.0", + filePath: "dist/index.js", + }, + functionQuery: { + functionName: "embedMany", + kind: "Async", + }, + }, + // streamObject - async function (v3 only, before the sync refactor in v4) { channelName: aiSDKChannels.streamObject.channelName, diff --git a/js/src/instrumentation/plugins/ai-sdk-channels.ts b/js/src/instrumentation/plugins/ai-sdk-channels.ts index 0678982d..e8ec7344 100644 --- a/js/src/instrumentation/plugins/ai-sdk-channels.ts +++ b/js/src/instrumentation/plugins/ai-sdk-channels.ts @@ -3,6 +3,8 @@ import type { ChannelSpanInfo } from "../core/types"; import type { AISDK, AISDKCallParams, + AISDKEmbedParams, + AISDKEmbeddingResult, AISDKResult, } from "../../vendor-sdk-types/ai-sdk"; @@ -69,6 +71,20 @@ export const aiSDKChannels = defineChannels("ai", { channelName: "streamObject.sync", kind: "sync-stream", }), + embed: channel<[AISDKEmbedParams], AISDKEmbeddingResult, AISDKChannelContext>( + { + channelName: "embed", + kind: "async", + }, + ), + embedMany: channel< + [AISDKEmbedParams], + AISDKEmbeddingResult, + AISDKChannelContext + >({ + channelName: "embedMany", + kind: "async", + }), agentGenerate: channel< [AISDKCallParams], AISDKStreamResult, diff --git a/js/src/instrumentation/plugins/ai-sdk-plugin.test.ts b/js/src/instrumentation/plugins/ai-sdk-plugin.test.ts index 599dd3e8..ffb3091a 100644 --- a/js/src/instrumentation/plugins/ai-sdk-plugin.test.ts +++ b/js/src/instrumentation/plugins/ai-sdk-plugin.test.ts @@ -883,6 +883,53 @@ describe("AI SDK utility functions", () => { expect(result).toBeUndefined(); }); }); + + describe("processAISDKEmbeddingOutput", () => { + it("should summarize single embedding length", () => { + const output = { + embedding: [0.1, 0.2, 0.3, 0.4], + usage: { + totalTokens: 10, + }, + }; + + const result = processAISDKEmbeddingOutput(output, []); + expect(result.embedding).toBeUndefined(); + expect(result.embedding_length).toBe(4); + expect(result.usage).toMatchObject({ + totalTokens: 10, + }); + }); + + it("should summarize embedding batches", () => { + const output = { + embeddings: [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + ], + }; + + const result = processAISDKEmbeddingOutput(output, []); + expect(result.embeddings).toBeUndefined(); + expect(result.embedding_count).toBe(2); + expect(result.embedding_length).toBe(3); + }); + + it("should omit non-whitelisted fields like responses", () => { + const output = { + embeddings: [[0.1, 0.2, 0.3]], + response: { body: "too much" }, + responses: [{ body: "way too much" }], + usage: { totalTokens: 8 }, + }; + + const result = processAISDKEmbeddingOutput(output, []); + expect(result.response).toBeUndefined(); + expect(result.responses).toBeUndefined(); + expect(result.usage).toMatchObject({ totalTokens: 8 }); + expect(result.embedding_count).toBe(1); + }); + }); }); // Helper functions exported for testing @@ -1038,12 +1085,17 @@ function extractGetterValues(obj: any): any { const getterNames = [ "text", "object", + "value", + "values", + "embedding", + "embeddings", "finishReason", "usage", "totalUsage", "toolCalls", "toolResults", "warnings", + "responses", "experimental_providerMetadata", "providerMetadata", "rawResponse", @@ -1210,3 +1262,45 @@ function processAISDKOutput(output: any, denyOutputPaths: string[]): any { return omit(merged, denyOutputPaths); } + +function processAISDKEmbeddingOutput( + output: any, + denyOutputPaths: string[], +): any { + if (!output || typeof output !== "object") { + return output; + } + + const processed: Record = {}; + const whitelistedFields = [ + "usage", + "totalUsage", + "warnings", + "providerMetadata", + "experimental_providerMetadata", + ]; + + for (const field of whitelistedFields) { + const value = output?.[field]; + if (value !== undefined && typeof value !== "function") { + processed[field] = value; + } + } + + if (Array.isArray(output?.embedding)) { + processed.embedding_length = output.embedding.length; + } + + if (Array.isArray(output?.embeddings)) { + processed.embedding_count = output.embeddings.length; + + const firstEmbedding = output.embeddings.find((item: unknown) => + Array.isArray(item), + ); + if (Array.isArray(firstEmbedding)) { + processed.embedding_length = firstEmbedding.length; + } + } + + return processed; +} diff --git a/js/src/instrumentation/plugins/ai-sdk-plugin.ts b/js/src/instrumentation/plugins/ai-sdk-plugin.ts index da1c3745..8d902298 100644 --- a/js/src/instrumentation/plugins/ai-sdk-plugin.ts +++ b/js/src/instrumentation/plugins/ai-sdk-plugin.ts @@ -1,5 +1,6 @@ import { BasePlugin } from "../core"; import { + traceAsyncChannel, traceStreamingChannel, traceSyncStreamChannel, unsubscribeAll, @@ -18,6 +19,8 @@ import { aiSDKChannels } from "./ai-sdk-channels"; import type { AISDK, AISDKCallParams, + AISDKEmbedParams, + AISDKEmbeddingResult, AISDKLanguageModel, AISDKModel, AISDKModelStreamChunk, @@ -71,6 +74,8 @@ const RUNTIME_DENY_OUTPUT_PATHS = Symbol.for( * - streamText (function returning stream) * - generateObject (async function) * - streamObject (function returning stream) + * - embed (async function) + * - embedMany (async function) * - Agent.generate (async method) * - Agent.stream (async method returning stream) * - ToolLoopAgent.generate (async method) @@ -108,7 +113,7 @@ export class AISDKPlugin extends BasePlugin { name: "generateText", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), extractOutput: (result, endEvent) => { finalizeAISDKChildTracing(endEvent as { [key: string]: unknown }); return processAISDKOutput( @@ -128,7 +133,7 @@ export class AISDKPlugin extends BasePlugin { name: "streamText", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), extractOutput: (result, endEvent) => processAISDKOutput( result, @@ -154,7 +159,7 @@ export class AISDKPlugin extends BasePlugin { name: "streamText", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), patchResult: ({ endEvent, result, span, startTime }) => patchAISDKStreamingResult({ defaultDenyOutputPaths: denyOutputPaths, @@ -172,7 +177,7 @@ export class AISDKPlugin extends BasePlugin { name: "generateObject", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), extractOutput: (result, endEvent) => { finalizeAISDKChildTracing(endEvent as { [key: string]: unknown }); return processAISDKOutput( @@ -192,7 +197,7 @@ export class AISDKPlugin extends BasePlugin { name: "streamObject", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), extractOutput: (result, endEvent) => processAISDKOutput( result, @@ -218,7 +223,7 @@ export class AISDKPlugin extends BasePlugin { name: "streamObject", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), patchResult: ({ endEvent, result, span, startTime }) => patchAISDKStreamingResult({ defaultDenyOutputPaths: denyOutputPaths, @@ -230,13 +235,47 @@ export class AISDKPlugin extends BasePlugin { }), ); + // embed - async embedding function + this.unsubscribers.push( + traceAsyncChannel(aiSDKChannels.embed, { + name: "embed", + type: SpanTypeAttribute.LLM, + extractInput: ([params], event) => + prepareAISDKEmbedInput(params, event.self), + extractOutput: (result, endEvent) => + processAISDKEmbeddingOutput( + result, + resolveDenyOutputPaths(endEvent, denyOutputPaths), + ), + extractMetrics: (result, _startTime, endEvent) => + extractTopLevelAISDKMetrics(result, endEvent), + }), + ); + + // embedMany - async embedding batch function + this.unsubscribers.push( + traceAsyncChannel(aiSDKChannels.embedMany, { + name: "embedMany", + type: SpanTypeAttribute.LLM, + extractInput: ([params], event) => + prepareAISDKEmbedInput(params, event.self), + extractOutput: (result, endEvent) => + processAISDKEmbeddingOutput( + result, + resolveDenyOutputPaths(endEvent, denyOutputPaths), + ), + extractMetrics: (result, _startTime, endEvent) => + extractTopLevelAISDKMetrics(result, endEvent), + }), + ); + // Agent.generate - async method this.unsubscribers.push( traceStreamingChannel(aiSDKChannels.agentGenerate, { name: "Agent.generate", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), extractOutput: (result, endEvent) => { finalizeAISDKChildTracing(endEvent as { [key: string]: unknown }); return processAISDKOutput( @@ -256,7 +295,7 @@ export class AISDKPlugin extends BasePlugin { name: "Agent.stream", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), extractOutput: (result, endEvent) => processAISDKOutput( result, @@ -282,7 +321,7 @@ export class AISDKPlugin extends BasePlugin { name: "Agent.stream", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), patchResult: ({ endEvent, result, span, startTime }) => patchAISDKStreamingResult({ defaultDenyOutputPaths: denyOutputPaths, @@ -300,7 +339,7 @@ export class AISDKPlugin extends BasePlugin { name: "ToolLoopAgent.generate", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), extractOutput: (result, endEvent) => { finalizeAISDKChildTracing(endEvent as { [key: string]: unknown }); return processAISDKOutput( @@ -320,7 +359,7 @@ export class AISDKPlugin extends BasePlugin { name: "ToolLoopAgent.stream", type: SpanTypeAttribute.LLM, extractInput: ([params], event, span) => - prepareAISDKInput(params, event, span, denyOutputPaths), + prepareAISDKCallInput(params, event, span, denyOutputPaths), extractOutput: (result, endEvent) => processAISDKOutput( result, @@ -376,7 +415,7 @@ function resolveDenyOutputPaths( return defaultDenyOutputPaths; } -interface ProcessInputSyncResult { +interface ProcessCallInputSyncResult { input: AISDKCallParams; outputPromise?: Promise<{ output: { @@ -510,7 +549,7 @@ const serializeOutputObject = ( const processInputAttachmentsSync = ( input: AISDKCallParams, -): ProcessInputSyncResult => { +): ProcessCallInputSyncResult => { if (!input) return { input }; const processed: AISDKCallParams = { ...input }; @@ -774,11 +813,13 @@ const convertDataToAttachment = ( /** * Process AI SDK input parameters, converting attachments as needed. */ -function processAISDKInput(params: AISDKCallParams): ProcessInputSyncResult { +function processAISDKCallInput( + params: AISDKCallParams, +): ProcessCallInputSyncResult { return processInputAttachmentsSync(params); } -function prepareAISDKInput( +function prepareAISDKCallInput( params: AISDKCallParams, event: { aiSDK?: AISDK; @@ -792,7 +833,7 @@ function prepareAISDKInput( input: unknown; metadata: Record; } { - const { input, outputPromise } = processAISDKInput(params); + const { input, outputPromise } = processAISDKCallInput(params); if (outputPromise && input && typeof input === "object") { outputPromise .then((resolvedData) => { @@ -808,7 +849,7 @@ function prepareAISDKInput( }); } - const metadata = extractMetadataFromParams(params, event.self); + const metadata = extractMetadataFromCallParams(params, event.self); const childTracing = prepareAISDKChildTracing( params, event.self, @@ -827,6 +868,19 @@ function prepareAISDKInput( }; } +function prepareAISDKEmbedInput( + params: AISDKEmbedParams, + self?: unknown, +): { + input: unknown; + metadata: Record; +} { + return { + input: { ...params }, + metadata: extractMetadataFromEmbedParams(params, self), + }; +} + function extractTopLevelAISDKMetrics( result: AISDKResult, event?: { [key: string]: unknown }, @@ -850,51 +904,65 @@ function hasModelChildTracing(event?: { [key: string]: unknown }): boolean { ); } -/** - * Extract metadata from AI SDK parameters. - * Includes model, provider, and integration info. - */ -function extractMetadataFromParams( - params: AISDKCallParams, - self?: unknown, -): Record { - const metadata: Record = { +function createAISDKIntegrationMetadata(): Record { + return { braintrust: { integration_name: "ai-sdk", sdk_language: "typescript", }, }; +} - // Extract model information - const agentModel = - self && +function resolveModelFromSelf(self?: unknown): AISDKModel | undefined { + return self && typeof self === "object" && "model" in self && (self as { model?: AISDKModel }).model - ? (self as { model?: AISDKModel }).model - : self && - typeof self === "object" && - "settings" in self && - (self as { settings?: { model?: AISDKModel } }).settings?.model - ? (self as { settings?: { model?: AISDKModel } }).settings?.model - : undefined; - const { model, provider } = serializeModelWithProvider( - params.model ?? agentModel, + ? (self as { model?: AISDKModel }).model + : self && + typeof self === "object" && + "settings" in self && + (self as { settings?: { model?: AISDKModel } }).settings?.model + ? (self as { settings?: { model?: AISDKModel } }).settings?.model + : undefined; +} + +function extractBaseMetadata( + model: AISDKModel | undefined, + self?: unknown, +): Record { + const metadata: Record = createAISDKIntegrationMetadata(); + const { model: modelId, provider } = serializeModelWithProvider( + model ?? resolveModelFromSelf(self), ); - if (model) { - metadata.model = model; + if (modelId) { + metadata.model = modelId; } if (provider) { metadata.provider = provider; } + return metadata; +} + +function extractMetadataFromCallParams( + params: AISDKCallParams, + self?: unknown, +): Record { + const metadata = extractBaseMetadata(params.model, self); const tools = serializeAISDKToolsForLogging(params.tools); if (tools) { metadata.tools = tools; } - return metadata; } +function extractMetadataFromEmbedParams( + params: AISDKEmbedParams, + self?: unknown, +): Record { + return extractBaseMetadata(params.model, self); +} + function prepareAISDKChildTracing( params: AISDKCallParams, self: unknown, @@ -958,7 +1026,7 @@ function prepareAISDKChildTracing( type: SpanTypeAttribute.LLM, }, event: { - input: processAISDKInput(options).input, + input: processAISDKCallInput(options).input, metadata: baseMetadata, }, }, @@ -975,7 +1043,7 @@ function prepareAISDKChildTracing( type: SpanTypeAttribute.LLM, }, event: { - input: processAISDKInput(options).input, + input: processAISDKCallInput(options).input, metadata: baseMetadata, }, }); @@ -1372,9 +1440,14 @@ function attachKnownResultPromiseHandlers( "content", "text", "object", + "value", + "values", "finishReason", + "embedding", + "embeddings", "usage", "totalUsage", + "responses", "steps", ]; @@ -1615,6 +1688,48 @@ function processAISDKOutput( return normalizeAISDKLoggedOutput(omit(merged, denyOutputPaths)); } +function processAISDKEmbeddingOutput( + output: AISDKEmbeddingResult, + denyOutputPaths: string[], +): Record | AISDKEmbeddingResult { + if (!output || typeof output !== "object") { + return output; + } + + const summarized: Record = {}; + const whitelistedFields = [ + "usage", + "totalUsage", + "warnings", + "providerMetadata", + "experimental_providerMetadata", + ] as const; + + for (const field of whitelistedFields) { + const value = safeSerializableFieldRead(output, field); + if (value !== undefined && isSerializableOutputValue(value)) { + summarized[field] = value; + } + } + + const embedding = safeSerializableFieldRead(output, "embedding"); + if (Array.isArray(embedding)) { + summarized.embedding_length = embedding.length; + } + + const embeddings = safeSerializableFieldRead(output, "embeddings"); + if (Array.isArray(embeddings)) { + summarized.embedding_count = embeddings.length; + + const firstEmbedding = embeddings.find((item) => Array.isArray(item)); + if (Array.isArray(firstEmbedding)) { + summarized.embedding_length = firstEmbedding.length; + } + } + + return normalizeAISDKLoggedOutput(omit(summarized, denyOutputPaths)); +} + /** * Extract token metrics from AI SDK result. */ @@ -1763,12 +1878,17 @@ function extractGetterValues( "content", "text", "object", + "value", + "values", + "embedding", + "embeddings", "finishReason", "usage", "totalUsage", "toolCalls", "toolResults", "warnings", + "responses", "experimental_providerMetadata", "providerMetadata", "rawResponse", diff --git a/js/src/vendor-sdk-types/ai-sdk-common.ts b/js/src/vendor-sdk-types/ai-sdk-common.ts index 15e4c1ed..4885dbac 100644 --- a/js/src/vendor-sdk-types/ai-sdk-common.ts +++ b/js/src/vendor-sdk-types/ai-sdk-common.ts @@ -149,6 +149,13 @@ export interface AISDKTool { export type AISDKTools = AISDKTool[] | Record; +export interface AISDKEmbedParams { + model?: AISDKModel; + value?: unknown; + values?: unknown[]; + [key: string]: unknown; +} + export interface AISDKCallParams { model?: AISDKModel; prompt?: AISDKMessage[] | Record; @@ -186,6 +193,18 @@ export interface AISDKResult { [key: string]: unknown; } +export interface AISDKEmbeddingResult extends AISDKResult { + embedding?: number[]; + embeddings?: number[][]; + value?: unknown; + values?: unknown[]; + responses?: unknown[]; +} + +export type AISDKEmbedFunction = ( + params: AISDKEmbedParams, +) => Promise; + export type AISDKGenerateFunction = ( params: AISDKCallParams, ) => Promise; @@ -216,6 +235,8 @@ export interface AISDKNamespaceBase { streamText: AISDKStreamFunction; generateObject: AISDKGenerateFunction; streamObject: AISDKStreamFunction; + embed: AISDKEmbedFunction; + embedMany: AISDKEmbedFunction; gateway?: AISDKProviderResolver; [key: string]: unknown; } diff --git a/js/src/vendor-sdk-types/ai-sdk.ts b/js/src/vendor-sdk-types/ai-sdk.ts index 7c5c5b11..1b9f4745 100644 --- a/js/src/vendor-sdk-types/ai-sdk.ts +++ b/js/src/vendor-sdk-types/ai-sdk.ts @@ -2,6 +2,9 @@ import type { AISDKAgentClass, AISDKAgentInstance, AISDKCallParams, + AISDKEmbedFunction, + AISDKEmbedParams, + AISDKEmbeddingResult, AISDKGeneratedFile, AISDKGenerateFunction, AISDKLanguageModel, @@ -41,6 +44,9 @@ export type { AISDKAgentInstance, AISDKAsyncOutputObject, AISDKCallParams, + AISDKEmbedFunction, + AISDKEmbedParams, + AISDKEmbeddingResult, AISDKGeneratedFile, AISDKGenerateFunction, AISDKLanguageModel, diff --git a/js/src/wrappers/ai-sdk/ai-sdk.ts b/js/src/wrappers/ai-sdk/ai-sdk.ts index 312f3bfb..97e12d6b 100644 --- a/js/src/wrappers/ai-sdk/ai-sdk.ts +++ b/js/src/wrappers/ai-sdk/ai-sdk.ts @@ -7,6 +7,8 @@ import type { AISDKAgentClass, AISDKAgentInstance, AISDKCallParams, + AISDKEmbedFunction, + AISDKEmbedParams, AISDKGenerateFunction, AISDKStreamFunction, } from "../../vendor-sdk-types/ai-sdk"; @@ -122,6 +124,10 @@ export function wrapAISDK(aiSDK: T, options: WrapAISDKOptions = {}): T { ); case "streamObject": return wrapStreamObject(typedAISDK.streamObject, options, typedAISDK); + case "embed": + return wrapEmbed(typedAISDK.embed, options, typedAISDK); + case "embedMany": + return wrapEmbedMany(typedAISDK.embedMany, options, typedAISDK); case "Agent": case "Experimental_Agent": case "ToolLoopAgent": @@ -269,6 +275,66 @@ const wrapGenerateObject = ( ); }; +const makeEmbedWrapper = ( + channel: typeof aiSDKChannels.embed | typeof aiSDKChannels.embedMany, + name: string, + embed: AISDKEmbedFunction, + contextOptions: { + aiSDK?: AISDK; + self?: unknown; + spanType?: SpanTypeAttribute; + } = {}, + options: WrapAISDKOptions = {}, +) => { + const wrapper = async function (allParams: AISDKEmbedParams & SpanInfo) { + const { span_info, ...params } = allParams; + const tracedParams = { ...params }; + + return channel.tracePromise( + () => embed(tracedParams), + createAISDKChannelContext(tracedParams, { + aiSDK: contextOptions.aiSDK, + denyOutputPaths: options.denyOutputPaths, + self: contextOptions.self, + span_info: mergeSpanInfo(span_info, { + name, + spanType: contextOptions.spanType, + }), + }), + ); + }; + Object.defineProperty(wrapper, "name", { value: name, writable: false }); + return wrapper; +}; + +const wrapEmbed = ( + embed: AISDKEmbedFunction, + options: WrapAISDKOptions = {}, + aiSDK?: AISDK, +) => { + return makeEmbedWrapper( + aiSDKChannels.embed, + "embed", + embed, + { aiSDK }, + options, + ); +}; + +const wrapEmbedMany = ( + embedMany: AISDKEmbedFunction, + options: WrapAISDKOptions = {}, + aiSDK?: AISDK, +) => { + return makeEmbedWrapper( + aiSDKChannels.embedMany, + "embedMany", + embedMany, + { aiSDK }, + options, + ); +}; + const makeStreamWrapper = ( channel: | typeof aiSDKChannels.streamText @@ -362,8 +428,8 @@ function mergeSpanInfo( }; } -function createAISDKChannelContext( - params: AISDKCallParams, +function createAISDKChannelContext>( + params: TParams, context: { aiSDK?: AISDK; denyOutputPaths?: string[]; @@ -372,7 +438,7 @@ function createAISDKChannelContext( } = {}, ) { return { - arguments: [params] as [AISDKCallParams], + arguments: [params] as [TParams], ...(context.aiSDK ? { aiSDK: context.aiSDK } : {}), ...(context.denyOutputPaths ? { denyOutputPaths: context.denyOutputPaths }