From 070bb09aa3322c49300d7438a357de99e6d2d51b Mon Sep 17 00:00:00 2001 From: User Date: Wed, 20 May 2026 15:40:22 -0700 Subject: [PATCH] fix: unblock active streams on session close Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com> --- src/session.ts | 13 ++- tests/session.test.ts | 183 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+), 1 deletion(-) diff --git a/src/session.ts b/src/session.ts index 90407af..8295522 100644 --- a/src/session.ts +++ b/src/session.ts @@ -140,6 +140,7 @@ export class DroidSession { private _closed = false; private _cleanupAbortSignal: (() => void) | null = null; private _cleanupCallbacks: Array<() => Promise | void> = []; + private readonly _activeBridges = new Set(); /** @internal */ constructor( @@ -188,12 +189,17 @@ export class DroidSession { throwIfAborted(options?.abortSignal); const startedAt = Date.now(); - const bridge = new MessageBridge(undefined, { + let resolveDone: () => void = () => {}; + const donePromise = new Promise((resolve) => { + resolveDone = resolve; + }); + const bridge = new MessageBridge(resolveDone, { includePartialMessages: options?.includePartialMessages, sessionId: this._sessionId, startedAt, outputFormat: options?.outputFormat, }); + this._activeBridges.add(bridge); const unsubscribe = this._client.onNotification(bridge.notificationHandler); let resolveAbort: () => void = () => {}; const abortPromise = new Promise((resolve) => { @@ -213,6 +219,7 @@ export class DroidSession { files: options?.files, outputFormat: options?.outputFormat, }), + donePromise, abortPromise, ]); throwIfAborted(options?.abortSignal); @@ -225,6 +232,7 @@ export class DroidSession { } finally { cleanupAbortSignal(); unsubscribe(); + this._activeBridges.delete(bridge); } } @@ -240,6 +248,9 @@ export class DroidSession { this._closed = true; this._cleanupAbortSignal?.(); this._cleanupAbortSignal = null; + for (const bridge of this._activeBridges) { + bridge.signalDone(); + } try { await this._client.closeSession({ reason: 'other' }).catch(() => {}); diff --git a/tests/session.test.ts b/tests/session.test.ts index 6dd559d..15fafc0 100644 --- a/tests/session.test.ts +++ b/tests/session.test.ts @@ -73,6 +73,27 @@ async function expectStreamToThrow( }).rejects.toThrow(ConnectionError); } +async function expectToResolveWithin( + promise: Promise, + timeoutMs: number +): Promise { + let timeout: ReturnType | undefined; + try { + return await Promise.race([ + promise, + new Promise((_, reject) => { + timeout = setTimeout(() => { + reject(new Error(`Promise did not resolve within ${timeoutMs}ms`)); + }, timeoutMs); + }), + ]); + } finally { + if (timeout !== undefined) { + clearTimeout(timeout); + } + } +} + /** * Set up transport to auto-respond to loadSession. */ @@ -691,6 +712,168 @@ describe('DroidSession', () => { expect(closeMessage?.['params']).toEqual({ reason: 'other' }); expect(transport.isConnected).toBe(false); }); + + it('unblocks an active stream with no completion notification when closed', async () => { + const transport = new InMemoryTransport(); + await transport.connect(); + + let resolveAddSent: () => void = () => {}; + const addSent = new Promise((resolve) => { + resolveAddSent = resolve; + }); + + wireTransportSend(transport, ({ method, id }) => { + if (method === DroidServerMethod.INITIALIZE_SESSION) { + queueMicrotask(() => { + transport.injectMessage( + makeSuccessResponse(id, { + sessionId: 'sess-close-active-stream', + session: {}, + settings: { + modelId: 'test-model', + reasoningEffort: 'medium', + }, + }) + ); + }); + } else if (method === DroidServerMethod.ADD_USER_MESSAGE) { + resolveAddSent(); + queueMicrotask(() => { + transport.injectMessage(makeSuccessResponse(id, {})); + }); + } else if (method === DroidServerMethod.CLOSE_SESSION) { + queueMicrotask(() => { + transport.injectMessage(makeSuccessResponse(id, {})); + }); + } + }); + + const session = await createSession({ transport }); + const messages: DroidMessage[] = []; + const streamPromise = (async () => { + for await (const msg of session.stream('wait forever')) { + messages.push(msg); + } + })(); + + await addSent; + + const closePromise = session.close(); + + await expectToResolveWithin(streamPromise, 100); + await closePromise; + expect(messages).toEqual([]); + }); + + it('unblocks an active stream even when addUserMessage has no response', async () => { + const transport = new InMemoryTransport(); + await transport.connect(); + + let resolveAddSent: () => void = () => {}; + const addSent = new Promise((resolve) => { + resolveAddSent = resolve; + }); + + wireTransportSend(transport, ({ method, id }) => { + if (method === DroidServerMethod.INITIALIZE_SESSION) { + queueMicrotask(() => { + transport.injectMessage( + makeSuccessResponse(id, { + sessionId: 'sess-close-pending-add', + session: {}, + settings: { + modelId: 'test-model', + reasoningEffort: 'medium', + }, + }) + ); + }); + } else if (method === DroidServerMethod.ADD_USER_MESSAGE) { + resolveAddSent(); + } else if (method === DroidServerMethod.CLOSE_SESSION) { + queueMicrotask(() => { + transport.injectMessage(makeSuccessResponse(id, {})); + }); + } + }); + + const session = await createSession({ transport }); + const streamPromise = (async () => { + const collected: DroidMessage[] = []; + for await (const msg of session.stream('pending add')) { + collected.push(msg); + } + return collected; + })(); + + await addSent; + + const closePromise = session.close(); + const messages = await expectToResolveWithin(streamPromise, 100); + await closePromise; + + expect(messages).toEqual([]); + }); + + it('unblocks all active streams when closed', async () => { + const transport = new InMemoryTransport(); + await transport.connect(); + + let addSentCount = 0; + let resolveBothAddSent: () => void = () => {}; + const bothAddSent = new Promise((resolve) => { + resolveBothAddSent = resolve; + }); + + wireTransportSend(transport, ({ method, id }) => { + if (method === DroidServerMethod.INITIALIZE_SESSION) { + queueMicrotask(() => { + transport.injectMessage( + makeSuccessResponse(id, { + sessionId: 'sess-close-all-active-streams', + session: {}, + settings: { + modelId: 'test-model', + reasoningEffort: 'medium', + }, + }) + ); + }); + } else if (method === DroidServerMethod.ADD_USER_MESSAGE) { + addSentCount += 1; + if (addSentCount === 2) { + resolveBothAddSent(); + } + } else if (method === DroidServerMethod.CLOSE_SESSION) { + queueMicrotask(() => { + transport.injectMessage(makeSuccessResponse(id, {})); + }); + } + }); + + const session = await createSession({ transport }); + const collect = async (prompt: string): Promise => { + const messages: DroidMessage[] = []; + for await (const msg of session.stream(prompt)) { + messages.push(msg); + } + return messages; + }; + const firstStream = collect('first pending add'); + const secondStream = collect('second pending add'); + + await bothAddSent; + + const closePromise = session.close(); + const [firstMessages, secondMessages] = await expectToResolveWithin( + Promise.all([firstStream, secondStream]), + 100 + ); + await closePromise; + + expect(firstMessages).toEqual([]); + expect(secondMessages).toEqual([]); + }); }); describe('MCP methods (VAL-API-011)', () => {