Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ export class DroidSession {
private _closed = false;
private _cleanupAbortSignal: (() => void) | null = null;
private _cleanupCallbacks: Array<() => Promise<void> | void> = [];
private readonly _activeBridges = new Set<MessageBridge>();

/** @internal */
constructor(
Expand Down Expand Up @@ -188,12 +189,17 @@ export class DroidSession {
throwIfAborted(options?.abortSignal);

const startedAt = Date.now();
const bridge = new MessageBridge(undefined, {
let resolveDone: () => void = () => {};
const donePromise = new Promise<void>((resolve) => {
resolveDone = resolve;
});
const bridge = new MessageBridge(resolveDone, {
includePartialMessages: options?.includePartialMessages,
sessionId: this._sessionId,
startedAt,
outputFormat: options?.outputFormat,
});
this._activeBridges.add(bridge);
const unsubscribe = this._client.onNotification(bridge.notificationHandler);
let resolveAbort: () => void = () => {};
const abortPromise = new Promise<void>((resolve) => {
Expand All @@ -213,6 +219,7 @@ export class DroidSession {
files: options?.files,
outputFormat: options?.outputFormat,
}),
donePromise,
abortPromise,
]);
throwIfAborted(options?.abortSignal);
Expand All @@ -225,6 +232,7 @@ export class DroidSession {
} finally {
cleanupAbortSignal();
unsubscribe();
this._activeBridges.delete(bridge);
}
}

Expand All @@ -240,6 +248,9 @@ export class DroidSession {
this._closed = true;
this._cleanupAbortSignal?.();
this._cleanupAbortSignal = null;
for (const bridge of this._activeBridges) {
bridge.signalDone();
}

try {
await this._client.closeSession({ reason: 'other' }).catch(() => {});
Expand Down
183 changes: 183 additions & 0 deletions tests/session.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ async function expectStreamToThrow(
}).rejects.toThrow(ConnectionError);
}

async function expectToResolveWithin<T>(
promise: Promise<T>,
timeoutMs: number
): Promise<T> {
let timeout: ReturnType<typeof setTimeout> | undefined;
try {
return await Promise.race([
promise,
new Promise<never>((_, reject) => {
timeout = setTimeout(() => {
reject(new Error(`Promise did not resolve within ${timeoutMs}ms`));
}, timeoutMs);
}),
]);
} finally {
if (timeout !== undefined) {
clearTimeout(timeout);
}
}
}

/**
* Set up transport to auto-respond to loadSession.
*/
Expand Down Expand Up @@ -691,6 +712,168 @@ describe('DroidSession', () => {
expect(closeMessage?.['params']).toEqual({ reason: 'other' });
expect(transport.isConnected).toBe(false);
});

it('unblocks an active stream with no completion notification when closed', async () => {
const transport = new InMemoryTransport();
await transport.connect();

let resolveAddSent: () => void = () => {};
const addSent = new Promise<void>((resolve) => {
resolveAddSent = resolve;
});

wireTransportSend(transport, ({ method, id }) => {
if (method === DroidServerMethod.INITIALIZE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(
makeSuccessResponse(id, {
sessionId: 'sess-close-active-stream',
session: {},
settings: {
modelId: 'test-model',
reasoningEffort: 'medium',
},
})
);
});
} else if (method === DroidServerMethod.ADD_USER_MESSAGE) {
resolveAddSent();
queueMicrotask(() => {
transport.injectMessage(makeSuccessResponse(id, {}));
});
} else if (method === DroidServerMethod.CLOSE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(makeSuccessResponse(id, {}));
});
}
});

const session = await createSession({ transport });
const messages: DroidMessage[] = [];
const streamPromise = (async () => {
for await (const msg of session.stream('wait forever')) {
messages.push(msg);
}
})();

await addSent;

const closePromise = session.close();

await expectToResolveWithin(streamPromise, 100);
await closePromise;
expect(messages).toEqual([]);
});

it('unblocks an active stream even when addUserMessage has no response', async () => {
const transport = new InMemoryTransport();
await transport.connect();

let resolveAddSent: () => void = () => {};
const addSent = new Promise<void>((resolve) => {
resolveAddSent = resolve;
});

wireTransportSend(transport, ({ method, id }) => {
if (method === DroidServerMethod.INITIALIZE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(
makeSuccessResponse(id, {
sessionId: 'sess-close-pending-add',
session: {},
settings: {
modelId: 'test-model',
reasoningEffort: 'medium',
},
})
);
});
} else if (method === DroidServerMethod.ADD_USER_MESSAGE) {
resolveAddSent();
} else if (method === DroidServerMethod.CLOSE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(makeSuccessResponse(id, {}));
});
}
});

const session = await createSession({ transport });
const streamPromise = (async () => {
const collected: DroidMessage[] = [];
for await (const msg of session.stream('pending add')) {
collected.push(msg);
}
return collected;
})();

await addSent;

const closePromise = session.close();
const messages = await expectToResolveWithin(streamPromise, 100);
await closePromise;

expect(messages).toEqual([]);
});

it('unblocks all active streams when closed', async () => {
const transport = new InMemoryTransport();
await transport.connect();

let addSentCount = 0;
let resolveBothAddSent: () => void = () => {};
const bothAddSent = new Promise<void>((resolve) => {
resolveBothAddSent = resolve;
});

wireTransportSend(transport, ({ method, id }) => {
if (method === DroidServerMethod.INITIALIZE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(
makeSuccessResponse(id, {
sessionId: 'sess-close-all-active-streams',
session: {},
settings: {
modelId: 'test-model',
reasoningEffort: 'medium',
},
})
);
});
} else if (method === DroidServerMethod.ADD_USER_MESSAGE) {
addSentCount += 1;
if (addSentCount === 2) {
resolveBothAddSent();
}
} else if (method === DroidServerMethod.CLOSE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(makeSuccessResponse(id, {}));
});
}
});

const session = await createSession({ transport });
const collect = async (prompt: string): Promise<DroidMessage[]> => {
const messages: DroidMessage[] = [];
for await (const msg of session.stream(prompt)) {
messages.push(msg);
}
return messages;
};
const firstStream = collect('first pending add');
const secondStream = collect('second pending add');

await bothAddSent;

const closePromise = session.close();
const [firstMessages, secondMessages] = await expectToResolveWithin(
Promise.all([firstStream, secondStream]),
100
);
await closePromise;

expect(firstMessages).toEqual([]);
expect(secondMessages).toEqual([]);
});
});

describe('MCP methods (VAL-API-011)', () => {
Expand Down
Loading