Skip to content

Commit dcda0b8

Browse files
authored
feat: Add embedding instrumentation to AI SDK (#1754)
Resolves #1625
1 parent 7338811 commit dcda0b8

10 files changed

Lines changed: 568 additions & 46 deletions

File tree

e2e/scenarios/ai-sdk-instrumentation/assertions.ts

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,20 @@ function findStreamTrace(events: CapturedLogEvent[]) {
189189
return { child, operation, parent };
190190
}
191191

192+
function findEmbedTrace(events: CapturedLogEvent[]) {
193+
const operation = findLatestSpan(events, "ai-sdk-embed-operation");
194+
const parent = findParentSpan(events, "embed", operation?.span.id);
195+
196+
return { operation, parent };
197+
}
198+
199+
function findEmbedManyTrace(events: CapturedLogEvent[]) {
200+
const operation = findLatestSpan(events, "ai-sdk-embed-many-operation");
201+
const parent = findParentSpan(events, "embedMany", operation?.span.id);
202+
203+
return { operation, parent };
204+
}
205+
192206
function findToolTrace(events: CapturedLogEvent[]) {
193207
const operation = findLatestSpan(events, "ai-sdk-tool-operation");
194208
const parent = findParentSpan(events, "generateText", operation?.span.id);
@@ -641,6 +655,24 @@ function expectAISDKParentSpan(span: CapturedLogEvent | undefined) {
641655
).toBe("string");
642656
}
643657

658+
function expectEmbeddingTokenMetrics(span: CapturedLogEvent | undefined) {
659+
const metrics = span?.metrics as Record<string, unknown> | undefined;
660+
const totalTokens = metrics?.tokens;
661+
const promptTokens = metrics?.prompt_tokens;
662+
663+
const tokenMetric =
664+
typeof totalTokens === "number"
665+
? totalTokens
666+
: typeof promptTokens === "number"
667+
? promptTokens
668+
: undefined;
669+
670+
expect(tokenMetric).toEqual(expect.any(Number));
671+
if (typeof tokenMetric === "number") {
672+
expect(tokenMetric).toBeGreaterThan(0);
673+
}
674+
}
675+
644676
export function defineAISDKInstrumentationAssertions(options: {
645677
agentSpanName?: AgentSpanName;
646678
name: string;
@@ -754,6 +786,50 @@ export function defineAISDKInstrumentationAssertions(options: {
754786
}
755787
});
756788

789+
test("captures trace for embed()", testConfig, () => {
790+
const root = findLatestSpan(events, ROOT_NAME);
791+
const trace = findEmbedTrace(events);
792+
793+
expectOperationParentedByRoot(trace.operation, root);
794+
expectAISDKParentSpan(trace.parent);
795+
expect(operationName(trace.operation)).toBe("embed");
796+
expectEmbeddingTokenMetrics(trace.parent);
797+
const input = isRecord(trace.parent?.input) ? trace.parent.input : null;
798+
expect(typeof input?.value).toBe("string");
799+
const output = extractOutputRecord(trace.parent);
800+
expect(output).toBeDefined();
801+
if (output) {
802+
expect(output.embedding).toBeUndefined();
803+
expect(output.embedding_length).toEqual(expect.any(Number));
804+
expect(output.embedding_length).toBeGreaterThan(0);
805+
}
806+
});
807+
808+
test("captures trace for embedMany()", testConfig, () => {
809+
const root = findLatestSpan(events, ROOT_NAME);
810+
const trace = findEmbedManyTrace(events);
811+
812+
expectOperationParentedByRoot(trace.operation, root);
813+
expectAISDKParentSpan(trace.parent);
814+
expect(operationName(trace.operation)).toBe("embed-many");
815+
expectEmbeddingTokenMetrics(trace.parent);
816+
const input = isRecord(trace.parent?.input) ? trace.parent.input : null;
817+
expect(Array.isArray(input?.values)).toBe(true);
818+
if (Array.isArray(input?.values)) {
819+
expect(input.values.length).toBeGreaterThanOrEqual(2);
820+
}
821+
const output = extractOutputRecord(trace.parent);
822+
expect(output).toBeDefined();
823+
if (output) {
824+
expect(output.embeddings).toBeUndefined();
825+
expect(output.responses).toBeUndefined();
826+
expect(output.embedding_count).toEqual(expect.any(Number));
827+
expect(output.embedding_count).toBeGreaterThanOrEqual(2);
828+
expect(output.embedding_length).toEqual(expect.any(Number));
829+
expect(output.embedding_length).toBeGreaterThan(0);
830+
}
831+
});
832+
757833
if (options.supportsOutputObjectScenario) {
758834
test(
759835
"captures Output.object schema on generateText()",

e2e/scenarios/ai-sdk-instrumentation/scenario.impl.mjs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ async function runAISDKInstrumentationScenario(
125125
) {
126126
const instrumentedAI = decorateAI ? decorateAI(options.ai) : options.ai;
127127
const openaiModel = options.openai("gpt-4o-mini");
128+
const openaiEmbeddingModel = options.openai.textEmbeddingModel(
129+
"text-embedding-3-small",
130+
);
128131
const sdkMajorVersion = parseMajorVersion(options.sdkVersion);
129132
const supportsRichInputScenarios = sdkMajorVersion >= 5;
130133
const outputObject = createOutputObjectIfSupported(options.ai);
@@ -169,6 +172,28 @@ async function runAISDKInstrumentationScenario(
169172
}
170173
});
171174

175+
await runOperation("ai-sdk-embed-operation", "embed", async () => {
176+
await instrumentedAI.embed({
177+
model: openaiEmbeddingModel,
178+
value: "Paris is the capital of France.",
179+
});
180+
});
181+
182+
await runOperation(
183+
"ai-sdk-embed-many-operation",
184+
"embed-many",
185+
async () => {
186+
await instrumentedAI.embedMany({
187+
model: openaiEmbeddingModel,
188+
values: [
189+
"Paris is in France.",
190+
"Berlin is in Germany.",
191+
"Vienna is in Austria.",
192+
],
193+
});
194+
},
195+
);
196+
172197
await runOperation("ai-sdk-tool-operation", "tool", async () => {
173198
const toolRequest = {
174199
model: openaiModel,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import { describe, expect, it } from "vitest";
2+
import { aiSDKChannels } from "../../instrumentation/plugins/ai-sdk-channels";
3+
import { aiSDKConfigs } from "./ai-sdk";
4+
5+
function findConfigsByFunctionName(functionName: string) {
6+
return aiSDKConfigs.filter((config) => {
7+
if (!("functionQuery" in config)) {
8+
return false;
9+
}
10+
const query = config.functionQuery as { functionName?: unknown };
11+
return query.functionName === functionName;
12+
});
13+
}
14+
15+
describe("aiSDKConfigs", () => {
16+
it("defines embed channels", () => {
17+
expect(aiSDKChannels.embed.channelName).toBe("embed");
18+
expect(aiSDKChannels.embedMany.channelName).toBe("embedMany");
19+
});
20+
21+
it("instruments embed() in both ESM and CJS entrypoints", () => {
22+
const embedConfigs = findConfigsByFunctionName("embed");
23+
24+
expect(embedConfigs).toHaveLength(2);
25+
expect(embedConfigs.map((config) => config.channelName)).toEqual([
26+
aiSDKChannels.embed.channelName,
27+
aiSDKChannels.embed.channelName,
28+
]);
29+
expect(embedConfigs.map((config) => config.module.filePath).sort()).toEqual(
30+
["dist/index.js", "dist/index.mjs"],
31+
);
32+
});
33+
34+
it("instruments embedMany() in both ESM and CJS entrypoints", () => {
35+
const embedManyConfigs = findConfigsByFunctionName("embedMany");
36+
37+
expect(embedManyConfigs).toHaveLength(2);
38+
expect(embedManyConfigs.map((config) => config.channelName)).toEqual([
39+
aiSDKChannels.embedMany.channelName,
40+
aiSDKChannels.embedMany.channelName,
41+
]);
42+
expect(
43+
embedManyConfigs.map((config) => config.module.filePath).sort(),
44+
).toEqual(["dist/index.js", "dist/index.mjs"]);
45+
});
46+
});

js/src/auto-instrumentations/configs/ai-sdk.ts

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,58 @@ export const aiSDKConfigs: InstrumentationConfig[] = [
105105
},
106106
},
107107

108+
// embed - async function
109+
{
110+
channelName: aiSDKChannels.embed.channelName,
111+
module: {
112+
name: "ai",
113+
versionRange: ">=3.0.0",
114+
filePath: "dist/index.mjs",
115+
},
116+
functionQuery: {
117+
functionName: "embed",
118+
kind: "Async",
119+
},
120+
},
121+
{
122+
channelName: aiSDKChannels.embed.channelName,
123+
module: {
124+
name: "ai",
125+
versionRange: ">=3.0.0",
126+
filePath: "dist/index.js",
127+
},
128+
functionQuery: {
129+
functionName: "embed",
130+
kind: "Async",
131+
},
132+
},
133+
134+
// embedMany - async function
135+
{
136+
channelName: aiSDKChannels.embedMany.channelName,
137+
module: {
138+
name: "ai",
139+
versionRange: ">=3.0.0",
140+
filePath: "dist/index.mjs",
141+
},
142+
functionQuery: {
143+
functionName: "embedMany",
144+
kind: "Async",
145+
},
146+
},
147+
{
148+
channelName: aiSDKChannels.embedMany.channelName,
149+
module: {
150+
name: "ai",
151+
versionRange: ">=3.0.0",
152+
filePath: "dist/index.js",
153+
},
154+
functionQuery: {
155+
functionName: "embedMany",
156+
kind: "Async",
157+
},
158+
},
159+
108160
// streamObject - async function (v3 only, before the sync refactor in v4)
109161
{
110162
channelName: aiSDKChannels.streamObject.channelName,

js/src/instrumentation/plugins/ai-sdk-channels.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import type { ChannelSpanInfo } from "../core/types";
33
import type {
44
AISDK,
55
AISDKCallParams,
6+
AISDKEmbedParams,
7+
AISDKEmbeddingResult,
68
AISDKResult,
79
} from "../../vendor-sdk-types/ai-sdk";
810

@@ -69,6 +71,20 @@ export const aiSDKChannels = defineChannels("ai", {
6971
channelName: "streamObject.sync",
7072
kind: "sync-stream",
7173
}),
74+
embed: channel<[AISDKEmbedParams], AISDKEmbeddingResult, AISDKChannelContext>(
75+
{
76+
channelName: "embed",
77+
kind: "async",
78+
},
79+
),
80+
embedMany: channel<
81+
[AISDKEmbedParams],
82+
AISDKEmbeddingResult,
83+
AISDKChannelContext
84+
>({
85+
channelName: "embedMany",
86+
kind: "async",
87+
}),
7288
agentGenerate: channel<
7389
[AISDKCallParams],
7490
AISDKStreamResult,

js/src/instrumentation/plugins/ai-sdk-plugin.test.ts

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,53 @@ describe("AI SDK utility functions", () => {
883883
expect(result).toBeUndefined();
884884
});
885885
});
886+
887+
describe("processAISDKEmbeddingOutput", () => {
888+
it("should summarize single embedding length", () => {
889+
const output = {
890+
embedding: [0.1, 0.2, 0.3, 0.4],
891+
usage: {
892+
totalTokens: 10,
893+
},
894+
};
895+
896+
const result = processAISDKEmbeddingOutput(output, []);
897+
expect(result.embedding).toBeUndefined();
898+
expect(result.embedding_length).toBe(4);
899+
expect(result.usage).toMatchObject({
900+
totalTokens: 10,
901+
});
902+
});
903+
904+
it("should summarize embedding batches", () => {
905+
const output = {
906+
embeddings: [
907+
[0.1, 0.2, 0.3],
908+
[0.4, 0.5, 0.6],
909+
],
910+
};
911+
912+
const result = processAISDKEmbeddingOutput(output, []);
913+
expect(result.embeddings).toBeUndefined();
914+
expect(result.embedding_count).toBe(2);
915+
expect(result.embedding_length).toBe(3);
916+
});
917+
918+
it("should omit non-whitelisted fields like responses", () => {
919+
const output = {
920+
embeddings: [[0.1, 0.2, 0.3]],
921+
response: { body: "too much" },
922+
responses: [{ body: "way too much" }],
923+
usage: { totalTokens: 8 },
924+
};
925+
926+
const result = processAISDKEmbeddingOutput(output, []);
927+
expect(result.response).toBeUndefined();
928+
expect(result.responses).toBeUndefined();
929+
expect(result.usage).toMatchObject({ totalTokens: 8 });
930+
expect(result.embedding_count).toBe(1);
931+
});
932+
});
886933
});
887934

888935
// Helper functions exported for testing
@@ -1038,12 +1085,17 @@ function extractGetterValues(obj: any): any {
10381085
const getterNames = [
10391086
"text",
10401087
"object",
1088+
"value",
1089+
"values",
1090+
"embedding",
1091+
"embeddings",
10411092
"finishReason",
10421093
"usage",
10431094
"totalUsage",
10441095
"toolCalls",
10451096
"toolResults",
10461097
"warnings",
1098+
"responses",
10471099
"experimental_providerMetadata",
10481100
"providerMetadata",
10491101
"rawResponse",
@@ -1210,3 +1262,45 @@ function processAISDKOutput(output: any, denyOutputPaths: string[]): any {
12101262

12111263
return omit(merged, denyOutputPaths);
12121264
}
1265+
1266+
function processAISDKEmbeddingOutput(
1267+
output: any,
1268+
denyOutputPaths: string[],
1269+
): any {
1270+
if (!output || typeof output !== "object") {
1271+
return output;
1272+
}
1273+
1274+
const processed: Record<string, unknown> = {};
1275+
const whitelistedFields = [
1276+
"usage",
1277+
"totalUsage",
1278+
"warnings",
1279+
"providerMetadata",
1280+
"experimental_providerMetadata",
1281+
];
1282+
1283+
for (const field of whitelistedFields) {
1284+
const value = output?.[field];
1285+
if (value !== undefined && typeof value !== "function") {
1286+
processed[field] = value;
1287+
}
1288+
}
1289+
1290+
if (Array.isArray(output?.embedding)) {
1291+
processed.embedding_length = output.embedding.length;
1292+
}
1293+
1294+
if (Array.isArray(output?.embeddings)) {
1295+
processed.embedding_count = output.embeddings.length;
1296+
1297+
const firstEmbedding = output.embeddings.find((item: unknown) =>
1298+
Array.isArray(item),
1299+
);
1300+
if (Array.isArray(firstEmbedding)) {
1301+
processed.embedding_length = firstEmbedding.length;
1302+
}
1303+
}
1304+
1305+
return processed;
1306+
}

0 commit comments

Comments
 (0)