diff --git a/src/app/api/runtimes/[id]/talk/realtime/relay/route.test.ts b/src/app/api/runtimes/[id]/talk/realtime/relay/route.test.ts index e4384a7e..79b89335 100644 --- a/src/app/api/runtimes/[id]/talk/realtime/relay/route.test.ts +++ b/src/app/api/runtimes/[id]/talk/realtime/relay/route.test.ts @@ -83,6 +83,28 @@ describe("POST /api/runtimes/[id]/talk/realtime/relay", () => { }); }); + it("proxies realtime output cancellation through an accessible runtime", async () => { + const realtimeRelayCancelOutput = vi.fn().mockResolvedValue({ ok: true }); + mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" }); + mockGetGatewayClientForRuntime.mockResolvedValue({ realtimeRelayCancelOutput }); + + const response = await POST( + new Request("http://localhost/api/runtimes/rt_1/talk/realtime/relay", { + method: "POST", + body: JSON.stringify({ + action: "cancelOutput", + relaySessionId: "relay_1", + reason: "barge-in", + }), + }), + { params: Promise.resolve({ id: "rt_1" }) }, + ); + + expect(response.status).toBe(200); + await expect(response.json()).resolves.toEqual({ result: { ok: true } }); + expect(realtimeRelayCancelOutput).toHaveBeenCalledWith("relay_1", "barge-in"); + }); + it("rejects invalid relay actions before calling the gateway", async () => { mockRuntimeRows.push({ id: "rt_1", ownerUserId: "user_1" }); diff --git a/src/app/api/runtimes/[id]/talk/realtime/relay/route.ts b/src/app/api/runtimes/[id]/talk/realtime/relay/route.ts index f96ba8c3..8293166b 100644 --- a/src/app/api/runtimes/[id]/talk/realtime/relay/route.ts +++ b/src/app/api/runtimes/[id]/talk/realtime/relay/route.ts @@ -7,7 +7,7 @@ import { getGatewayClientForRuntime } from "@/lib/gateway-chat-pool"; export const dynamic = "force-dynamic"; -type RelayAction = "audio" | "mark" | "toolResult" | "stop"; +type RelayAction = "audio" | "cancelOutput" | "mark" | "toolResult" | "stop"; export async function POST( request: Request, @@ -53,6 +53,14 @@ export async function POST( return NextResponse.json({ result }); } + if (action === "cancelOutput") { + const result = await client.realtimeRelayCancelOutput( + relaySessionId, + readOptionalString(body.reason), + ); + return NextResponse.json({ result }); + } + if (action === "toolResult") { const callId = readRequiredString(body.callId, "callId"); const result = await client.realtimeRelayToolResult({ @@ -88,6 +96,6 @@ function readOptionalString(value: unknown) { } function readRelayAction(value: unknown): RelayAction { - if (value === "audio" || value === "mark" || value === "toolResult" || value === "stop") return value; - throw new ValidationError("action must be audio, mark, toolResult, or stop"); + if (value === "audio" || value === "cancelOutput" || value === "mark" || value === "toolResult" || value === "stop") return value; + throw new ValidationError("action must be audio, cancelOutput, mark, toolResult, or stop"); } diff --git a/src/lib/gateway-client-realtime.test.ts b/src/lib/gateway-client-realtime.test.ts index e39ea9d5..b59f977f 100644 --- a/src/lib/gateway-client-realtime.test.ts +++ b/src/lib/gateway-client-realtime.test.ts @@ -76,4 +76,16 @@ describe("GatewayClient realtime Talk compatibility", () => { expect(rpc).not.toHaveBeenCalled(); }); + + it("maps output cancellation onto the unified session API", async () => { + const client = new GatewayClient("ws://localhost:18789", null, device); + const rpc = vi.spyOn(client, "rpc").mockResolvedValue({ ok: true }); + + await expect(client.realtimeRelayCancelOutput("relay_1", "barge-in")).resolves.toEqual({ ok: true }); + + expect(rpc).toHaveBeenCalledWith("talk.session.cancelOutput", { + sessionId: "relay_1", + reason: "barge-in", + }); + }); }); diff --git a/src/lib/gateway-client.ts b/src/lib/gateway-client.ts index c3c896d7..e8dfea8d 100644 --- a/src/lib/gateway-client.ts +++ b/src/lib/gateway-client.ts @@ -768,6 +768,18 @@ export class GatewayClient { return { ok: true }; } + async realtimeRelayCancelOutput(relaySessionId: string, reason?: string): Promise<{ ok?: boolean }> { + try { + return await this.rpc<{ ok?: boolean }>("talk.session.cancelOutput", withoutUndefined({ + sessionId: relaySessionId, + reason, + })); + } catch (err) { + if (!isLikelyMissingGatewayMethod(err)) throw err; + return { ok: true }; + } + } + async realtimeRelayToolResult(params: GatewayRealtimeRelayToolResultParams): Promise<{ ok?: boolean }> { try { return await this.rpc<{ ok?: boolean }>("talk.session.submitToolResult", withoutUndefined({ diff --git a/src/lib/realtime-voice-client.test.ts b/src/lib/realtime-voice-client.test.ts index 00c54014..05c50e2d 100644 --- a/src/lib/realtime-voice-client.test.ts +++ b/src/lib/realtime-voice-client.test.ts @@ -1,5 +1,6 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { + cancelRealtimeRelayOutput, openRealtimeRelayEvents, sendRealtimeRelayAudio, startRealtimeVoiceSession, @@ -57,6 +58,25 @@ describe("realtime voice client helpers", () => { }); }); + it("sends output cancellation through the runtime relay route", async () => { + const fetchMock = vi.spyOn(globalThis, "fetch").mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ result: { ok: true } }), + } as Response); + + await cancelRealtimeRelayOutput("rt_1", "relay_1", "barge-in"); + + expect(fetchMock).toHaveBeenCalledWith("/api/runtimes/rt_1/talk/realtime/relay", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + action: "cancelOutput", + relaySessionId: "relay_1", + reason: "barge-in", + }), + }); + }); + it("opens relay events for the selected runtime and session", () => { const eventSourceMock = vi.fn(); vi.stubGlobal("EventSource", eventSourceMock); diff --git a/src/lib/realtime-voice-client.ts b/src/lib/realtime-voice-client.ts index c1c6c475..d69989a7 100644 --- a/src/lib/realtime-voice-client.ts +++ b/src/lib/realtime-voice-client.ts @@ -75,6 +75,18 @@ export async function sendRealtimeRelayMark(runtimeId: string, relaySessionId: s }); } +export async function cancelRealtimeRelayOutput( + runtimeId: string, + relaySessionId: string, + reason = "barge-in", +): Promise { + await postRealtimeRelay(runtimeId, { + action: "cancelOutput", + relaySessionId, + reason, + }); +} + export async function sendRealtimeRelayToolResult( runtimeId: string, relaySessionId: string, diff --git a/src/lib/realtime-voice-gateway-relay.ts b/src/lib/realtime-voice-gateway-relay.ts index 7fcc571a..56f798cf 100644 --- a/src/lib/realtime-voice-gateway-relay.ts +++ b/src/lib/realtime-voice-gateway-relay.ts @@ -1,4 +1,5 @@ import { + cancelRealtimeRelayOutput, openRealtimeRelayEvents, sendRealtimeRelayAudio, sendRealtimeRelayMark, @@ -10,6 +11,10 @@ import { base64ToBytes, bytesToBase64, floatToPcm16, pcm16ToFloat, rmsLevel } fr export type RealtimeVoiceStatus = "idle" | "listening" | "processing" | "speaking" | "error"; +const BARGE_IN_RMS_THRESHOLD = 0.02; +const BARGE_IN_PEAK_THRESHOLD = 0.08; +const BARGE_IN_FRAMES = 2; + export interface RealtimeGatewayRelayCallbacks { onStatus?: (status: RealtimeVoiceStatus, message?: string) => void; onTranscript?: (event: { role: "user" | "assistant"; text: string; final: boolean }) => void; @@ -50,6 +55,8 @@ export class RealtimeGatewayRelaySession { private readonly sources = new Set(); private playhead = 0; private closed = false; + private cancelRequestedForPlayback = false; + private speechFramesDuringPlayback = 0; constructor( private readonly runtimeId: string, @@ -128,6 +135,7 @@ export class RealtimeGatewayRelaySession { } this.sources.clear(); this.playhead = this.outputContext?.currentTime ?? 0; + this.speechFramesDuringPlayback = 0; this.callbacks.onSpeakingChange?.(false); } @@ -141,6 +149,7 @@ export class RealtimeGatewayRelaySession { const input = event.inputBuffer.getChannelData(0); this.callbacks.onVoiceLevel?.(rmsLevel(input)); const pcm = floatToPcm16(input); + if (this.detectBargeInSpeech(input)) this.cancelOutputForBargeIn(); void sendRealtimeRelayAudio(this.runtimeId, { relaySessionId: this.session.relaySessionId, audioBase64: bytesToBase64(pcm), @@ -162,7 +171,10 @@ export class RealtimeGatewayRelaySession { this.callbacks.onStatus?.("listening"); return; case "audio": - if (event.audioBase64) this.playPcm16(event.audioBase64); + if (event.audioBase64) { + this.cancelRequestedForPlayback = false; + this.playPcm16(event.audioBase64); + } return; case "clear": this.stopOutput(); @@ -235,4 +247,39 @@ export class RealtimeGatewayRelaySession { name: event.name ?? null, }).catch(() => {}); } + + private cancelOutputForBargeIn(): void { + if (!this.session.relaySessionId || this.cancelRequestedForPlayback) return; + this.cancelRequestedForPlayback = true; + this.stopOutput(); + this.callbacks.onStatus?.("listening", "Interrupted"); + void cancelRealtimeRelayOutput(this.runtimeId, this.session.relaySessionId, "barge-in").catch((error) => { + const message = error instanceof Error ? error.message : String(error); + this.callbacks.onError?.(message); + }); + } + + private detectBargeInSpeech(input: Float32Array): boolean { + if (this.sources.size === 0 || this.cancelRequestedForPlayback || input.length === 0) { + this.speechFramesDuringPlayback = 0; + return false; + } + + let peak = 0; + let sum = 0; + for (const sample of input) { + const abs = Math.abs(sample); + peak = Math.max(peak, abs); + sum += sample * sample; + } + + const rms = Math.sqrt(sum / input.length); + if (rms >= BARGE_IN_RMS_THRESHOLD && peak >= BARGE_IN_PEAK_THRESHOLD) { + this.speechFramesDuringPlayback += 1; + } else { + this.speechFramesDuringPlayback = 0; + } + + return this.speechFramesDuringPlayback >= BARGE_IN_FRAMES; + } }