Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/app/api/runtimes/[id]/talk/realtime/relay/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ describe("POST /api/runtimes/[id]/talk/realtime/relay", () => {
});
});

it("proxies realtime output cancellation through an accessible runtime", async () => {
const realtimeRelayCancelOutput = vi.fn().mockResolvedValue({ ok: true });
mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" });
mockGetGatewayClientForRuntime.mockResolvedValue({ realtimeRelayCancelOutput });

const response = await POST(
new Request("http://localhost/api/runtimes/rt_1/talk/realtime/relay", {
method: "POST",
body: JSON.stringify({
action: "cancelOutput",
relaySessionId: "relay_1",
reason: "barge-in",
}),
}),
{ params: Promise.resolve({ id: "rt_1" }) },
);

expect(response.status).toBe(200);
await expect(response.json()).resolves.toEqual({ result: { ok: true } });
expect(realtimeRelayCancelOutput).toHaveBeenCalledWith("relay_1", "barge-in");
});

it("rejects invalid relay actions before calling the gateway", async () => {
mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" });

Expand Down
14 changes: 11 additions & 3 deletions src/app/api/runtimes/[id]/talk/realtime/relay/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { getGatewayClientForRuntime } from "@/lib/gateway-chat-pool";

export const dynamic = "force-dynamic";

type RelayAction = "audio" | "mark" | "toolResult" | "stop";
type RelayAction = "audio" | "cancelOutput" | "mark" | "toolResult" | "stop";

export async function POST(
request: Request,
Expand Down Expand Up @@ -53,6 +53,14 @@ export async function POST(
return NextResponse.json({ result });
}

if (action === "cancelOutput") {
const result = await client.realtimeRelayCancelOutput(
relaySessionId,
readOptionalString(body.reason),
);
return NextResponse.json({ result });
}

if (action === "toolResult") {
const callId = readRequiredString(body.callId, "callId");
const result = await client.realtimeRelayToolResult({
Expand Down Expand Up @@ -88,6 +96,6 @@ function readOptionalString(value: unknown) {
}

function readRelayAction(value: unknown): RelayAction {
if (value === "audio" || value === "mark" || value === "toolResult" || value === "stop") return value;
throw new ValidationError("action must be audio, mark, toolResult, or stop");
if (value === "audio" || value === "cancelOutput" || value === "mark" || value === "toolResult" || value === "stop") return value;
throw new ValidationError("action must be audio, cancelOutput, mark, toolResult, or stop");
}
12 changes: 12 additions & 0 deletions src/lib/gateway-client-realtime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,16 @@ describe("GatewayClient realtime Talk compatibility", () => {

expect(rpc).not.toHaveBeenCalled();
});

it("maps output cancellation onto the unified session API", async () => {
const client = new GatewayClient("ws://localhost:18789", null, device);
const rpc = vi.spyOn(client, "rpc").mockResolvedValue({ ok: true });

await expect(client.realtimeRelayCancelOutput("relay_1", "barge-in")).resolves.toEqual({ ok: true });

expect(rpc).toHaveBeenCalledWith("talk.session.cancelOutput", {
sessionId: "relay_1",
reason: "barge-in",
});
});
});
12 changes: 12 additions & 0 deletions src/lib/gateway-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,18 @@ export class GatewayClient {
return { ok: true };
}

async realtimeRelayCancelOutput(relaySessionId: string, reason?: string): Promise<{ ok?: boolean }> {
try {
return await this.rpc<{ ok?: boolean }>("talk.session.cancelOutput", withoutUndefined({
sessionId: relaySessionId,
reason,
}));
} catch (err) {
if (!isLikelyMissingGatewayMethod(err)) throw err;
return { ok: true };
}
}

async realtimeRelayToolResult(params: GatewayRealtimeRelayToolResultParams): Promise<{ ok?: boolean }> {
try {
return await this.rpc<{ ok?: boolean }>("talk.session.submitToolResult", withoutUndefined({
Expand Down
20 changes: 20 additions & 0 deletions src/lib/realtime-voice-client.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import {
cancelRealtimeRelayOutput,
openRealtimeRelayEvents,
sendRealtimeRelayAudio,
startRealtimeVoiceSession,
Expand Down Expand Up @@ -57,6 +58,25 @@ describe("realtime voice client helpers", () => {
});
});

it("sends output cancellation through the runtime relay route", async () => {
const fetchMock = vi.spyOn(globalThis, "fetch").mockResolvedValue({
ok: true,
json: () => Promise.resolve({ result: { ok: true } }),
} as Response);

await cancelRealtimeRelayOutput("rt_1", "relay_1", "barge-in");

expect(fetchMock).toHaveBeenCalledWith("/api/runtimes/rt_1/talk/realtime/relay", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
action: "cancelOutput",
relaySessionId: "relay_1",
reason: "barge-in",
}),
});
});

it("opens relay events for the selected runtime and session", () => {
const eventSourceMock = vi.fn();
vi.stubGlobal("EventSource", eventSourceMock);
Expand Down
12 changes: 12 additions & 0 deletions src/lib/realtime-voice-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ export async function sendRealtimeRelayMark(runtimeId: string, relaySessionId: s
});
}

export async function cancelRealtimeRelayOutput(
runtimeId: string,
relaySessionId: string,
reason = "barge-in",
): Promise<void> {
await postRealtimeRelay(runtimeId, {
action: "cancelOutput",
relaySessionId,
reason,
});
}

export async function sendRealtimeRelayToolResult(
runtimeId: string,
relaySessionId: string,
Expand Down
49 changes: 48 additions & 1 deletion src/lib/realtime-voice-gateway-relay.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {
cancelRealtimeRelayOutput,
openRealtimeRelayEvents,
sendRealtimeRelayAudio,
sendRealtimeRelayMark,
Expand All @@ -10,6 +11,10 @@ import { base64ToBytes, bytesToBase64, floatToPcm16, pcm16ToFloat, rmsLevel } fr

export type RealtimeVoiceStatus = "idle" | "listening" | "processing" | "speaking" | "error";

const BARGE_IN_RMS_THRESHOLD = 0.02;
const BARGE_IN_PEAK_THRESHOLD = 0.08;
const BARGE_IN_FRAMES = 2;

export interface RealtimeGatewayRelayCallbacks {
onStatus?: (status: RealtimeVoiceStatus, message?: string) => void;
onTranscript?: (event: { role: "user" | "assistant"; text: string; final: boolean }) => void;
Expand Down Expand Up @@ -50,6 +55,8 @@ export class RealtimeGatewayRelaySession {
private readonly sources = new Set<AudioBufferSourceNode>();
private playhead = 0;
private closed = false;
private cancelRequestedForPlayback = false;
private speechFramesDuringPlayback = 0;

constructor(
private readonly runtimeId: string,
Expand Down Expand Up @@ -128,6 +135,7 @@ export class RealtimeGatewayRelaySession {
}
this.sources.clear();
this.playhead = this.outputContext?.currentTime ?? 0;
this.speechFramesDuringPlayback = 0;
this.callbacks.onSpeakingChange?.(false);
}

Expand All @@ -141,6 +149,7 @@ export class RealtimeGatewayRelaySession {
const input = event.inputBuffer.getChannelData(0);
this.callbacks.onVoiceLevel?.(rmsLevel(input));
const pcm = floatToPcm16(input);
if (this.detectBargeInSpeech(input)) this.cancelOutputForBargeIn();
void sendRealtimeRelayAudio(this.runtimeId, {
relaySessionId: this.session.relaySessionId,
audioBase64: bytesToBase64(pcm),
Expand All @@ -162,7 +171,10 @@ export class RealtimeGatewayRelaySession {
this.callbacks.onStatus?.("listening");
return;
case "audio":
if (event.audioBase64) this.playPcm16(event.audioBase64);
if (event.audioBase64) {
this.cancelRequestedForPlayback = false;
this.playPcm16(event.audioBase64);
}
return;
case "clear":
this.stopOutput();
Expand Down Expand Up @@ -235,4 +247,39 @@ export class RealtimeGatewayRelaySession {
name: event.name ?? null,
}).catch(() => {});
}

private cancelOutputForBargeIn(): void {
if (!this.session.relaySessionId || this.cancelRequestedForPlayback) return;
this.cancelRequestedForPlayback = true;
this.stopOutput();
this.callbacks.onStatus?.("listening", "Interrupted");
void cancelRealtimeRelayOutput(this.runtimeId, this.session.relaySessionId, "barge-in").catch((error) => {
const message = error instanceof Error ? error.message : String(error);
this.callbacks.onError?.(message);
});
}

private detectBargeInSpeech(input: Float32Array): boolean {
if (this.sources.size === 0 || this.cancelRequestedForPlayback || input.length === 0) {
this.speechFramesDuringPlayback = 0;
return false;
}

let peak = 0;
let sum = 0;
for (const sample of input) {
const abs = Math.abs(sample);
peak = Math.max(peak, abs);
sum += sample * sample;
}

const rms = Math.sqrt(sum / input.length);
if (rms >= BARGE_IN_RMS_THRESHOLD && peak >= BARGE_IN_PEAK_THRESHOLD) {
this.speechFramesDuringPlayback += 1;
} else {
this.speechFramesDuringPlayback = 0;
}

return this.speechFramesDuringPlayback >= BARGE_IN_FRAMES;
}
}
Loading