diff --git a/.changeset/extract-task-manager.md b/.changeset/extract-task-manager.md new file mode 100644 index 000000000..c70a7a6e1 --- /dev/null +++ b/.changeset/extract-task-manager.md @@ -0,0 +1,14 @@ +--- +"@modelcontextprotocol/core": minor +"@modelcontextprotocol/client": minor +"@modelcontextprotocol/server": minor +--- + +refactor: extract task orchestration from Protocol into TaskManager + +**Breaking changes:** +- `extra.taskId` → `extra.task?.taskId` +- `extra.taskStore` → `extra.task?.taskStore` +- `extra.taskRequestedTtl` → `extra.task?.requestedTtl` +- `ProtocolOptions` no longer accepts `taskStore`/`taskMessageQueue` — pass via `TaskManagerOptions` in `ClientOptions`/`ServerOptions` +- Abstract methods `assertTaskCapability`/`assertTaskHandlerCapability` removed from Protocol diff --git a/examples/client/src/simpleStreamableHttp.ts b/examples/client/src/simpleStreamableHttp.ts index ac8491584..f22d16ba4 100644 --- a/examples/client/src/simpleStreamableHttp.ts +++ b/examples/client/src/simpleStreamableHttp.ts @@ -265,14 +265,14 @@ async function connect(url?: string): Promise { form: {} }, tasks: { + taskStore: clientTaskStore, requests: { elicitation: { create: {} } } } - }, - taskStore: clientTaskStore + } } ); client.onerror = error => { diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index be025c04c..1263f4bb5 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -41,9 +41,14 @@ const getServer = () => { websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' }, { - capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } }, - taskStore, // Enable task support - taskMessageQueue: new InMemoryTaskMessageQueue() + capabilities: { + logging: {}, + tasks: { + requests: { tools: { call: {} } }, + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + } + } } ); diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index edb08ee58..716e9244f 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -30,6 +30,7 @@ import type { ResultTypeMap, ServerCapabilities, SubscribeRequest, + TaskManagerOptions, Tool, Transport, UnsubscribeRequest @@ -61,7 +62,8 @@ import { ProtocolErrorCode, ReadResourceResultSchema, SdkError, - SdkErrorCode + SdkErrorCode, + TaskManager } from '@modelcontextprotocol/core'; import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; @@ -140,11 +142,20 @@ export function getSupportedElicitationModes(capabilities: ClientCapabilities['e return { supportsFormMode, supportsUrlMode }; } +/** + * Extended tasks capability that includes runtime configuration (store, messageQueue). + * The runtime-only fields are stripped before advertising capabilities to servers. + */ +export type ClientTasksCapabilityWithRuntime = NonNullable & + Pick; + export type ClientOptions = ProtocolOptions & { /** * Capabilities to advertise as being supported by this client. */ - capabilities?: ClientCapabilities; + capabilities?: Omit & { + tasks?: ClientTasksCapabilityWithRuntime; + }; /** * JSON Schema validator for tool output validation. @@ -204,6 +215,7 @@ export class Client extends Protocol { private _listChangedDebounceTimers: Map> = new Map(); private _pendingListChangedConfig?: ListChangedHandlers; private _enforceStrictCapabilities: boolean; + private _taskModule?: TaskManager; /** * Initializes this client with the given name and version information. @@ -213,16 +225,39 @@ export class Client extends Protocol { options?: ClientOptions ) { super(options); - this._capabilities = options?.capabilities ?? {}; + this._capabilities = options?.capabilities ? { ...options.capabilities } : {}; this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); this._enforceStrictCapabilities = options?.enforceStrictCapabilities ?? false; + // If tasks capability is declared, create and register the task module + if (options?.capabilities?.tasks) { + const { taskStore, taskMessageQueue, ...wireCapabilities } = options.capabilities.tasks; + // Strip runtime-only config from advertised capabilities + this._capabilities.tasks = wireCapabilities; + this._taskModule = new TaskManager({ + taskStore, + taskMessageQueue, + enforceStrictCapabilities: options?.enforceStrictCapabilities, + assertTaskCapability: method => assertToolsCallTaskCapability(this._serverCapabilities?.tasks?.requests, method, 'Server'), + assertTaskHandlerCapability: method => + assertClientRequestTaskCapability(this._capabilities.tasks?.requests, method, 'Client') + }); + this.registerModule(this._taskModule); + } + // Store list changed config for setup after connection (when we know server capabilities) if (options?.listChanged) { this._pendingListChangedConfig = options.listChanged; } } + /** + * Access the task module, if tasks capability is configured. + */ + get taskModule(): TaskManager | undefined { + return this._taskModule; + } + protected override buildContext(ctx: BaseContext, _transportInfo?: MessageExtraInfo): ClientContext { return ctx; } @@ -635,12 +670,6 @@ export class Client extends Protocol { } protected assertRequestHandlerCapability(method: string): void { - // Task handlers are registered in Protocol constructor before _capabilities is initialized - // Skip capability check for task methods during initialization - if (!this._capabilities) { - return; - } - switch (method) { case 'sampling/createMessage': { if (!this._capabilities.sampling) { @@ -672,19 +701,6 @@ export class Client extends Protocol { break; } - case 'tasks/get': - case 'tasks/list': - case 'tasks/result': - case 'tasks/cancel': { - if (!this._capabilities.tasks) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `Client does not support tasks capability (required for ${method})` - ); - } - break; - } - case 'ping': { // No specific capability required for ping break; @@ -692,21 +708,6 @@ export class Client extends Protocol { } } - protected assertTaskCapability(method: string): void { - assertToolsCallTaskCapability(this._serverCapabilities?.tasks?.requests, method, 'Server'); - } - - protected assertTaskHandlerCapability(method: string): void { - // Task handlers are registered in Protocol constructor before _capabilities is initialized - // Skip capability check for task methods during initialization - if (!this._capabilities) { - return; - } - - assertClientRequestTaskCapability(this._capabilities.tasks?.requests, method, 'Client'); - } - - /** Sends a ping to the server to check connectivity. */ async ping(options?: RequestOptions) { return this._requestWithSchema({ method: 'ping' }, EmptyResultSchema, options); } diff --git a/packages/client/src/experimental/tasks/client.ts b/packages/client/src/experimental/tasks/client.ts index 696862c9c..736d89ff4 100644 --- a/packages/client/src/experimental/tasks/client.ts +++ b/packages/client/src/experimental/tasks/client.ts @@ -13,13 +13,14 @@ import type { CreateTaskResult, GetTaskResult, ListTasksResult, + Request, RequestMethod, RequestOptions, ResponseMessage, ResultTypeMap, SchemaOutput } from '@modelcontextprotocol/core'; -import { ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core'; +import { CallToolResultSchema, getResultSchema, ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core'; import type { Client } from '../../client/client.js'; @@ -28,10 +29,6 @@ import type { Client } from '../../client/client.js'; * @internal */ interface ClientInternal { - requestStream( - request: { method: M; params?: Record }, - options?: RequestOptions - ): AsyncGenerator, void, void>; isToolTask(toolName: string): boolean; getToolOutputValidator(toolName: string): ((data: unknown) => { valid: boolean; errorMessage?: string }) | undefined; } @@ -50,6 +47,14 @@ interface ClientInternal { export class ExperimentalClientTasks { constructor(private readonly _client: Client) {} + private get _module() { + const module = this._client.taskModule; + if (!module) { + throw new Error('Tasks capability is not configured. Declare tasks in capabilities to use task features.'); + } + return module; + } + /** * Calls a tool and returns an AsyncGenerator that yields response messages. * The generator is guaranteed to end with either a `'result'` or `'error'` message. @@ -104,7 +109,7 @@ export class ExperimentalClientTasks { task: options?.task ?? (clientInternal.isToolTask(params.name) ? {} : undefined) }; - const stream = clientInternal.requestStream({ method: 'tools/call', params }, optionsWithTask); + const stream = this._module.requestStream({ method: 'tools/call', params }, CallToolResultSchema, optionsWithTask); // Get the validator for this tool (if it has an output schema) const validator = clientInternal.getToolOutputValidator(params.name); @@ -176,9 +181,7 @@ export class ExperimentalClientTasks { * @experimental */ async getTask(taskId: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - type ClientWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise }; - return (this._client as unknown as ClientWithGetTask).getTask({ taskId }, options); + return this._module.getTask({ taskId }, options); } /** @@ -192,16 +195,7 @@ export class ExperimentalClientTasks { * @experimental */ async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - getTaskResult: ( - params: { taskId: string }, - resultSchema?: U, - options?: RequestOptions - ) => Promise>; - } - ).getTaskResult({ taskId }, resultSchema, options); + return this._module.getTaskResult({ taskId }, resultSchema!, options); } /** @@ -214,12 +208,7 @@ export class ExperimentalClientTasks { * @experimental */ async listTasks(cursor?: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise; - } - ).listTasks(cursor ? { cursor } : undefined, options); + return this._module.listTasks(cursor ? { cursor } : undefined, options); } /** @@ -231,12 +220,7 @@ export class ExperimentalClientTasks { * @experimental */ async cancelTask(taskId: string, options?: RequestOptions): Promise { - // Delegate to the client's underlying Protocol method - return ( - this._client as unknown as { - cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise; - } - ).cancelTask({ taskId }, options); + return this._module.cancelTask({ taskId }, options); } /** @@ -281,7 +265,11 @@ export class ExperimentalClientTasks { request: { method: M; params?: Record }, options?: RequestOptions ): AsyncGenerator, void, void> { - // Delegate to the client's underlying Protocol method - return (this._client as unknown as ClientInternal).requestStream(request, options); + const resultSchema = getResultSchema(request.method) as unknown as AnyObjectSchema; + return this._module.requestStream(request as Request, resultSchema, options) as AsyncGenerator< + ResponseMessage, + void, + void + >; } } diff --git a/packages/core/src/experimental/tasks/helpers.ts b/packages/core/src/experimental/tasks/helpers.ts index 0d3fce84d..98440be43 100644 --- a/packages/core/src/experimental/tasks/helpers.ts +++ b/packages/core/src/experimental/tasks/helpers.ts @@ -17,7 +17,7 @@ interface TaskRequestsCapability { /** * Asserts that task creation is supported for `tools/call`. - * Used by {@linkcode @modelcontextprotocol/client!client/client.Client.assertTaskCapability | Client.assertTaskCapability} and {@linkcode @modelcontextprotocol/server!server/server.Server.assertTaskHandlerCapability | Server.assertTaskHandlerCapability}. + * Used as the `assertTaskCapability` or `assertTaskHandlerCapability` callback in `TaskManagerOptions`. * * @param requests - The task requests capability object * @param method - The method being checked @@ -52,7 +52,7 @@ export function assertToolsCallTaskCapability( /** * Asserts that task creation is supported for `sampling/createMessage` or `elicitation/create`. - * Used by {@linkcode @modelcontextprotocol/server!server/server.Server.assertTaskCapability | Server.assertTaskCapability} and {@linkcode @modelcontextprotocol/client!client/client.Client.assertTaskHandlerCapability | Client.assertTaskHandlerCapability}. + * Used as the `assertTaskCapability` or `assertTaskHandlerCapability` callback in `TaskManagerOptions`. * * @param requests - The task requests capability object * @param method - The method being checked diff --git a/packages/core/src/experimental/tasks/interfaces.ts b/packages/core/src/experimental/tasks/interfaces.ts index 8b3459b7c..4522a4120 100644 --- a/packages/core/src/experimental/tasks/interfaces.ts +++ b/packages/core/src/experimental/tasks/interfaces.ts @@ -3,7 +3,8 @@ * WARNING: These APIs are experimental and may change without notice. */ -import type { RequestTaskStore, ServerContext } from '../../shared/protocol.js'; +import type { ServerContext } from '../../shared/protocol.js'; +import type { RequestTaskStore } from '../../shared/taskManager.js'; import type { JSONRPCErrorResponse, JSONRPCNotification, diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 56769c575..33b860d6f 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -6,6 +6,8 @@ export * from './shared/metadataUtils.js'; export * from './shared/protocol.js'; export * from './shared/responseMessage.js'; export * from './shared/stdio.js'; +export type { RequestTaskStore, TaskContext, TaskManagerOptions, TaskRequestOptions } from './shared/taskManager.js'; +export { TaskManager } from './shared/taskManager.js'; export * from './shared/toolNameValidation.js'; export * from './shared/transport.js'; export * from './shared/uriTemplate.js'; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index b82731582..69ac123a5 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1,6 +1,4 @@ import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; -import type { CreateTaskOptions, QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; -import { isTerminal } from '../experimental/tasks/interfaces.js'; import type { AuthInfo, CancelledNotification, @@ -11,9 +9,6 @@ import type { ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, - GetTaskPayloadRequest, - GetTaskRequest, - GetTaskResult, JSONRPCErrorResponse, JSONRPCNotification, JSONRPCRequest, @@ -36,32 +31,24 @@ import type { Result, ResultTypeMap, ServerCapabilities, - Task, - TaskCreationParams, - TaskStatusNotification + TaskCreationParams } from '../types/types.js'; import { - CancelTaskResultSchema, - CreateTaskResultSchema, getNotificationSchema, getRequestSchema, getResultSchema, - GetTaskResultSchema, isJSONRPCErrorResponse, isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResultResponse, - isTaskAugmentedRequestParams, - ListTasksResultSchema, ProtocolError, ProtocolErrorCode, - RELATED_TASK_META_KEY, - SUPPORTED_PROTOCOL_VERSIONS, - TaskStatusNotificationSchema + SUPPORTED_PROTOCOL_VERSIONS } from '../types/types.js'; -import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/schema.js'; +import type { AnySchema, SchemaOutput } from '../util/schema.js'; import { parseSchema } from '../util/schema.js'; -import type { ResponseMessage } from './responseMessage.js'; +import type { ProtocolModule, ProtocolModuleHost } from './protocolModule.js'; +import type { TaskContext, TaskRequestOptions } from './taskManager.js'; import type { Transport, TransportSendOptions } from './transport.js'; /** @@ -96,29 +83,6 @@ export type ProtocolOptions = { * e.g., `['notifications/tools/list_changed']` */ debouncedNotificationMethods?: string[]; - /** - * Optional task storage implementation. If provided, enables task-related request handlers - * and provides task storage capabilities to request handlers. - */ - taskStore?: TaskStore; - /** - * Optional task message queue implementation for managing server-initiated messages - * that will be delivered through the tasks/result response stream. - */ - taskMessageQueue?: TaskMessageQueue; - /** - * Default polling interval (in milliseconds) for task status checks when no `pollInterval` - * is provided by the server. Defaults to 5000ms if not specified. - */ - defaultTaskPollInterval?: number; - /** - * Maximum number of messages that can be queued per task for side-channel delivery. - * If undefined, the queue size is unbounded. - * When the limit is exceeded, the {@linkcode TaskMessageQueue} implementation's {@linkcode TaskMessageQueue.enqueue | enqueue()} method - * will throw an error. It's the implementation's responsibility to handle overflow - * appropriately (e.g., by failing the task, dropping messages, etc.). - */ - maxTaskQueueSize?: number; }; /** @@ -189,78 +153,6 @@ export type NotificationOptions = { relatedTask?: RelatedTaskMetadata; }; -/** - * Options that can be given per request. - */ -// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. -export type TaskRequestOptions = Omit; - -/** - * Request-scoped {@linkcode TaskStore} interface. - */ -export interface RequestTaskStore { - /** - * Creates a new task with the given creation parameters. - * The implementation generates a unique `taskId` and `createdAt` timestamp. - * - * @param taskParams - The task creation parameters from the request - * @returns The created {@linkcode Task} object - */ - createTask(taskParams: CreateTaskOptions): Promise; - - /** - * Gets the current status of a task. - * - * @param taskId - The task identifier - * @returns The {@linkcode Task} object - * @throws If the task does not exist - */ - getTask(taskId: string): Promise; - - /** - * Stores the result of a task and sets its final status. - * - * @param taskId - The task identifier - * @param status - The final status: `'completed'` for success, `'failed'` for errors - * @param result - The result to store - */ - storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; - - /** - * Retrieves the stored result of a task. - * - * @param taskId - The task identifier - * @returns The stored result - */ - getTaskResult(taskId: string): Promise; - - /** - * Updates a task's status (e.g., to `'cancelled'`, `'failed'`, `'completed'`). - * - * @param taskId - The task identifier - * @param status - The new status - * @param statusMessage - Optional diagnostic message for failed tasks or other status information - */ - updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @param cursor - Optional cursor for pagination - * @returns An object containing the `tasks` array and an optional `nextCursor` - */ - listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; -} - -/** - * Task context provided to request handlers when task storage is configured. - */ -export type TaskContext = { - id?: string; - store: RequestTaskStore; - requestedTtl?: number | null; -}; - /** * Base context provided to all request handlers. */ @@ -390,6 +282,10 @@ type TimeoutInfo = { onTimeout: () => void; }; +async function _noopRouteResponse(): Promise { + return false; +} + /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. @@ -405,13 +301,7 @@ export abstract class Protocol { private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); - // Maps task IDs to progress tokens to keep handlers alive after CreateTaskResult - private _taskProgressTokens: Map = new Map(); - - private _taskStore?: TaskStore; - private _taskMessageQueue?: TaskMessageQueue; - - private _requestResolvers: Map void> = new Map(); + private _modules: ProtocolModule[] = []; protected _supportedProtocolVersions: string[]; @@ -455,175 +345,27 @@ export abstract class Protocol { // Automatic pong by default. _request => ({}) as Result ); + } - // Install task handlers if TaskStore is provided - this._taskStore = _options?.taskStore; - this._taskMessageQueue = _options?.taskMessageQueue; - if (this._taskStore) { - this.setRequestHandler('tasks/get', async (request, ctx) => { - const task = await this._taskStore!.getTask(request.params.taskId, ctx.sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - - // Per spec: tasks/get responses SHALL NOT include related-task metadata - // as the taskId parameter is the source of truth - return { - ...task - } as Result; - }); - - this.setRequestHandler('tasks/result', async (request, ctx) => { - const handleTaskResult = async (): Promise => { - const taskId = request.params.taskId; - - // Deliver queued messages - if (this._taskMessageQueue) { - let queuedMessage: QueuedMessage | undefined; - while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, ctx.sessionId))) { - // Handle response and error messages by routing them to the appropriate resolver - if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { - const message = queuedMessage.message; - const requestId = message.id; - - // Lookup resolver in _requestResolvers map - const resolver = this._requestResolvers.get(requestId as RequestId); - - if (resolver) { - // Remove resolver from map after invocation - this._requestResolvers.delete(requestId as RequestId); - - // Invoke resolver with response or error - if (queuedMessage.type === 'response') { - resolver(message as JSONRPCResultResponse); - } else { - // Convert JSONRPCError to ProtocolError - const errorMessage = message as JSONRPCErrorResponse; - const error = new ProtocolError( - errorMessage.error.code, - errorMessage.error.message, - errorMessage.error.data - ); - resolver(error); - } - } else { - // Handle missing resolver gracefully with error logging - const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; - this._onerror(new Error(`${messageType} handler missing for request ${requestId}`)); - } - - // Continue to next message - continue; - } - - // Send the message on the response stream by passing the relatedRequestId - // This tells the transport to write the message to the tasks/result response stream - await this._transport?.send(queuedMessage.message, { relatedRequestId: ctx.mcpReq.id }); - } - } - - // Now check task status - const task = await this._taskStore!.getTask(taskId, ctx.sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); - } - - // Block if task is not terminal (we've already delivered all queued messages above) - if (!isTerminal(task.status)) { - // Wait for status change or new messages - await this._waitForTaskUpdate(taskId, ctx.mcpReq.signal); - - // After waking up, recursively call to deliver any new messages or result - return await handleTaskResult(); - } - - // If task is terminal, return the result - if (isTerminal(task.status)) { - const result = await this._taskStore!.getTaskResult(taskId, ctx.sessionId); - - this._clearTaskQueue(taskId); - - return { - ...result, - _meta: { - ...result._meta, - [RELATED_TASK_META_KEY]: { - taskId: taskId - } - } - } as Result; - } - - return await handleTaskResult(); - }; - - return await handleTaskResult(); - }); - - this.setRequestHandler('tasks/list', async (request, ctx) => { - try { - const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, ctx.sessionId); - return { - tasks, - nextCursor, - _meta: {} - } as Result; - } catch (error) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` - ); - } - }); - - this.setRequestHandler('tasks/cancel', async (request, ctx) => { - try { - // Get the current task to check if it's in a terminal state, in case the implementation is not atomic - const task = await this._taskStore!.getTask(request.params.taskId, ctx.sessionId); - - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); - } - - // Reject cancellation of terminal tasks - if (isTerminal(task.status)) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); - } - - await this._taskStore!.updateTaskStatus( - request.params.taskId, - 'cancelled', - 'Client cancelled task execution.', - ctx.sessionId - ); - - this._clearTaskQueue(request.params.taskId); - - const cancelledTask = await this._taskStore!.getTask(request.params.taskId, ctx.sessionId); - if (!cancelledTask) { - // Task was deleted during cancellation (e.g., cleanup happened) - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Task not found after cancellation: ${request.params.taskId}` - ); - } - - return { - _meta: {}, - ...cancelledTask - } as Result; - } catch (error) { - // Re-throw ProtocolError as-is - if (error instanceof ProtocolError) { - throw error; - } - throw new ProtocolError( - ProtocolErrorCode.InvalidRequest, - `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` - ); - } - }); - } + /** + * Registers a ProtocolModule that hooks into the message lifecycle. + * The module is bound to this Protocol and can register handlers, send messages, etc. + */ + protected registerModule(module: ProtocolModule): void { + this._modules.push(module); + const host: ProtocolModuleHost = { + request: (request, resultSchema, options) => this._requestWithSchema(request, resultSchema, options), + notification: (notification, options) => this.notification(notification, options), + reportError: error => this._onerror(error), + removeProgressHandler: token => this._progressHandlers.delete(token), + registerHandler: (method, handler) => { + this._requestHandlers.set(method, (request, ctx) => handler(request, ctx)); + }, + sendOnResponseStream: async (message, relatedRequestId) => { + await this._transport?.send(message, { relatedRequestId }); + } + }; + module.bind(host); } /** @@ -727,7 +469,9 @@ export abstract class Protocol { const responseHandlers = this._responseHandlers; this._responseHandlers = new Map(); this._progressHandlers.clear(); - this._taskProgressTokens.clear(); + for (const module of this._modules) { + module.onClose(); + } this._pendingDebouncedNotifications.clear(); const error = new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed'); @@ -764,8 +508,53 @@ export abstract class Protocol { // Capture the current transport at request time to ensure responses go to the correct client const capturedTransport = this._transport; - // Extract taskId from request metadata if present (needed early for method not found case) - const relatedTaskId = request.params?._meta?.[RELATED_TASK_META_KEY]?.taskId; + // Delegate context extraction to module (if registered) + const inboundCtx = { + sessionId: capturedTransport?.sessionId, + sendNotification: (notification: Notification, options?: NotificationOptions) => + this.notification(notification, { ...options, relatedRequestId: request.id }), + sendRequest: (r: Request, resultSchema: U, options?: RequestOptions) => + this._requestWithSchema(r, resultSchema, { ...options, relatedRequestId: request.id }) + }; + + // Compose results from all modules + let sendNotification: (notification: Notification) => Promise = (notification: Notification) => + inboundCtx.sendNotification(notification); + let sendRequest = inboundCtx.sendRequest; + let routeResponse: (message: JSONRPCResponse | JSONRPCErrorResponse) => Promise = _noopRouteResponse; + let taskContext: BaseContext['task'] | undefined; + let hasTaskCreationParams = false; + const validators: Array<() => void> = []; + + for (const module of this._modules) { + const moduleResult = module.processInboundRequest(request, inboundCtx); + + // Chain sendNotification/sendRequest wrappers + sendNotification = moduleResult.sendNotification; + sendRequest = moduleResult.sendRequest; + + // Last non-undefined taskContext wins + if (moduleResult.taskContext !== undefined) { + taskContext = moduleResult.taskContext; + } + + // OR for hasTaskCreationParams + hasTaskCreationParams = hasTaskCreationParams || moduleResult.hasTaskCreationParams; + + // Collect deferred validations (e.g., assertTaskHandlerCapability) to run + // inside the async handler chain so errors produce proper JSON-RPC error responses. + if (moduleResult.validateInbound) { + validators.push(moduleResult.validateInbound); + } + + // Compose routeResponse as OR-chain (first returning true wins) + const prevRouteResponse = routeResponse; + const moduleRouteResponse = moduleResult.routeResponse; + routeResponse = async (message: JSONRPCResponse | JSONRPCErrorResponse) => { + if (await prevRouteResponse(message)) return true; + return moduleRouteResponse(message); + }; + } if (handler === undefined) { const errorResponse: JSONRPCErrorResponse = { @@ -778,34 +567,21 @@ export abstract class Protocol { }; // Queue or send the error response based on whether this is a task-related request - if (relatedTaskId && this._taskMessageQueue) { - this._enqueueTaskMessage( - relatedTaskId, - { - type: 'error', - message: errorResponse, - timestamp: Date.now() - }, - capturedTransport?.sessionId - ).catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); - } else { - capturedTransport - ?.send(errorResponse) - .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); - } + routeResponse(errorResponse) + .then(routed => { + if (!routed) { + capturedTransport + ?.send(errorResponse) + .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); + } + }) + .catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); return; } const abortController = new AbortController(); this._requestHandlerAbortControllers.set(request.id, abortController); - const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; - const taskStore = this._taskStore ? this.requestTaskStore(request, capturedTransport?.sessionId) : undefined; - - const task: TaskContext | undefined = taskStore - ? { id: relatedTaskId, store: taskStore, requestedTtl: taskCreationParams?.ttl } - : undefined; - const baseCtx: BaseContext = { sessionId: capturedTransport?.sessionId, mcpReq: { @@ -813,37 +589,22 @@ export abstract class Protocol { method: request.method, _meta: request.params?._meta, signal: abortController.signal, - send: async (r, options?) => { - const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } - return await this.request(r, requestOptions); + send: (r: { method: M; params?: Record }, options?: TaskRequestOptions) => { + const resultSchema = getResultSchema(r.method); + return sendRequest(r as Request, resultSchema, options) as Promise; }, - notify: async notification => { - const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; - if (relatedTaskId) { - notificationOptions.relatedTask = { taskId: relatedTaskId }; - } - await this.notification(notification, notificationOptions); - } + notify: sendNotification }, http: extra?.authInfo ? { authInfo: extra.authInfo } : undefined, - task + task: taskContext }; const ctx = this.buildContext(baseCtx, extra); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() .then(() => { - // If this request asked for task creation, check capability first - if (taskCreationParams) { - // Check if the request method supports task creation - this.assertTaskHandlerCapability(request.method); + for (const validate of validators) { + validate(); } }) .then(() => handler(request, ctx)) @@ -861,17 +622,10 @@ export abstract class Protocol { }; // Queue or send the response based on whether this is a task-related request - await (relatedTaskId && this._taskMessageQueue - ? this._enqueueTaskMessage( - relatedTaskId, - { - type: 'response', - message: response, - timestamp: Date.now() - }, - capturedTransport?.sessionId - ) - : capturedTransport?.send(response)); + const routed = await routeResponse(response); + if (!routed) { + await capturedTransport?.send(response); + } }, async error => { if (abortController.signal.aborted) { @@ -890,17 +644,10 @@ export abstract class Protocol { }; // Queue or send the error response based on whether this is a task-related request - await (relatedTaskId && this._taskMessageQueue - ? this._enqueueTaskMessage( - relatedTaskId, - { - type: 'error', - message: errorResponse, - timestamp: Date.now() - }, - capturedTransport?.sessionId - ) - : capturedTransport?.send(errorResponse)); + const routed = await routeResponse(errorResponse); + if (!routed) { + await capturedTransport?.send(errorResponse); + } } ) .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) @@ -941,17 +688,15 @@ export abstract class Protocol { private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void { const messageId = Number(response.id); - // Check if this is a response to a queued request - const resolver = this._requestResolvers.get(messageId); - if (resolver) { - this._requestResolvers.delete(messageId); - if (isJSONRPCResultResponse(response)) { - resolver(response); - } else { - const error = new ProtocolError(response.error.code, response.error.message, response.error.data); - resolver(error); + // Delegate to modules for task-related response handling + let preserveProgress = false; + for (const module of this._modules) { + const moduleResult = module.processInboundResponse(response, messageId); + if (moduleResult.consumed) { + return; } - return; + // OR preserveProgress across non-consuming modules + preserveProgress = preserveProgress || moduleResult.preserveProgress; } const handler = this._responseHandlers.get(messageId); @@ -964,19 +709,7 @@ export abstract class Protocol { this._cleanupTimeout(messageId); // Keep progress handler alive for CreateTaskResult responses - let isTaskResponse = false; - if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') { - const result = response.result as Record; - if (result.task && typeof result.task === 'object') { - const task = result.task as Record; - if (typeof task.taskId === 'string') { - isTaskResponse = true; - this._taskProgressTokens.set(task.taskId, messageId); - } - } - } - - if (!isTaskResponse) { + if (!preserveProgress) { this._progressHandlers.delete(messageId); } @@ -1020,142 +753,6 @@ export abstract class Protocol { */ protected abstract assertRequestHandlerCapability(method: string): void; - /** - * A method to check if task creation is supported for the given request method. - * - * This should be implemented by subclasses. - */ - protected abstract assertTaskCapability(method: string): void; - - /** - * A method to check if a task handler is supported by the local side, for the given method to be handled. - * - * This should be implemented by subclasses. - */ - protected abstract assertTaskHandlerCapability(method: string): void; - - /** - * Sends a request and returns an AsyncGenerator that yields response messages, - * resolving the result schema automatically from the method name. - * The generator is guaranteed to end with either a `'result'` or `'error'` message. - * - * @experimental Use `client.experimental.tasks.requestStream()` to access this method. - */ - protected async *requestStream( - request: { method: M; params?: Record }, - options?: RequestOptions - ): AsyncGenerator, void, void> { - const resultSchema = getResultSchema(request.method) as unknown as AnyObjectSchema; - yield* this._requestStreamWithSchema(request as Request, resultSchema, options) as AsyncGenerator< - ResponseMessage, - void, - void - >; - } - - /** - * Sends a request and returns an AsyncGenerator that yields response messages, - * using the provided schema for validation. - * - * This is the internal implementation used by SDK methods that need to specify - * a particular result schema. - */ - protected async *_requestStreamWithSchema( - request: Request, - resultSchema: T, - options?: RequestOptions - ): AsyncGenerator>, void, void> { - const { task } = options ?? {}; - - // For non-task requests, just yield the result - if (!task) { - try { - const result = await this._requestWithSchema(request, resultSchema, options); - yield { type: 'result', result }; - } catch (error) { - yield { - type: 'error', - error: error instanceof Error ? error : new Error(String(error)) - }; - } - return; - } - - // For task-augmented requests, we need to poll for status - // First, make the request to create the task - let taskId: string | undefined; - try { - // Send the request and get the CreateTaskResult - const createResult = await this._requestWithSchema(request, CreateTaskResultSchema, options); - - // Extract taskId from the result - if (createResult.task) { - taskId = createResult.task.taskId; - yield { type: 'taskCreated', task: createResult.task }; - } else { - throw new ProtocolError(ProtocolErrorCode.InternalError, 'Task creation did not return a task'); - } - - // Poll for task completion - while (true) { - // Get current task status - const task = await this.getTask({ taskId }, options); - yield { type: 'taskStatus', task }; - - // Check if task is terminal - if (isTerminal(task.status)) { - switch (task.status) { - case 'completed': { - // Get the final result - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - - break; - } - case 'failed': { - yield { - type: 'error', - error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} failed`) - }; - - break; - } - case 'cancelled': { - yield { - type: 'error', - error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} was cancelled`) - }; - - break; - } - // No default - } - return; - } - - // When input_required, call tasks/result to deliver queued messages - // (elicitation, sampling) via SSE and block until terminal - if (task.status === 'input_required') { - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - return; - } - - // Wait before polling again - const pollInterval = task.pollInterval ?? this._options?.defaultTaskPollInterval ?? 1000; - await new Promise(resolve => setTimeout(resolve, pollInterval)); - - // Check if cancelled - options?.signal?.throwIfAborted(); - } - } catch (error) { - yield { - type: 'error', - error: error instanceof Error ? error : new Error(String(error)) - }; - } - } - /** * Sends a request and waits for a response, resolving the result schema * automatically from the method name. @@ -1181,7 +778,7 @@ export abstract class Protocol { resultSchema: T, options?: RequestOptions ): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; + const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; // Send the request return new Promise>((resolve, reject) => { @@ -1197,11 +794,6 @@ export abstract class Protocol { if (this._options?.enforceStrictCapabilities === true) { try { this.assertCapabilityForMethod(request.method as RequestMethod); - - // If task creation is requested, also check task capabilities - if (task) { - this.assertTaskCapability(request.method); - } } catch (error) { earlyReject(error); return; @@ -1228,25 +820,6 @@ export abstract class Protocol { }; } - // Augment with task creation parameters if provided - if (task) { - jsonrpcRequest.params = { - ...jsonrpcRequest.params, - task: task - }; - } - - // Augment with related task metadata if relatedTask is provided - if (relatedTask) { - jsonrpcRequest.params = { - ...jsonrpcRequest.params, - _meta: { - ...jsonrpcRequest.params?._meta, - [RELATED_TASK_META_KEY]: relatedTask - } - }; - } - const cancel = (reason: unknown) => { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); @@ -1301,34 +874,30 @@ export abstract class Protocol { this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - // Queue request if related to a task - const relatedTaskId = relatedTask?.taskId; - if (relatedTaskId) { - // Store the response resolver for this request so responses can be routed back - const responseResolver = (response: JSONRPCResultResponse | Error) => { - const handler = this._responseHandlers.get(messageId); - if (handler) { - handler(response); - } else { - // Log error when resolver is missing, but don't fail - this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); - } - }; - this._requestResolvers.set(messageId, responseResolver); + // Delegate task augmentation and routing to module (if registered) + const responseHandler = (response: JSONRPCResultResponse | Error) => { + const handler = this._responseHandlers.get(messageId); + if (handler) { + handler(response); + } else { + this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); + } + }; - this._enqueueTaskMessage(relatedTaskId, { - type: 'request', - message: jsonrpcRequest, - timestamp: Date.now() - }).catch(error => { + let outboundQueued = false; + for (const module of this._modules) { + const moduleResult = module.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, error => { this._cleanupTimeout(messageId); reject(error); }); + if (moduleResult.queued) { + outboundQueued = true; + break; + } + } - // Don't send through transport - queued messages are delivered via tasks/result only - // This prevents duplicate delivery for bidirectional transports - } else { - // No related task - send through transport normally + if (!outboundQueued) { + // No related task or no module - send through transport normally this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { this._cleanupTimeout(messageId); reject(error); @@ -1337,46 +906,6 @@ export abstract class Protocol { }); } - /** - * Gets the current status of a task. - * - * @experimental Use `client.experimental.tasks.getTask()` to access this method. - */ - protected async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { - return this._requestWithSchema({ method: 'tasks/get', params }, GetTaskResultSchema, options); - } - - /** - * Retrieves the result of a completed task. - * - * @experimental Use `client.experimental.tasks.getTaskResult()` to access this method. - */ - protected async getTaskResult( - params: GetTaskPayloadRequest['params'], - resultSchema: T, - options?: RequestOptions - ): Promise> { - return this._requestWithSchema({ method: 'tasks/result', params }, resultSchema, options); - } - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @experimental Use `client.experimental.tasks.listTasks()` to access this method. - */ - protected async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { - return this._requestWithSchema({ method: 'tasks/list', params }, ListTasksResultSchema, options); - } - - /** - * Cancels a specific task. - * - * @experimental Use `client.experimental.tasks.cancelTask()` to access this method. - */ - protected async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { - return this._requestWithSchema({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); - } - /** * Emits a notification, which is a one-way message that does not expect a response. */ @@ -1387,30 +916,27 @@ export abstract class Protocol { this.assertNotificationCapability(notification.method as NotificationMethod); - // Queue notification if related to a task - const relatedTaskId = options?.relatedTask?.taskId; - if (relatedTaskId) { - // Build the JSONRPC notification with metadata - const jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0', - params: { - ...notification.params, - _meta: { - ...notification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; + // Delegate task-related notification routing and JSONRPC building to modules (if registered) + let queued = false; + let jsonrpcNotification: JSONRPCNotification | undefined; - await this._enqueueTaskMessage(relatedTaskId, { - type: 'notification', - message: jsonrpcNotification, - timestamp: Date.now() - }); + if (this._modules.length > 0) { + for (const module of this._modules) { + const result = await module.processOutboundNotification(notification, options); + if (result.queued) { + queued = true; + break; + } + // Last module's jsonrpcNotification is used + jsonrpcNotification = result.jsonrpcNotification; + } + } else { + // No modules — build JSONRPC notification directly + jsonrpcNotification = { ...notification, jsonrpc: '2.0' }; + } + if (queued) { // Don't send through transport - queued messages are delivered via tasks/result only - // This prevents duplicate delivery for bidirectional transports return; } @@ -1440,54 +966,16 @@ export abstract class Protocol { return; } - let jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0' - }; - - // Augment with related task metadata if relatedTask is provided - if (options?.relatedTask) { - jsonrpcNotification = { - ...jsonrpcNotification, - params: { - ...jsonrpcNotification.params, - _meta: { - ...jsonrpcNotification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; - } - // Send the notification, but don't await it here to avoid blocking. // Handle potential errors with a .catch(). - this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error)); + this._transport?.send(jsonrpcNotification!, options).catch(error => this._onerror(error)); }); // Return immediately. return; } - let jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0' - }; - - // Augment with related task metadata if relatedTask is provided - if (options?.relatedTask) { - jsonrpcNotification = { - ...jsonrpcNotification, - params: { - ...jsonrpcNotification.params, - _meta: { - ...jsonrpcNotification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; - } - - await this._transport.send(jsonrpcNotification, options); + await this._transport.send(jsonrpcNotification!, options); } /** @@ -1547,194 +1035,6 @@ export abstract class Protocol { removeNotificationHandler(method: NotificationMethod): void { this._notificationHandlers.delete(method); } - - /** - * Cleans up the progress handler associated with a task. - * This should be called when a task reaches a terminal status. - */ - private _cleanupTaskProgressHandler(taskId: string): void { - const progressToken = this._taskProgressTokens.get(taskId); - if (progressToken !== undefined) { - this._progressHandlers.delete(progressToken); - this._taskProgressTokens.delete(taskId); - } - } - - /** - * Enqueues a task-related message for side-channel delivery via `tasks/result`. - * @param taskId The task ID to associate the message with - * @param message The message to enqueue - * @param sessionId Optional session ID for binding the operation to a specific session - * @throws Error if `taskStore` is not configured or if enqueue fails (e.g., queue overflow) - * - * Note: If enqueue fails, it's the {@linkcode TaskMessageQueue} implementation's responsibility to handle - * the error appropriately (e.g., by failing the task, logging, etc.). The Protocol layer - * simply propagates the error. - */ - private async _enqueueTaskMessage(taskId: string, message: QueuedMessage, sessionId?: string): Promise { - // Task message queues are only used when taskStore is configured - if (!this._taskStore || !this._taskMessageQueue) { - throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); - } - - const maxQueueSize = this._options?.maxTaskQueueSize; - await this._taskMessageQueue.enqueue(taskId, message, sessionId, maxQueueSize); - } - - /** - * Clears the message queue for a task and rejects any pending request resolvers. - * @param taskId The task ID whose queue should be cleared - * @param sessionId Optional session ID for binding the operation to a specific session - */ - private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { - if (this._taskMessageQueue) { - // Reject any pending request resolvers - const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); - for (const message of messages) { - if (message.type === 'request' && isJSONRPCRequest(message.message)) { - // Extract request ID from the message - const requestId = message.message.id as RequestId; - const resolver = this._requestResolvers.get(requestId); - if (resolver) { - resolver(new ProtocolError(ProtocolErrorCode.InternalError, 'Task cancelled or completed')); - this._requestResolvers.delete(requestId); - } else { - // Log error when resolver is missing during cleanup for better observability - this._onerror(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); - } - } - } - } - } - - /** - * Waits for a task update (new messages or status change) with abort signal support. - * Uses polling to check for updates at the task's configured poll interval. - * @param taskId The task ID to wait for - * @param signal Abort signal to cancel the wait - * @returns Promise that resolves when an update occurs or rejects if aborted - */ - private async _waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { - // Get the task's poll interval, falling back to default - let interval = this._options?.defaultTaskPollInterval ?? 1000; - try { - const task = await this._taskStore?.getTask(taskId); - if (task?.pollInterval) { - interval = task.pollInterval; - } - } catch { - // Use default interval if task lookup fails - } - - return new Promise((resolve, reject) => { - if (signal.aborted) { - reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); - return; - } - - // Wait for the poll interval, then resolve so caller can check for updates - const timeoutId = setTimeout(resolve, interval); - - // Clean up timeout and reject if aborted - signal.addEventListener( - 'abort', - () => { - clearTimeout(timeoutId); - reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); - }, - { once: true } - ); - }); - } - - private requestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { - const taskStore = this._taskStore; - if (!taskStore) { - throw new Error('No task store configured'); - } - - return { - createTask: async taskParams => { - if (!request) { - throw new Error('No request provided'); - } - - return await taskStore.createTask( - taskParams, - request.id, - { - method: request.method, - params: request.params - }, - sessionId - ); - }, - getTask: async taskId => { - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - - return task; - }, - storeTaskResult: async (taskId, status, result) => { - await taskStore.storeTaskResult(taskId, status, result, sessionId); - - // Get updated task state and send notification - const task = await taskStore.getTask(taskId, sessionId); - if (task) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: task - }); - await this.notification(notification as Notification); - - if (isTerminal(task.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } - } - }, - getTaskResult: taskId => { - return taskStore.getTaskResult(taskId, sessionId); - }, - updateTaskStatus: async (taskId, status, statusMessage) => { - // Check if task exists - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task "${taskId}" not found - it may have been cleaned up`); - } - - // Don't allow transitions from terminal states - if (isTerminal(task.status)) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` - ); - } - - await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); - - // Get updated task state and send notification - const updatedTask = await taskStore.getTask(taskId, sessionId); - if (updatedTask) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: updatedTask - }); - await this.notification(notification as Notification); - - if (isTerminal(updatedTask.status)) { - this._cleanupTaskProgressHandler(taskId); - // Don't clear queue here - it will be cleared after delivery via tasks/result - } - } - }, - listTasks: cursor => { - return taskStore.listTasks(cursor, sessionId); - } - }; - } } function isPlainObject(value: unknown): value is Record { diff --git a/packages/core/src/shared/protocolModule.ts b/packages/core/src/shared/protocolModule.ts new file mode 100644 index 000000000..df5944cb2 --- /dev/null +++ b/packages/core/src/shared/protocolModule.ts @@ -0,0 +1,122 @@ +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + Notification, + Request, + RequestId, + Result +} from '../types/types.js'; +import type { AnySchema, SchemaOutput } from '../util/schema.js'; +import type { BaseContext, NotificationOptions, RequestOptions } from './protocol.js'; + +/** + * Host interface that a ProtocolModule uses to interact with the Protocol instance. + * Provided to the module via bind(). + * @internal + */ +export interface ProtocolModuleHost { + request(request: Request, resultSchema: T, options?: RequestOptions): Promise>; + notification(notification: Notification, options?: NotificationOptions): Promise; + reportError(error: Error): void; + removeProgressHandler(token: number): void; + registerHandler(method: string, handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise): void; + sendOnResponseStream(message: JSONRPCNotification | JSONRPCRequest, relatedRequestId: RequestId): Promise; +} + +/** + * Context provided to a module when processing an inbound request. + * @internal + */ +export interface InboundContext { + sessionId?: string; + sendNotification: (notification: Notification, options?: NotificationOptions) => Promise; + sendRequest: (request: Request, resultSchema: U, options?: RequestOptions) => Promise>; +} + +/** + * Result returned by a module after processing an inbound request. + * Provides wrapped send functions and routing for task-related responses. + * @internal + */ +export interface InboundResult { + taskContext?: BaseContext['task']; + sendNotification: (notification: Notification) => Promise; + sendRequest: ( + request: Request, + resultSchema: U, + options?: Omit + ) => Promise>; + routeResponse: (message: JSONRPCResponse | JSONRPCErrorResponse) => Promise; + hasTaskCreationParams: boolean; + /** + * Optional validation to run inside the async handler chain (before the request handler). + * Throwing here produces a proper JSON-RPC error response, matching the behavior of + * capability checks on main. + */ + validateInbound?: () => void; +} + +/** + * Interface for pluggable protocol modules that extend Protocol behavior. + * + * A ProtocolModule hooks into Protocol's message lifecycle to intercept, + * augment, or route messages. Modules are registered via Protocol.registerModule(). + * @internal + */ +export interface ProtocolModule { + /** + * Binds this module to a Protocol host, allowing it to send messages + * and register handlers. + */ + bind(host: ProtocolModuleHost): void; + + /** + * Processes an inbound request, extracting module-specific context + * and wrapping send functions for routing. + */ + processInboundRequest(request: JSONRPCRequest, ctx: InboundContext): InboundResult; + + /** + * Processes an outbound request, potentially augmenting it or routing + * it through a side channel. + * + * @returns { queued: true } if the request was routed and should not be sent via transport. + */ + processOutboundRequest( + jsonrpcRequest: JSONRPCRequest, + options: RequestOptions | undefined, + messageId: number, + responseHandler: (response: JSONRPCResultResponse | Error) => void, + onError: (error: unknown) => void + ): { queued: boolean }; + + /** + * Processes an inbound response, potentially consuming it (e.g., for side-channel responses). + * + * @returns consumed=true if the response was handled and should not be dispatched normally. + * preserveProgress=true if the progress handler should be kept alive after dispatch. + */ + processInboundResponse( + response: JSONRPCResponse | JSONRPCErrorResponse, + messageId: number + ): { consumed: boolean; preserveProgress: boolean }; + + /** + * Processes an outbound notification, potentially routing it through a side channel. + * + * @returns queued=true if the notification was routed and should not be sent via transport. + * jsonrpcNotification is the JSONRPC-wrapped notification to send if not queued. + */ + processOutboundNotification( + notification: Notification, + options?: NotificationOptions + ): Promise<{ queued: boolean; jsonrpcNotification?: JSONRPCNotification }>; + + /** + * Called when the protocol connection is closed. Cleans up module state. + */ + onClose(): void; +} diff --git a/packages/core/src/shared/taskManager.ts b/packages/core/src/shared/taskManager.ts new file mode 100644 index 000000000..d634a53ea --- /dev/null +++ b/packages/core/src/shared/taskManager.ts @@ -0,0 +1,817 @@ +import type { CreateTaskOptions, QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; +import { isTerminal } from '../experimental/tasks/interfaces.js'; +import type { + GetTaskPayloadRequest, + GetTaskRequest, + GetTaskResult, + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + Notification, + Request, + RequestId, + Result, + Task, + TaskCreationParams, + TaskStatusNotification +} from '../types/types.js'; +import { + CancelTaskResultSchema, + CreateTaskResultSchema, + GetTaskResultSchema, + isJSONRPCErrorResponse, + isJSONRPCRequest, + isJSONRPCResultResponse, + isTaskAugmentedRequestParams, + ListTasksResultSchema, + ProtocolError, + ProtocolErrorCode, + RELATED_TASK_META_KEY, + TaskStatusNotificationSchema +} from '../types/types.js'; +import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/schema.js'; +import type { NotificationOptions, RequestOptions } from './protocol.js'; +import type { InboundContext, InboundResult, ProtocolModule, ProtocolModuleHost } from './protocolModule.js'; +import type { ResponseMessage } from './responseMessage.js'; + +/** + * Options that can be given per request. + */ +// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. +export type TaskRequestOptions = Omit; + +/** + * Request-scoped TaskStore interface. + */ +export interface RequestTaskStore { + /** + * Creates a new task with the given creation parameters. + * The implementation generates a unique taskId and createdAt timestamp. + * + * @param taskParams - The task creation parameters from the request + * @returns The created task object + */ + createTask(taskParams: CreateTaskOptions): Promise; + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @returns The task object + * @throws If the task does not exist + */ + getTask(taskId: string): Promise; + + /** + * Stores the result of a task and sets its final status. + * + * @param taskId - The task identifier + * @param status - The final status: 'completed' for success, 'failed' for errors + * @param result - The result to store + */ + storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; + + /** + * Retrieves the stored result of a task. + * + * @param taskId - The task identifier + * @returns The stored result + */ + getTaskResult(taskId: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + * + * @param taskId - The task identifier + * @param status - The new status + * @param statusMessage - Optional diagnostic message for failed tasks or other status information + */ + updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + * + * @param cursor - Optional cursor for pagination + * @returns An object containing the tasks array and an optional nextCursor + */ + listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; +} + +/** + * Task context provided to request handlers when task storage is configured. + */ +export type TaskContext = { + id?: string; + store: RequestTaskStore; + requestedTtl?: number | null; +}; + +export type TaskManagerOptions = { + /** + * Task storage implementation. Required for handling incoming task requests (server-side). + * Not required for sending task requests (client-side outbound API). + */ + taskStore?: TaskStore; + /** + * Optional task message queue implementation for managing server-initiated messages + * that will be delivered through the tasks/result response stream. + */ + taskMessageQueue?: TaskMessageQueue; + /** + * Default polling interval (in milliseconds) for task status checks when no pollInterval + * is provided by the server. Defaults to 1000ms if not specified. + */ + defaultTaskPollInterval?: number; + /** + * Maximum number of messages that can be queued per task for side-channel delivery. + * If undefined, the queue size is unbounded. + */ + maxTaskQueueSize?: number; + /** + * When true, outbound task-augmented requests will be validated against + * the remote side's declared task capabilities before sending. + */ + enforceStrictCapabilities?: boolean; + /** + * Assert that the remote side supports task creation for the given method. + * Called when sending a task-augmented outbound request (only when enforceStrictCapabilities is true). + */ + assertTaskCapability?: (method: string) => void; + /** + * Assert that this side supports handling task creation for the given method. + * Called when receiving a task-augmented inbound request. + */ + assertTaskHandlerCapability?: (method: string) => void; +}; + +/** + * Manages task orchestration: state, message queuing, and polling. + * Capability checking is provided via optional constructor callbacks. + * @internal + */ +export class TaskManager implements ProtocolModule { + private _taskStore?: TaskStore; + private _taskMessageQueue?: TaskMessageQueue; + private _taskProgressTokens: Map = new Map(); + private _requestResolvers: Map void> = new Map(); + private _options: TaskManagerOptions; + private _host?: ProtocolModuleHost; + + constructor(options: TaskManagerOptions) { + this._options = options; + this._taskStore = options.taskStore; + this._taskMessageQueue = options.taskMessageQueue; + } + + bind(host: ProtocolModuleHost): void { + this._host = host; + + if (this._taskStore) { + host.registerHandler('tasks/get', async (request, ctx) => { + const params = request.params as { taskId: string }; + const task = await this.handleGetTask(params.taskId, ctx.sessionId); + // Per spec: tasks/get responses SHALL NOT include related-task metadata + // as the taskId parameter is the source of truth + return { + ...task + } as Result; + }); + + host.registerHandler('tasks/result', async (request, ctx) => { + const params = request.params as { taskId: string }; + return await this.handleGetTaskPayload(params.taskId, ctx.sessionId, ctx.mcpReq.signal, async message => { + // Send the message on the response stream by passing the relatedRequestId + // This tells the transport to write the message to the tasks/result response stream + await host.sendOnResponseStream(message, ctx.mcpReq.id); + }); + }); + + host.registerHandler('tasks/list', async (request, ctx) => { + const params = request.params as { cursor?: string } | undefined; + return (await this.handleListTasks(params?.cursor, ctx.sessionId)) as Result; + }); + + host.registerHandler('tasks/cancel', async (request, ctx) => { + const params = request.params as { taskId: string }; + return await this.handleCancelTask(params.taskId, ctx.sessionId); + }); + } + } + + private get _requireHost(): ProtocolModuleHost { + if (!this._host) { + throw new ProtocolError(ProtocolErrorCode.InternalError, 'TaskManager is not bound to a Protocol host — call bind() first'); + } + return this._host; + } + + get taskStore(): TaskStore | undefined { + return this._taskStore; + } + + private get _requireTaskStore(): TaskStore { + if (!this._taskStore) { + throw new ProtocolError(ProtocolErrorCode.InternalError, 'TaskStore is not configured'); + } + return this._taskStore; + } + + get taskMessageQueue(): TaskMessageQueue | undefined { + return this._taskMessageQueue; + } + + // -- Public API (client-facing) -- + async *requestStream( + request: Request, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + const host = this._requireHost; + const { task } = options ?? {}; + + if (!task) { + try { + const result = await host.request(request, resultSchema, options); + yield { type: 'result', result }; + } catch (error) { + yield { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)) + }; + } + return; + } + + let taskId: string | undefined; + try { + const createResult = await host.request(request, CreateTaskResultSchema, options); + + if (createResult.task) { + taskId = createResult.task.taskId; + yield { type: 'taskCreated', task: createResult.task }; + } else { + throw new ProtocolError(ProtocolErrorCode.InternalError, 'Task creation did not return a task'); + } + + while (true) { + const task = await this.getTask({ taskId }, options); + yield { type: 'taskStatus', task }; + + if (isTerminal(task.status)) { + switch (task.status) { + case 'completed': { + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + break; + } + case 'failed': { + yield { type: 'error', error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} failed`) }; + break; + } + case 'cancelled': { + yield { + type: 'error', + error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} was cancelled`) + }; + break; + } + } + return; + } + + if (task.status === 'input_required') { + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + return; + } + + const pollInterval = task.pollInterval ?? this._options.defaultTaskPollInterval ?? 1000; + await new Promise(resolve => setTimeout(resolve, pollInterval)); + options?.signal?.throwIfAborted(); + } + } catch (error) { + yield { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)) + }; + } + } + + async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { + return this._requireHost.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); + } + + async getTaskResult( + params: GetTaskPayloadRequest['params'], + resultSchema: T, + options?: RequestOptions + ): Promise> { + return this._requireHost.request({ method: 'tasks/result', params }, resultSchema, options); + } + + async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { + return this._requireHost.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); + } + + async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { + return this._requireHost.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); + } + + // -- Handler bodies (delegated from Protocol's registered handlers) -- + + private async handleGetTask(taskId: string, sessionId?: string): Promise { + const task = await this._requireTaskStore.getTask(taskId, sessionId); + if (!task) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + return task; + } + + private async handleGetTaskPayload( + taskId: string, + sessionId: string | undefined, + signal: AbortSignal, + sendOnResponseStream: (message: JSONRPCNotification | JSONRPCRequest) => Promise + ): Promise { + const handleTaskResult = async (): Promise => { + if (this._taskMessageQueue) { + let queuedMessage: QueuedMessage | undefined; + while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, sessionId))) { + if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { + const message = queuedMessage.message; + const requestId = message.id; + const resolver = this._requestResolvers.get(requestId as RequestId); + + if (resolver) { + this._requestResolvers.delete(requestId as RequestId); + if (queuedMessage.type === 'response') { + resolver(message as JSONRPCResultResponse); + } else { + const errorMessage = message as JSONRPCErrorResponse; + resolver(new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data)); + } + } else { + const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; + this._host?.reportError(new Error(`${messageType} handler missing for request ${requestId}`)); + } + continue; + } + + await sendOnResponseStream(queuedMessage.message as JSONRPCNotification | JSONRPCRequest); + } + } + + const task = await this._requireTaskStore.getTask(taskId, sessionId); + if (!task) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); + } + + if (!isTerminal(task.status)) { + await this._waitForTaskUpdate(task.pollInterval, signal); + return await handleTaskResult(); + } + + const result = await this._requireTaskStore.getTaskResult(taskId, sessionId); + await this._clearTaskQueue(taskId); + + return { + ...result, + _meta: { + ...result._meta, + [RELATED_TASK_META_KEY]: { taskId } + } + }; + }; + + return await handleTaskResult(); + } + + private async handleListTasks( + cursor: string | undefined, + sessionId?: string + ): Promise<{ tasks: Task[]; nextCursor?: string; _meta: Record }> { + try { + const { tasks, nextCursor } = await this._requireTaskStore.listTasks(cursor, sessionId); + return { tasks, nextCursor, _meta: {} }; + } catch (error) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + private async handleCancelTask(taskId: string, sessionId?: string): Promise { + try { + const task = await this._requireTaskStore.getTask(taskId, sessionId); + if (!task) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); + } + + if (isTerminal(task.status)) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); + } + + await this._requireTaskStore.updateTaskStatus(taskId, 'cancelled', 'Client cancelled task execution.', sessionId); + await this._clearTaskQueue(taskId); + + const cancelledTask = await this._requireTaskStore.getTask(taskId, sessionId); + if (!cancelledTask) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found after cancellation: ${taskId}`); + } + + return { _meta: {}, ...cancelledTask }; + } catch (error) { + if (error instanceof ProtocolError) throw error; + throw new ProtocolError( + ProtocolErrorCode.InvalidRequest, + `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + // -- Internal delegation methods -- + + private prepareOutboundRequest( + jsonrpcRequest: JSONRPCRequest, + options: RequestOptions | undefined, + messageId: number, + responseHandler: (response: JSONRPCResultResponse | Error) => void, + onError: (error: unknown) => void + ): boolean { + const { task, relatedTask } = options ?? {}; + + if (task) { + jsonrpcRequest.params = { + ...jsonrpcRequest.params, + task: task + }; + } + + if (relatedTask) { + jsonrpcRequest.params = { + ...jsonrpcRequest.params, + _meta: { + ...jsonrpcRequest.params?._meta, + [RELATED_TASK_META_KEY]: relatedTask + } + }; + } + + const relatedTaskId = relatedTask?.taskId; + if (relatedTaskId) { + this._requestResolvers.set(messageId, responseHandler); + + this._enqueueTaskMessage(relatedTaskId, { + type: 'request', + message: jsonrpcRequest, + timestamp: Date.now() + }).catch(error => { + onError(error); + }); + + return true; + } + + return false; + } + + private extractInboundTaskContext( + request: JSONRPCRequest, + sessionId?: string + ): { + relatedTaskId?: string; + taskCreationParams?: TaskCreationParams; + taskContext?: TaskContext; + } { + const relatedTaskId = (request.params?._meta as Record | undefined)?.[RELATED_TASK_META_KEY]?.taskId; + const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; + + // Provide task context whenever a task store is configured, + // not just for task-related requests — tools need ctx.task.store + let taskContext: TaskContext | undefined; + if (this._taskStore) { + const store = this.createRequestTaskStore(request, sessionId); + taskContext = { + id: relatedTaskId, + store, + requestedTtl: taskCreationParams?.ttl + }; + } + + if (!relatedTaskId && !taskCreationParams && !taskContext) { + return {}; + } + + return { + relatedTaskId, + taskCreationParams, + taskContext + }; + } + + private wrapSendNotification( + relatedTaskId: string, + originalSendNotification: (notification: Notification, options?: NotificationOptions) => Promise + ): (notification: Notification) => Promise { + return async (notification: Notification) => { + const notificationOptions: NotificationOptions = { relatedTask: { taskId: relatedTaskId } }; + await originalSendNotification(notification, notificationOptions); + }; + } + + private wrapSendRequest( + relatedTaskId: string, + taskStore: RequestTaskStore | undefined, + originalSendRequest: (request: Request, resultSchema: V, options?: RequestOptions) => Promise> + ): (request: Request, resultSchema: V, options?: TaskRequestOptions) => Promise> { + return async (request: Request, resultSchema: V, options?: TaskRequestOptions) => { + const requestOptions: RequestOptions = { ...options }; + if (relatedTaskId && !requestOptions.relatedTask) { + requestOptions.relatedTask = { taskId: relatedTaskId }; + } + + const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; + if (effectiveTaskId && taskStore) { + await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); + } + + return await originalSendRequest(request, resultSchema, requestOptions); + }; + } + + private handleResponse(response: JSONRPCResponse | JSONRPCErrorResponse): boolean { + const messageId = Number(response.id); + const resolver = this._requestResolvers.get(messageId); + if (resolver) { + this._requestResolvers.delete(messageId); + if (isJSONRPCResultResponse(response)) { + resolver(response); + } else { + resolver(new ProtocolError(response.error.code, response.error.message, response.error.data)); + } + return true; + } + return false; + } + + private shouldPreserveProgressHandler(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): boolean { + if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') { + const result = response.result as Record; + if (result.task && typeof result.task === 'object') { + const task = result.task as Record; + if (typeof task.taskId === 'string') { + this._taskProgressTokens.set(task.taskId, messageId); + return true; + } + } + } + return false; + } + + private async routeNotification(notification: Notification, options?: NotificationOptions): Promise { + const relatedTaskId = options?.relatedTask?.taskId; + if (!relatedTaskId) return false; + + const jsonrpcNotification: JSONRPCNotification = { + ...notification, + jsonrpc: '2.0', + params: { + ...notification.params, + _meta: { + ...notification.params?._meta, + [RELATED_TASK_META_KEY]: options!.relatedTask + } + } + }; + + await this._enqueueTaskMessage(relatedTaskId, { + type: 'notification', + message: jsonrpcNotification, + timestamp: Date.now() + }); + + return true; + } + + private async routeResponse( + relatedTaskId: string | undefined, + message: JSONRPCResponse | JSONRPCErrorResponse, + sessionId?: string + ): Promise { + if (!relatedTaskId || !this._taskMessageQueue) return false; + + await (isJSONRPCErrorResponse(message) + ? this._enqueueTaskMessage(relatedTaskId, { type: 'error', message, timestamp: Date.now() }, sessionId) + : this._enqueueTaskMessage( + relatedTaskId, + { type: 'response', message: message as JSONRPCResultResponse, timestamp: Date.now() }, + sessionId + )); + return true; + } + + private createRequestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { + const taskStore = this._requireTaskStore; + const host = this._host; + + return { + createTask: async taskParams => { + if (!request) throw new Error('No request provided'); + return await taskStore.createTask(taskParams, request.id, { method: request.method, params: request.params }, sessionId); + }, + getTask: async taskId => { + const task = await taskStore.getTask(taskId, sessionId); + if (!task) throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + return task; + }, + storeTaskResult: async (taskId, status, result) => { + await taskStore.storeTaskResult(taskId, status, result, sessionId); + const task = await taskStore.getTask(taskId, sessionId); + if (task) { + const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ + method: 'notifications/tasks/status', + params: task + }); + await host?.notification(notification as Notification); + if (isTerminal(task.status)) { + this._cleanupTaskProgressHandler(taskId); + } + } + }, + getTaskResult: taskId => taskStore.getTaskResult(taskId, sessionId), + updateTaskStatus: async (taskId, status, statusMessage) => { + const task = await taskStore.getTask(taskId, sessionId); + if (!task) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task "${taskId}" not found - it may have been cleaned up`); + } + if (isTerminal(task.status)) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` + ); + } + await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); + const updatedTask = await taskStore.getTask(taskId, sessionId); + if (updatedTask) { + const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ + method: 'notifications/tasks/status', + params: updatedTask + }); + await host?.notification(notification as Notification); + if (isTerminal(updatedTask.status)) { + this._cleanupTaskProgressHandler(taskId); + } + } + }, + listTasks: cursor => taskStore.listTasks(cursor, sessionId) + }; + } + + // -- Consolidated lifecycle methods (called by Protocol via ProtocolModule) -- + + processInboundRequest(request: JSONRPCRequest, ctx: InboundContext): InboundResult { + const taskInfo = this.extractInboundTaskContext(request, ctx.sessionId); + const relatedTaskId = taskInfo?.relatedTaskId; + + const sendNotification = relatedTaskId + ? this.wrapSendNotification(relatedTaskId, ctx.sendNotification) + : (notification: Notification) => ctx.sendNotification(notification); + + const sendRequest = taskInfo?.taskContext + ? this.wrapSendRequest(relatedTaskId ?? '', taskInfo.taskContext.store, ctx.sendRequest) + : ctx.sendRequest; + + const hasTaskCreationParams = !!taskInfo?.taskCreationParams; + + return { + taskContext: taskInfo?.taskContext, + sendNotification, + sendRequest, + routeResponse: async (message: JSONRPCResponse | JSONRPCErrorResponse) => { + if (relatedTaskId) { + return this.routeResponse(relatedTaskId, message, ctx.sessionId); + } + return false; + }, + hasTaskCreationParams, + // Deferred validation: runs inside the async handler chain so errors + // produce proper JSON-RPC error responses (matching main's behavior). + validateInbound: hasTaskCreationParams ? () => this._options.assertTaskHandlerCapability?.(request.method) : undefined + }; + } + + processOutboundRequest( + jsonrpcRequest: JSONRPCRequest, + options: RequestOptions | undefined, + messageId: number, + responseHandler: (response: JSONRPCResultResponse | Error) => void, + onError: (error: unknown) => void + ): { queued: boolean } { + // Check task capability when sending a task-augmented request (matches main's enforceStrictCapabilities gate) + if (this._options.enforceStrictCapabilities && options?.task) { + this._options.assertTaskCapability?.(jsonrpcRequest.method); + } + + const queued = this.prepareOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, onError); + return { queued }; + } + + processInboundResponse( + response: JSONRPCResponse | JSONRPCErrorResponse, + messageId: number + ): { consumed: boolean; preserveProgress: boolean } { + const consumed = this.handleResponse(response); + if (consumed) { + return { consumed: true, preserveProgress: false }; + } + const preserveProgress = this.shouldPreserveProgressHandler(response, messageId); + return { consumed: false, preserveProgress }; + } + + async processOutboundNotification( + notification: Notification, + options?: NotificationOptions + ): Promise<{ queued: boolean; jsonrpcNotification?: JSONRPCNotification }> { + // Try queuing first + const queued = await this.routeNotification(notification, options); + if (queued) return { queued: true }; + + // Build JSONRPC notification with optional relatedTask metadata + let jsonrpcNotification: JSONRPCNotification = { ...notification, jsonrpc: '2.0' }; + if (options?.relatedTask) { + jsonrpcNotification = { + ...jsonrpcNotification, + params: { + ...jsonrpcNotification.params, + _meta: { + ...jsonrpcNotification.params?._meta, + [RELATED_TASK_META_KEY]: options.relatedTask + } + } + }; + } + return { queued: false, jsonrpcNotification }; + } + + onClose(): void { + this._taskProgressTokens.clear(); + } + + // -- Private helpers -- + + private async _enqueueTaskMessage(taskId: string, message: QueuedMessage, sessionId?: string): Promise { + if (!this._taskStore || !this._taskMessageQueue) { + throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); + } + await this._taskMessageQueue.enqueue(taskId, message, sessionId, this._options.maxTaskQueueSize); + } + + private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { + if (this._taskMessageQueue) { + const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); + for (const message of messages) { + if (message.type === 'request' && isJSONRPCRequest(message.message)) { + const requestId = message.message.id as RequestId; + const resolver = this._requestResolvers.get(requestId); + if (resolver) { + resolver(new ProtocolError(ProtocolErrorCode.InternalError, 'Task cancelled or completed')); + this._requestResolvers.delete(requestId); + } else { + this._host?.reportError(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); + } + } + } + } + } + + private async _waitForTaskUpdate(pollInterval: number | undefined, signal: AbortSignal): Promise { + const interval = pollInterval ?? this._options.defaultTaskPollInterval ?? 1000; + + return new Promise((resolve, reject) => { + if (signal.aborted) { + reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); + return; + } + const timeoutId = setTimeout(resolve, interval); + signal.addEventListener( + 'abort', + () => { + clearTimeout(timeoutId); + reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); + }, + { once: true } + ); + }); + } + + private _cleanupTaskProgressHandler(taskId: string): void { + const progressToken = this._taskProgressTokens.get(taskId); + if (progressToken !== undefined) { + this._host?.removeProgressHandler(progressToken); + this._taskProgressTokens.delete(taskId); + } + } +} diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 8675c1e03..522f041ba 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -13,14 +13,19 @@ import type { import { InMemoryTaskMessageQueue } from '../../src/experimental/tasks/stores/inMemory.js'; import type { BaseContext } from '../../src/shared/protocol.js'; import { mergeCapabilities, Protocol } from '../../src/shared/protocol.js'; +import type { InboundContext, InboundResult, ProtocolModule, ProtocolModuleHost } from '../../src/shared/protocolModule.js'; import type { ErrorMessage, ResponseMessage } from '../../src/shared/responseMessage.js'; import { toArrayAsync } from '../../src/shared/responseMessage.js'; +import type { TaskManagerOptions } from '../../src/shared/taskManager.js'; +import { TaskManager } from '../../src/shared/taskManager.js'; import type { Transport, TransportSendOptions } from '../../src/shared/transport.js'; import type { ClientCapabilities, JSONRPCErrorResponse, JSONRPCMessage, + JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, JSONRPCResultResponse, Notification, Request, @@ -33,23 +38,40 @@ import type { import { ProtocolError, ProtocolErrorCode, RELATED_TASK_META_KEY } from '../../src/types/types.js'; import { SdkError, SdkErrorCode } from '../../src/errors/sdkErrors.js'; +// Test Protocol subclass that exposes registerModule for testing +class TestProtocolImpl extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected buildContext(ctx: BaseContext): BaseContext { + return ctx; + } + public registerTestModule(module: ProtocolModule): void { + this.registerModule(module); + } +} + +function createTestProtocol(taskOptions?: TaskManagerOptions): TestProtocolImpl { + const p = new TestProtocolImpl(); + if (taskOptions) { + p.registerTestModule(new TaskManager(taskOptions)); + } + return p; +} + // Type helper for accessing private/protected Protocol properties in tests -interface TestProtocol { - _taskMessageQueue?: TaskMessageQueue; - _requestResolvers: Map void>; +interface TestProtocolInternals { _responseHandlers: Map void>; - _taskProgressTokens: Map; - _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; - requestTaskStore: (request: Request, authInfo: unknown) => TaskStore; - // Protected methods (exposed for testing) - _requestWithSchema: (request: Request, resultSchema: T, options?: unknown) => Promise>; - listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; - cancelTask: (params: { taskId: string }) => Promise; - _requestStreamWithSchema: ( - request: Request, - schema: ZodType, - options?: unknown - ) => AsyncGenerator>; + // TaskManager modules are accessible via the `_modules` private field + _modules: Array<{ + _taskMessageQueue?: TaskMessageQueue; + _requestResolvers: Map void>; + _taskProgressTokens: Map; + _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; + listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; + cancelTask: (params: { taskId: string }) => Promise; + requestStream: (request: Request, schema: ZodType, options?: unknown) => AsyncGenerator>; + }>; } // Mock Transport class @@ -160,7 +182,9 @@ function assertQueuedRequest(o?: QueuedMessage): asserts o is QueuedRequest { */ // eslint-disable-next-line @typescript-eslint/no-explicit-any function testRequest(proto: Protocol, request: Request, resultSchema: ZodType, options?: any) { - return (proto as unknown as TestProtocol)._requestWithSchema(request, resultSchema, options); + return ( + proto as unknown as { _requestWithSchema: (request: Request, resultSchema: ZodType, options?: unknown) => Promise } + )._requestWithSchema(request, resultSchema, options); } describe('protocol tests', () => { @@ -171,16 +195,7 @@ describe('protocol tests', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + protocol = createTestProtocol(); }); test('should throw a timeout error if the request exceeds the timeout', async () => { @@ -642,16 +657,7 @@ describe('protocol tests', () => { it('should NOT debounce a notification that has parameters', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); + protocol = new TestProtocolImpl({ debouncedNotificationMethods: ['test/debounced_with_params'] }); await protocol.connect(transport); // ACT @@ -668,16 +674,7 @@ describe('protocol tests', () => { it('should NOT debounce a notification that has a relatedRequestId', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); + protocol = new TestProtocolImpl({ debouncedNotificationMethods: ['test/debounced_with_options'] }); await protocol.connect(transport); // ACT @@ -692,16 +689,7 @@ describe('protocol tests', () => { it('should clear pending debounced notifications on connection close', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new TestProtocolImpl({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); // ACT @@ -721,16 +709,7 @@ describe('protocol tests', () => { it('should debounce multiple synchronous calls when params property is omitted', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new TestProtocolImpl({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); // ACT @@ -753,16 +732,7 @@ describe('protocol tests', () => { it('should debounce calls when params is explicitly undefined', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new TestProtocolImpl({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); // ACT @@ -783,16 +753,7 @@ describe('protocol tests', () => { it('should send non-debounced notifications immediately and multiple times', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method + protocol = new TestProtocolImpl({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method await protocol.connect(transport); // ACT @@ -821,16 +782,7 @@ describe('protocol tests', () => { it('should handle sequential batches of debounced notifications correctly', async () => { // ARRANGE - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ debouncedNotificationMethods: ['test/debounced'] }); + protocol = new TestProtocolImpl({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); // ACT (Batch 1) @@ -1043,16 +995,7 @@ describe('Task-based execution', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = createTestProtocol({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }); }); describe('request with task metadata', () => { @@ -1184,7 +1127,7 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); }); @@ -1208,7 +1151,7 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task-456'); @@ -1253,7 +1196,7 @@ describe('Task-based execution', () => { expect(sendSpy).not.toHaveBeenCalled(); // Verify the message was queued with all metadata combined - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task'); @@ -1275,9 +1218,32 @@ describe('Task-based execution', () => { }); describe('task status transitions', () => { - it('should be handled by tool implementors, not protocol layer', () => { - // Task status management is now the responsibility of tool implementors - expect(true).toBe(true); + it('should not auto-update task status when a task-augmented request completes', async () => { + const mockTaskStore = createMockTaskStore(); + const localProtocol = createTestProtocol({ taskStore: mockTaskStore }); + const localTransport = new MockTransport(); + await localProtocol.connect(localTransport); + + localProtocol.setRequestHandler('tools/call', async () => { + return { content: [{ type: 'text', text: 'done' }] }; + }); + + localTransport.onmessage?.({ + jsonrpc: '2.0', + id: 42, + method: 'tools/call', + params: { + name: 'test-tool', + arguments: {}, + task: { ttl: 60000, pollInterval: 1000 } + } + }); + + // Allow the request to be processed + await new Promise(resolve => setTimeout(resolve, 20)); + + // The protocol layer must not call updateTaskStatus — that is solely the tool implementor's responsibility + expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalled(); }); it('should handle requests with task creation parameters in top-level task field', async () => { @@ -1285,16 +1251,7 @@ describe('Task-based execution', () => { // rather than in _meta, and that task management is handled by tool implementors const mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = createTestProtocol({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1326,6 +1283,216 @@ describe('Task-based execution', () => { }); }); + describe('assertTaskHandlerCapability', () => { + it('should invoke assertTaskHandlerCapability callback when an inbound task-augmented request arrives', async () => { + const assertTaskHandlerCapability = vi.fn(); + const localProtocol = createTestProtocol({ + taskStore: createMockTaskStore(), + assertTaskHandlerCapability + }); + const localTransport = new MockTransport(); + await localProtocol.connect(localTransport); + + localProtocol.setRequestHandler('tools/call', async () => { + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + localTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'my-tool', + arguments: {}, + task: { ttl: 30000, pollInterval: 500 } + } + }); + + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(assertTaskHandlerCapability).toHaveBeenCalledOnce(); + expect(assertTaskHandlerCapability).toHaveBeenCalledWith('tools/call'); + }); + + it('should not invoke assertTaskHandlerCapability for non-task-augmented requests', async () => { + const assertTaskHandlerCapability = vi.fn(); + const localProtocol = createTestProtocol({ + taskStore: createMockTaskStore(), + assertTaskHandlerCapability + }); + const localTransport = new MockTransport(); + await localProtocol.connect(localTransport); + + localProtocol.setRequestHandler('tools/call', async () => { + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + localTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { name: 'my-tool', arguments: {} } + }); + + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(assertTaskHandlerCapability).not.toHaveBeenCalled(); + }); + + it('should not throw when assertTaskHandlerCapability is not provided', async () => { + // No assertTaskHandlerCapability callback — protocol must not error + const localProtocol = createTestProtocol({ taskStore: createMockTaskStore() }); + const localTransport = new MockTransport(); + const localSendSpy = vi.spyOn(localTransport, 'send'); + await localProtocol.connect(localTransport); + + localProtocol.setRequestHandler('tools/call', async () => { + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + localTransport.onmessage?.({ + jsonrpc: '2.0', + id: 3, + method: 'tools/call', + params: { + name: 'my-tool', + arguments: {}, + task: { ttl: 30000, pollInterval: 500 } + } + }); + + await new Promise(resolve => setTimeout(resolve, 20)); + + // The response should be a success, not an error + expect(localSendSpy).toHaveBeenCalledOnce(); + const response = localSendSpy.mock.calls[0]![0] as { error?: unknown }; + expect(response.error).toBeUndefined(); + }); + + it('should send a JSON-RPC error response when assertTaskHandlerCapability throws', async () => { + const assertTaskHandlerCapability = vi.fn(() => { + throw new Error('Task handler capability not declared'); + }); + const localProtocol = createTestProtocol({ + taskStore: createMockTaskStore(), + assertTaskHandlerCapability + }); + const localTransport = new MockTransport(); + const sendSpy = vi.spyOn(localTransport, 'send'); + await localProtocol.connect(localTransport); + + localProtocol.setRequestHandler('tools/call', async () => { + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + localTransport.onmessage?.({ + jsonrpc: '2.0', + id: 4, + method: 'tools/call', + params: { + name: 'my-tool', + arguments: {}, + task: { ttl: 30000, pollInterval: 500 } + } + }); + + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(assertTaskHandlerCapability).toHaveBeenCalledOnce(); + // Verify the error was sent back as a JSON-RPC error response (matching main's behavior) + expect(sendSpy).toHaveBeenCalledOnce(); + const response = sendSpy.mock.calls[0]![0] as { error?: { message?: string } }; + expect(response.error).toBeDefined(); + expect(response.error!.message).toBe('Task handler capability not declared'); + }); + }); + + describe('pollInterval fallback in _waitForTaskUpdate', () => { + it('should fall back to defaultTaskPollInterval when task has no pollInterval', async () => { + const mockTaskStore = createMockTaskStore(); + + const task = await mockTaskStore.createTask({ pollInterval: undefined as unknown as number }, 1, { + method: 'test/method', + params: {} + }); + // Override pollInterval to be undefined on the stored task + const storedTask = await mockTaskStore.getTask(task.taskId); + if (storedTask) { + storedTask.pollInterval = undefined as unknown as number; + } + + const localProtocol = createTestProtocol({ + taskStore: mockTaskStore, + defaultTaskPollInterval: 100 + }); + const localTransport = new MockTransport(); + const sendSpy = vi.spyOn(localTransport, 'send'); + await localProtocol.connect(localTransport); + + // Send tasks/result request — task is non-terminal so it will poll + localTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'tasks/result', + params: { taskId: task.taskId } + }); + + // Use a macrotask to complete the task AFTER the handler has entered polling + setTimeout(() => { + mockTaskStore.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'done' }] }); + }, 10); + + // At 50ms the 100ms poll hasn't fired yet + await new Promise(resolve => setTimeout(resolve, 50)); + expect(sendSpy).not.toHaveBeenCalled(); + + // At 200ms the poll should have fired and found the completed task + await new Promise(resolve => setTimeout(resolve, 150)); + expect(sendSpy).toHaveBeenCalled(); + }); + + it('should fall back to 1000ms when both pollInterval and defaultTaskPollInterval are absent', async () => { + const mockTaskStore = createMockTaskStore(); + + const task = await mockTaskStore.createTask({ pollInterval: undefined as unknown as number }, 1, { + method: 'test/method', + params: {} + }); + const storedTask = await mockTaskStore.getTask(task.taskId); + if (storedTask) { + storedTask.pollInterval = undefined as unknown as number; + } + + // No defaultTaskPollInterval — should fall back to 1000ms + const localProtocol = createTestProtocol({ + taskStore: mockTaskStore + }); + const localTransport = new MockTransport(); + const sendSpy = vi.spyOn(localTransport, 'send'); + await localProtocol.connect(localTransport); + + localTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'tasks/result', + params: { taskId: task.taskId } + }); + + // Complete the task via macrotask so the handler enters polling first + setTimeout(() => { + mockTaskStore.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'done' }] }); + }, 10); + + // At 500ms the 1000ms poll hasn't fired yet + await new Promise(resolve => setTimeout(resolve, 500)); + expect(sendSpy).not.toHaveBeenCalled(); + + // At 1100ms the poll should have fired + await new Promise(resolve => setTimeout(resolve, 600)); + expect(sendSpy).toHaveBeenCalled(); + }); + }); + describe('listTasks', () => { it('should handle tasks/list requests and return tasks from TaskStore', async () => { const listedTasks = createLatch(); @@ -1357,16 +1524,7 @@ describe('Task-based execution', () => { } ); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = createTestProtocol({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1421,16 +1579,7 @@ describe('Task-based execution', () => { } ); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = createTestProtocol({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1470,16 +1619,7 @@ describe('Task-based execution', () => { onList: () => listedTasks.releaseLatch() }); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = createTestProtocol({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1506,16 +1646,7 @@ describe('Task-based execution', () => { const mockTaskStore = createMockTaskStore(); mockTaskStore.listTasks.mockRejectedValue(new Error('Invalid cursor: bad-cursor')); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + protocol = createTestProtocol({ taskStore: mockTaskStore }); await protocol.connect(transport); @@ -1544,7 +1675,7 @@ describe('Task-based execution', () => { it('should call listTasks method from client side', async () => { await protocol.connect(transport); - const listTasksPromise = (protocol as unknown as TestProtocol).listTasks(); + const listTasksPromise = (protocol as unknown as TestProtocolInternals)._modules[0]!.listTasks(); // Simulate server response setTimeout(() => { @@ -1584,7 +1715,7 @@ describe('Task-based execution', () => { it('should call listTasks with cursor from client side', async () => { await protocol.connect(transport); - const listTasksPromise = (protocol as unknown as TestProtocol).listTasks({ cursor: 'task-10' }); + const listTasksPromise = (protocol as unknown as TestProtocolInternals)._modules[0]!.listTasks({ cursor: 'task-10' }); // Simulate server response setTimeout(() => { @@ -1643,16 +1774,7 @@ describe('Task-based execution', () => { throw new Error('Task not found'); }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1688,16 +1810,7 @@ describe('Task-based execution', () => { mockTaskStore.getTask.mockResolvedValue(null); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1739,16 +1852,7 @@ describe('Task-based execution', () => { mockTaskStore.updateTaskStatus.mockClear(); mockTaskStore.getTask.mockResolvedValue(completedTask); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1779,7 +1883,7 @@ describe('Task-based execution', () => { it('should call cancelTask method from client side', async () => { await protocol.connect(transport); - const deleteTaskPromise = (protocol as unknown as TestProtocol).cancelTask({ taskId: 'task-to-delete' }); + const deleteTaskPromise = (protocol as unknown as TestProtocolInternals)._modules[0]!.cancelTask({ taskId: 'task-to-delete' }); // Simulate server response - per MCP spec, CancelTaskResult is Result & Task setTimeout(() => { @@ -1824,16 +1928,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); await serverProtocol.connect(serverTransport); @@ -1877,16 +1972,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1929,16 +2015,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -1969,16 +2046,7 @@ describe('Task-based execution', () => { params: {} }); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2017,16 +2085,7 @@ describe('Task-based execution', () => { await mockTaskStore.storeTaskResult(task.taskId, 'completed', testResult); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2063,16 +2122,7 @@ describe('Task-based execution', () => { it('should propagate related-task metadata to handler sendRequest and sendNotification', async () => { const mockTaskStore = createMockTaskStore(); - const serverProtocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const serverProtocol = createTestProtocol({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); const serverTransport = new MockTransport(); const sendSpy = vi.spyOn(serverTransport, 'send'); @@ -2124,7 +2174,7 @@ describe('Task-based execution', () => { // Verify the notification was QUEUED (not sent via transport) // Messages with relatedTask metadata should be queued for delivery via tasks/result // to prevent duplicate delivery for bidirectional transports - const queue = (serverProtocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (serverProtocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('parent-task-123'); @@ -2149,16 +2199,7 @@ describe('Request Cancellation vs Task Cancellation', () => { beforeEach(() => { transport = new MockTransport(); taskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + protocol = createTestProtocol({ taskStore }); }); describe('notifications/cancelled behavior', () => { @@ -2437,30 +2478,12 @@ describe('Progress notification support for tasks', () => { beforeEach(() => { transport = new MockTransport(); sendSpy = vi.spyOn(transport, 'send'); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + protocol = createTestProtocol({ taskStore: createMockTaskStore() }); }); it('should maintain progress token association after CreateTaskResult is returned', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = createTestProtocol({ taskStore }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2545,16 +2568,7 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task reaches terminal status (completed)', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = createTestProtocol({ taskStore }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2644,15 +2658,22 @@ describe('Progress notification support for tasks', () => { expect(progressCallback).toHaveBeenCalledTimes(1); // Verify the task-progress association was created - const taskProgressTokens = (protocol as unknown as TestProtocol)._taskProgressTokens as Map; + const taskProgressTokens = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskProgressTokens as Map; expect(taskProgressTokens.has(taskId)).toBe(true); expect(taskProgressTokens.get(taskId)).toBe(progressToken); - // Simulate task completion by calling through the protocol's task store - // This will trigger the cleanup logic - const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; - const requestTaskStore = (protocol as unknown as TestProtocol).requestTaskStore(mockRequest, undefined); - await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); + // Simulate task completion by triggering an inbound request whose handler + // calls storeTaskResult through the task context (the public RequestTaskStore API). + // This is equivalent to how a real server handler would complete a task. + protocol.setRequestHandler('ping', async (_request, ctx) => { + if (ctx.task?.store) { + await ctx.task.store.storeTaskResult(taskId, 'completed', { content: [] }); + } + return {}; + }); + if (transport.onmessage) { + transport.onmessage({ jsonrpc: '2.0', id: 999, method: 'ping', params: {} }); + } // Wait for all async operations including notification sending to complete await new Promise(resolve => setTimeout(resolve, 50)); @@ -2682,16 +2703,7 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task reaches terminal status (failed)', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = createTestProtocol({ taskStore }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2783,16 +2795,7 @@ describe('Progress notification support for tasks', () => { it('should stop progress notifications when task is cancelled', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = createTestProtocol({ taskStore }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -2881,16 +2884,7 @@ describe('Progress notification support for tasks', () => { it('should use the same progressToken throughout task lifetime', async () => { const taskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore }); + const protocol = createTestProtocol({ taskStore }); const transport = new MockTransport(); const sendSpy = vi.spyOn(transport, 'send'); @@ -3153,16 +3147,7 @@ describe('Message interception for task-related notifications', () => { it('should queue notifications with io.modelcontextprotocol/related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3181,7 +3166,7 @@ describe('Message interception for task-related notifications', () => { ); // Access the private queue to verify the message was queued - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(task.taskId); @@ -3193,16 +3178,7 @@ describe('Message interception for task-related notifications', () => { it('should not queue notifications without related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3223,16 +3199,7 @@ describe('Message interception for task-related notifications', () => { it('should propagate queue overflow errors without failing the task', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); await server.connect(transport); @@ -3273,16 +3240,7 @@ describe('Message interception for task-related notifications', () => { it('should extract task ID correctly from metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3300,7 +3258,7 @@ describe('Message interception for task-related notifications', () => { ); // Verify the message was queued under the correct task ID - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(taskId); expect(queuedMessage).toBeDefined(); @@ -3309,16 +3267,7 @@ describe('Message interception for task-related notifications', () => { it('should preserve message order when queuing multiple notifications', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3339,7 +3288,7 @@ describe('Message interception for task-related notifications', () => { } // Verify messages are in FIFO order - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); for (let i = 0; i < 5; i++) { @@ -3354,16 +3303,7 @@ describe('Message interception for task-related requests', () => { it('should queue requests with io.modelcontextprotocol/related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3384,7 +3324,7 @@ describe('Message interception for task-related requests', () => { ); // Access the private queue to verify the message was queued - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(task.taskId); @@ -3394,7 +3334,7 @@ describe('Message interception for task-related requests', () => { // Verify resolver is stored in _requestResolvers map (not in the message) const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; - const resolvers = (server as unknown as TestProtocol)._requestResolvers; + const resolvers = (server as unknown as TestProtocolInternals)._modules[0]!._requestResolvers; expect(resolvers.has(requestId)).toBe(true); // Clean up - send a response to prevent hanging promise @@ -3410,16 +3350,7 @@ describe('Message interception for task-related requests', () => { it('should not queue requests without related-task metadata', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3434,7 +3365,7 @@ describe('Message interception for task-related requests', () => { ); // Verify queue exists (but we don't track size in the new API) - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Clean up - send a response @@ -3453,16 +3384,7 @@ describe('Message interception for task-related requests', () => { it('should store request resolver for response routing', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); await server.connect(transport); @@ -3483,11 +3405,11 @@ describe('Message interception for task-related requests', () => { ); // Verify the resolver was stored - const resolvers = (server as unknown as TestProtocol)._requestResolvers; + const resolvers = (server as unknown as TestProtocolInternals)._modules[0]!._requestResolvers; expect(resolvers.size).toBe(1); // Get the request ID from the queue - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; const queuedMessage = await queue!.dequeue(task.taskId); const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; @@ -3510,16 +3432,7 @@ describe('Message interception for task-related requests', () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); const queue = new InMemoryTaskMessageQueue(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: queue }); + const server = createTestProtocol({ taskStore, taskMessageQueue: queue }); await server.connect(transport); @@ -3579,16 +3492,7 @@ describe('Message interception for task-related requests', () => { it('should log error when resolver is missing for side-channeled request', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); const errors: Error[] = []; server.onerror = (error: Error) => { @@ -3614,12 +3518,12 @@ describe('Message interception for task-related requests', () => { ); // Get the request ID from the queue - const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queue = (server as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; const queuedMessage = await queue!.dequeue(task.taskId); const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; // Manually delete the resolver to simulate missing resolver - (server as unknown as TestProtocol)._requestResolvers.delete(requestId); + (server as unknown as TestProtocolInternals)._modules[0]!._requestResolvers.delete(requestId); // Enqueue a response message - this should trigger the error logging when processed await queue!.enqueue(task.taskId, { @@ -3659,16 +3563,7 @@ describe('Message interception for task-related requests', () => { it('should propagate queue overflow errors for requests without failing the task', async () => { const taskStore = createMockTaskStore(); const transport = new MockTransport(); - const server = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); await server.connect(transport); @@ -3723,16 +3618,7 @@ describe('Message Interception', () => { beforeEach(() => { transport = new MockTransport(); mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = createTestProtocol({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); }); describe('messages with relatedTask metadata are queued', () => { @@ -3753,7 +3639,7 @@ describe('Message Interception', () => { ); // Access the private _taskMessageQueue to verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('task-123'); @@ -3782,7 +3668,7 @@ describe('Message Interception', () => { ); // Access the private _taskMessageQueue to verify the message was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue('task-456'); @@ -3791,7 +3677,7 @@ describe('Message Interception', () => { // Verify resolver is stored in _requestResolvers map (not in the message) const requestId = queuedMessage.message.id as RequestId; - const resolvers = (protocol as unknown as TestProtocol)._requestResolvers; + const resolvers = (protocol as unknown as TestProtocolInternals)._modules[0]!._requestResolvers; expect(resolvers.has(requestId)).toBe(true); // Clean up the pending request @@ -3831,7 +3717,7 @@ describe('Message Interception', () => { await new Promise(resolve => setTimeout(resolve, 50)); // Verify the response was queued instead of sent directly - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(taskId); @@ -3869,7 +3755,7 @@ describe('Message Interception', () => { await new Promise(resolve => setTimeout(resolve, 50)); // Verify the error was queued instead of sent directly - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(taskId); @@ -3903,7 +3789,7 @@ describe('Message Interception', () => { await new Promise(resolve => setTimeout(resolve, 50)); // Verify the error was queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); const queuedMessage = await queue!.dequeue(taskId); @@ -3960,7 +3846,7 @@ describe('Message Interception', () => { // Access the private _taskMessageQueue to verify no messages were queued // Since we can't check if queues exist without messages, we verify that // attempting to dequeue returns undefined (no messages queued) - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); }); @@ -3983,7 +3869,7 @@ describe('Message Interception', () => { // Access the private _taskMessageQueue to verify no messages were queued // Since we can't check if queues exist without messages, we verify that // attempting to dequeue returns undefined (no messages queued) - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Clean up the pending request @@ -4017,7 +3903,7 @@ describe('Message Interception', () => { ); // Verify the message was queued under the correct task ID - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Verify a message was queued for this task @@ -4048,7 +3934,7 @@ describe('Message Interception', () => { ); // Verify the message was queued under the correct task ID - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Clean up the pending request @@ -4072,7 +3958,7 @@ describe('Message Interception', () => { await protocol.notification({ method: 'test3', params: {} }, { relatedTask: { taskId: 'task-A' } }); // Verify messages are queued under correct task IDs - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Verify two messages for task-A @@ -4095,7 +3981,7 @@ describe('Message Interception', () => { it('should queue messages for a task', async () => { await protocol.connect(transport); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Send first message for a task @@ -4110,7 +3996,7 @@ describe('Message Interception', () => { it('should queue multiple messages for the same task', async () => { await protocol.connect(transport); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Send first message @@ -4131,7 +4017,7 @@ describe('Message Interception', () => { it('should queue messages for different tasks separately', async () => { await protocol.connect(transport); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Send messages for different tasks @@ -4162,7 +4048,7 @@ describe('Message Interception', () => { { relatedTask } ); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; const queuedMessage = await queue!.dequeue('task-meta-123'); // Verify the metadata is preserved in the queued message @@ -4188,7 +4074,7 @@ describe('Message Interception', () => { { relatedTask } ); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; const queuedMessage = await queue!.dequeue('task-meta-456'); // Verify the metadata is preserved in the queued message @@ -4225,7 +4111,7 @@ describe('Message Interception', () => { } ); - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; const queuedMessage = await queue!.dequeue('task-preserve-meta'); // Verify both existing and new metadata are preserved @@ -4248,16 +4134,7 @@ describe('Queue lifecycle management', () => { beforeEach(() => { transport = new MockTransport(); mockTaskStore = createMockTaskStore(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + protocol = createTestProtocol({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); }); describe('queue cleanup on task completion', () => { @@ -4273,7 +4150,7 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); // Verify messages are queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Verify messages can be dequeued @@ -4283,7 +4160,7 @@ describe('Queue lifecycle management', () => { expect(msg2).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocolInternals)._modules[0]!._clearTaskQueue(taskId); // After cleanup, no more messages should be available const msg3 = await queue!.dequeue(taskId); @@ -4319,7 +4196,7 @@ describe('Queue lifecycle management', () => { await resultPromise; // Verify queue is cleared after delivery (no messages available) - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; const msg = await queue!.dequeue(taskId); expect(msg).toBeUndefined(); }); @@ -4337,7 +4214,7 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); // Verify message is queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; const msg1 = await queue!.dequeue(taskId); expect(msg1).toBeDefined(); @@ -4381,7 +4258,7 @@ describe('Queue lifecycle management', () => { ).catch(err => err); // Verify request is queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Mock task as non-terminal @@ -4422,7 +4299,7 @@ describe('Queue lifecycle management', () => { await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); // Verify messages are queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Verify messages can be dequeued @@ -4432,7 +4309,7 @@ describe('Queue lifecycle management', () => { expect(msg2).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocolInternals)._modules[0]!._clearTaskQueue(taskId); // After cleanup, no more messages should be available const msg3 = await queue!.dequeue(taskId); @@ -4457,11 +4334,11 @@ describe('Queue lifecycle management', () => { ).catch(err => err); // Verify request is queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocolInternals)._modules[0]!._clearTaskQueue(taskId); // Verify the request promise is rejected const result = (await requestPromise) as Error; @@ -4511,11 +4388,11 @@ describe('Queue lifecycle management', () => { ).catch(err => err); // Verify requests are queued - const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queue = (protocol as unknown as TestProtocolInternals)._modules[0]!._taskMessageQueue; expect(queue).toBeDefined(); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocolInternals)._modules[0]!._clearTaskQueue(taskId); // Verify all request promises are rejected const result1 = (await request1Promise) as Error; @@ -4552,7 +4429,7 @@ describe('Queue lifecycle management', () => { ).catch(err => err); // Get the request ID that was sent - const requestResolvers = (protocol as unknown as TestProtocol)._requestResolvers; + const requestResolvers = (protocol as unknown as TestProtocolInternals)._modules[0]!._requestResolvers; const initialResolverCount = requestResolvers.size; expect(initialResolverCount).toBeGreaterThan(0); @@ -4561,7 +4438,7 @@ describe('Queue lifecycle management', () => { mockTaskStore.getTask.mockResolvedValue(completedTask); // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + (protocol as unknown as TestProtocolInternals)._modules[0]!._clearTaskQueue(taskId); // Verify request promise is rejected const result = (await requestPromise) as Error; @@ -4583,22 +4460,13 @@ describe('requestStream() method', () => { test('should yield result immediately for non-task requests', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); // Start the request stream const streamPromise = (async () => { const messages = []; - const stream = (protocol as unknown as TestProtocol)._requestStreamWithSchema( + const stream = (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ); @@ -4629,22 +4497,13 @@ describe('requestStream() method', () => { test('should yield error message on request failure', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); // Start the request stream const streamPromise = (async () => { const messages = []; - const stream = (protocol as unknown as TestProtocol)._requestStreamWithSchema( + const stream = (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ); @@ -4678,16 +4537,7 @@ describe('requestStream() method', () => { test('should handle cancellation via AbortSignal', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); const abortController = new AbortController(); @@ -4697,7 +4547,7 @@ describe('requestStream() method', () => { // Start the request stream with already-aborted signal const messages = []; - const stream = (protocol as unknown as TestProtocol)._requestStreamWithSchema( + const stream = (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { @@ -4719,20 +4569,11 @@ describe('requestStream() method', () => { describe('Error responses', () => { test('should yield error as terminal message for server error response', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol)._requestStreamWithSchema( + (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -4764,20 +4605,11 @@ describe('requestStream() method', () => { vi.useFakeTimers(); try { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol)._requestStreamWithSchema( + (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { @@ -4806,16 +4638,7 @@ describe('requestStream() method', () => { test('should yield error as terminal message for cancellation', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); const abortController = new AbortController(); @@ -4823,7 +4646,7 @@ describe('requestStream() method', () => { // Collect messages const messages = await toArrayAsync( - (protocol as unknown as TestProtocol)._requestStreamWithSchema( + (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { @@ -4842,20 +4665,11 @@ describe('requestStream() method', () => { test('should not yield any messages after error message', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol)._requestStreamWithSchema( + (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -4897,20 +4711,11 @@ describe('requestStream() method', () => { test('should yield error as terminal message for task failure', async () => { const transport = new MockTransport(); const mockTaskStore = createMockTaskStore(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })({ taskStore: mockTaskStore }); + const protocol = createTestProtocol({ taskStore: mockTaskStore }); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol)._requestStreamWithSchema( + (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -4960,23 +4765,14 @@ describe('requestStream() method', () => { test('should yield error as terminal message for network error', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); // Override send to simulate network error transport.send = vi.fn().mockRejectedValue(new Error('Network error')); const messages = await toArrayAsync( - (protocol as unknown as TestProtocol)._requestStreamWithSchema( + (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -4991,20 +4787,11 @@ describe('requestStream() method', () => { test('should ensure error is always the final message', async () => { const transport = new MockTransport(); - const protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(): void {} - })(); + const protocol = createTestProtocol({}); await protocol.connect(transport); const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocol)._requestStreamWithSchema( + (protocol as unknown as TestProtocolInternals)._modules[0]!.requestStream( { method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema ) @@ -5050,20 +4837,7 @@ describe('Error handling for missing resolvers', () => { taskMessageQueue = new InMemoryTaskMessageQueue(); errorHandler = vi.fn(); - protocol = new (class extends Protocol { - protected assertCapabilityForMethod(_method: string): void {} - protected assertNotificationCapability(_method: string): void {} - protected assertRequestHandlerCapability(_method: string): void {} - protected assertTaskCapability(_method: string): void {} - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertTaskHandlerCapability(_method: string): void {} - })({ - taskStore, - taskMessageQueue, - defaultTaskPollInterval: 100 - }); + protocol = createTestProtocol({ taskStore, taskMessageQueue, defaultTaskPollInterval: 100 }); // @ts-expect-error deliberately overriding error handler with mock protocol.onerror = errorHandler; @@ -5089,7 +4863,7 @@ describe('Error handling for missing resolvers', () => { }); // Set up the GetTaskPayloadRequest handler to process the message - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; // Simulate dequeuing and processing the response const queuedMessage = await taskMessageQueue.dequeue(task.taskId); @@ -5100,7 +4874,7 @@ describe('Error handling for missing resolvers', () => { if (queuedMessage && queuedMessage.type === 'response') { const responseMessage = queuedMessage.message as JSONRPCResultResponse; const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(requestId); if (!resolver) { // This simulates what happens in the actual handler @@ -5176,8 +4950,8 @@ describe('Error handling for missing resolvers', () => { }); // Clear the task queue (simulating cancellation) - const testProtocol = protocol as unknown as TestProtocol; - await testProtocol._clearTaskQueue(task.taskId); + const testProtocol = protocol as unknown as TestProtocolInternals; + await testProtocol._modules[0]!._clearTaskQueue(task.taskId); // Verify error was logged for missing resolver expect(errorHandler).toHaveBeenCalledWith( @@ -5197,8 +4971,8 @@ describe('Error handling for missing resolvers', () => { const resolverMock = vi.fn(); // Store a resolver - const testProtocol = protocol as unknown as TestProtocol; - testProtocol._requestResolvers.set(requestId, resolverMock); + const testProtocol = protocol as unknown as TestProtocolInternals; + testProtocol._modules[0]!._requestResolvers.set(requestId, resolverMock); // Enqueue a request await taskMessageQueue.enqueue(task.taskId, { @@ -5213,7 +4987,7 @@ describe('Error handling for missing resolvers', () => { }); // Clear the task queue - await testProtocol._clearTaskQueue(task.taskId); + await testProtocol._modules[0]!._clearTaskQueue(task.taskId); // Verify resolver was called with cancellation error expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); @@ -5224,7 +4998,7 @@ describe('Error handling for missing resolvers', () => { expect(calledError.message).toContain('Task cancelled or completed'); // Verify resolver was removed - expect(testProtocol._requestResolvers.has(requestId)).toBe(false); + expect(testProtocol._modules[0]!._requestResolvers.has(requestId)).toBe(false); }); it('should handle mixed messages during cleanup', async () => { @@ -5233,12 +5007,12 @@ describe('Error handling for missing resolvers', () => { // Create a task const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; // Enqueue multiple messages: request with resolver, request without, notification const requestId1 = 42; const resolverMock = vi.fn(); - testProtocol._requestResolvers.set(requestId1, resolverMock); + testProtocol._modules[0]!._requestResolvers.set(requestId1, resolverMock); await taskMessageQueue.enqueue(task.taskId, { type: 'request', @@ -5273,7 +5047,7 @@ describe('Error handling for missing resolvers', () => { }); // Clear the task queue - await testProtocol._clearTaskQueue(task.taskId); + await testProtocol._modules[0]!._clearTaskQueue(task.taskId); // Verify resolver was called for first request expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); @@ -5300,7 +5074,7 @@ describe('Error handling for missing resolvers', () => { it('should log error when response handler is missing for side-channeled request', async () => { await protocol.connect(transport); - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; const messageId = 123; // Create a response resolver without a corresponding response handler @@ -5351,10 +5125,10 @@ describe('Error handling for missing resolvers', () => { const processMessage = async () => { const msg = await taskMessageQueue.dequeue(task.taskId); if (msg && msg.type === 'response') { - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; const responseMessage = msg.message as JSONRPCResultResponse; const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(requestId); if (!resolver) { protocol.onerror?.(new Error(`Response handler missing for request ${requestId}`)); } @@ -5380,10 +5154,10 @@ describe('Error handling for missing resolvers', () => { timestamp: Date.now() }); - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; // This should not throw - await expect(testProtocol._clearTaskQueue(task.taskId)).resolves.not.toThrow(); + await expect(testProtocol._modules[0]!._clearTaskQueue(task.taskId)).resolves.not.toThrow(); }); }); @@ -5396,8 +5170,8 @@ describe('Error handling for missing resolvers', () => { const resolverMock = vi.fn(); // Store a resolver - const testProtocol = protocol as unknown as TestProtocol; - testProtocol._requestResolvers.set(requestId, resolverMock); + const testProtocol = protocol as unknown as TestProtocolInternals; + testProtocol._modules[0]!._requestResolvers.set(requestId, resolverMock); // Enqueue an error message await taskMessageQueue.enqueue(task.taskId, { @@ -5422,10 +5196,10 @@ describe('Error handling for missing resolvers', () => { if (queuedMessage && queuedMessage.type === 'error') { const errorMessage = queuedMessage.message as JSONRPCErrorResponse; const reqId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(reqId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(reqId); if (resolver) { - testProtocol._requestResolvers.delete(reqId); + testProtocol._modules[0]!._requestResolvers.delete(reqId); const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } @@ -5438,7 +5212,7 @@ describe('Error handling for missing resolvers', () => { expect(calledError.message).toContain('Invalid request parameters'); // Verify resolver was removed from map - expect(testProtocol._requestResolvers.has(requestId)).toBe(false); + expect(testProtocol._modules[0]!._requestResolvers.has(requestId)).toBe(false); }); it('should log error for unknown request ID in error messages', async () => { @@ -5467,10 +5241,10 @@ describe('Error handling for missing resolvers', () => { // Manually trigger the error handling logic if (queuedMessage && queuedMessage.type === 'error') { - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; const errorMessage = queuedMessage.message as JSONRPCErrorResponse; const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(requestId); if (!resolver) { protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); @@ -5493,8 +5267,8 @@ describe('Error handling for missing resolvers', () => { const resolverMock = vi.fn(); // Store a resolver - const testProtocol = protocol as unknown as TestProtocol; - testProtocol._requestResolvers.set(requestId, resolverMock); + const testProtocol = protocol as unknown as TestProtocolInternals; + testProtocol._modules[0]!._requestResolvers.set(requestId, resolverMock); // Enqueue an error message with data field await taskMessageQueue.enqueue(task.taskId, { @@ -5517,10 +5291,10 @@ describe('Error handling for missing resolvers', () => { if (queuedMessage && queuedMessage.type === 'error') { const errorMessage = queuedMessage.message as JSONRPCErrorResponse; const reqId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(reqId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(reqId); if (resolver) { - testProtocol._requestResolvers.delete(reqId); + testProtocol._modules[0]!._requestResolvers.delete(reqId); const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } @@ -5556,10 +5330,10 @@ describe('Error handling for missing resolvers', () => { const processMessage = async () => { const msg = await taskMessageQueue.dequeue(task.taskId); if (msg && msg.type === 'error') { - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; const errorMessage = msg.message as JSONRPCErrorResponse; const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(requestId); if (!resolver) { protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); } @@ -5575,16 +5349,16 @@ describe('Error handling for missing resolvers', () => { await protocol.connect(transport); const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; // Set up resolvers for multiple requests const resolver1 = vi.fn(); const resolver2 = vi.fn(); const resolver3 = vi.fn(); - testProtocol._requestResolvers.set(1, resolver1); - testProtocol._requestResolvers.set(2, resolver2); - testProtocol._requestResolvers.set(3, resolver3); + testProtocol._modules[0]!._requestResolvers.set(1, resolver1); + testProtocol._modules[0]!._requestResolvers.set(2, resolver2); + testProtocol._modules[0]!._requestResolvers.set(3, resolver3); // Enqueue mixed messages: response, error, response await taskMessageQueue.enqueue(task.taskId, { @@ -5626,17 +5400,17 @@ describe('Error handling for missing resolvers', () => { if (msg.type === 'response') { const responseMessage = msg.message as JSONRPCResultResponse; const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(requestId); if (resolver) { - testProtocol._requestResolvers.delete(requestId); + testProtocol._modules[0]!._requestResolvers.delete(requestId); resolver(responseMessage); } } else if (msg.type === 'error') { const errorMessage = msg.message as JSONRPCErrorResponse; const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(requestId); if (resolver) { - testProtocol._requestResolvers.delete(requestId); + testProtocol._modules[0]!._requestResolvers.delete(requestId); const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } @@ -5654,23 +5428,23 @@ describe('Error handling for missing resolvers', () => { expect(error.message).toContain('Request failed'); // Verify all resolvers were removed - expect(testProtocol._requestResolvers.size).toBe(0); + expect(testProtocol._modules[0]!._requestResolvers.size).toBe(0); }); it('should maintain FIFO order when processing responses and errors', async () => { await protocol.connect(transport); const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - const testProtocol = protocol as unknown as TestProtocol; + const testProtocol = protocol as unknown as TestProtocolInternals; const callOrder: number[] = []; const resolver1 = vi.fn(() => callOrder.push(1)); const resolver2 = vi.fn(() => callOrder.push(2)); const resolver3 = vi.fn(() => callOrder.push(3)); - testProtocol._requestResolvers.set(1, resolver1); - testProtocol._requestResolvers.set(2, resolver2); - testProtocol._requestResolvers.set(3, resolver3); + testProtocol._modules[0]!._requestResolvers.set(1, resolver1); + testProtocol._modules[0]!._requestResolvers.set(2, resolver2); + testProtocol._modules[0]!._requestResolvers.set(3, resolver3); // Enqueue in specific order await taskMessageQueue.enqueue(task.taskId, { @@ -5701,17 +5475,17 @@ describe('Error handling for missing resolvers', () => { if (msg.type === 'response') { const responseMessage = msg.message as JSONRPCResultResponse; const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(requestId); if (resolver) { - testProtocol._requestResolvers.delete(requestId); + testProtocol._modules[0]!._requestResolvers.delete(requestId); resolver(responseMessage); } } else if (msg.type === 'error') { const errorMessage = msg.message as JSONRPCErrorResponse; const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._requestResolvers.get(requestId); + const resolver = testProtocol._modules[0]!._requestResolvers.get(requestId); if (resolver) { - testProtocol._requestResolvers.delete(requestId); + testProtocol._modules[0]!._requestResolvers.delete(requestId); const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); resolver(error); } @@ -5723,3 +5497,566 @@ describe('Error handling for missing resolvers', () => { }); }); }); + +// -- Helper: create a no-op mock module -- +function createMockModule( + overrides?: Partial +): ProtocolModule & { boundHost?: ProtocolModuleHost; onCloseCalled: boolean } { + const mock: ProtocolModule & { boundHost?: ProtocolModuleHost; onCloseCalled: boolean } = { + boundHost: undefined, + onCloseCalled: false, + bind(host: ProtocolModuleHost) { + mock.boundHost = host; + overrides?.bind?.(host); + }, + processInboundRequest(request: JSONRPCRequest, ctx: InboundContext): InboundResult { + if (overrides?.processInboundRequest) { + return overrides.processInboundRequest(request, ctx); + } + return { + sendNotification: (notification: Notification) => ctx.sendNotification(notification), + sendRequest: ctx.sendRequest, + routeResponse: async () => false, + hasTaskCreationParams: false + }; + }, + processOutboundRequest( + jsonrpcRequest: JSONRPCRequest, + options: import('../../src/shared/protocol.js').RequestOptions | undefined, + messageId: number, + responseHandler: (response: JSONRPCResultResponse | Error) => void, + onError: (error: unknown) => void + ): { queued: boolean } { + if (overrides?.processOutboundRequest) { + return overrides.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, onError); + } + return { queued: false }; + }, + processInboundResponse( + response: JSONRPCResponse | JSONRPCErrorResponse, + messageId: number + ): { consumed: boolean; preserveProgress: boolean } { + if (overrides?.processInboundResponse) { + return overrides.processInboundResponse(response, messageId); + } + return { consumed: false, preserveProgress: false }; + }, + async processOutboundNotification( + notification: Notification, + options?: import('../../src/shared/protocol.js').NotificationOptions + ): Promise<{ queued: boolean; jsonrpcNotification?: JSONRPCNotification }> { + if (overrides?.processOutboundNotification) { + return overrides.processOutboundNotification(notification, options); + } + return { queued: false, jsonrpcNotification: { ...notification, jsonrpc: '2.0' } }; + }, + onClose() { + mock.onCloseCalled = true; + overrides?.onClose?.(); + } + }; + return mock; +} + +describe('Protocol without modules', () => { + let protocol: TestProtocolImpl; + let transport: MockTransport; + let sendSpy: MockInstance; + + beforeEach(() => { + transport = new MockTransport(); + sendSpy = vi.spyOn(transport, 'send'); + protocol = createTestProtocol(); // no task options => no modules + }); + + test('request/response flow works normally without modules', async () => { + await protocol.connect(transport); + const mockSchema = z.object({ result: z.string() }); + + const requestPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { timeout: 5000 }); + + // Simulate response + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: { result: 'hello' } + }); + + const result = await requestPromise; + expect(result).toEqual({ result: 'hello' }); + }); + + test('notifications are sent with proper JSONRPC wrapping without modules', async () => { + await protocol.connect(transport); + + await protocol.notification({ method: 'notifications/cancelled', params: { requestId: '1', reason: 'test' } }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { requestId: '1', reason: 'test' } + }), + undefined + ); + }); + + test('onClose does not error without modules', async () => { + await protocol.connect(transport); + await expect(protocol.close()).resolves.not.toThrow(); + }); + + test('inbound requests dispatch to handlers without modules', async () => { + const handler = vi.fn().mockResolvedValue({ content: 'ok' }); + protocol.setRequestHandler('ping', handler); + + await protocol.connect(transport); + transport.onmessage?.({ jsonrpc: '2.0', method: 'ping', id: 1 }); + + // Wait for async handler + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(handler).toHaveBeenCalled(); + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: 1, + result: { content: 'ok' } + }) + ); + }); +}); + +describe('ProtocolModule lifecycle', () => { + let protocol: TestProtocolImpl; + let transport: MockTransport; + + beforeEach(() => { + transport = new MockTransport(); + protocol = new TestProtocolImpl(); + }); + + test('bind() is called during registerModule()', () => { + const module = createMockModule(); + protocol.registerTestModule(module); + + expect(module.boundHost).toBeDefined(); + }); + + test('onClose() is called when transport closes', async () => { + const module = createMockModule(); + protocol.registerTestModule(module); + + await protocol.connect(transport); + await protocol.close(); + + expect(module.onCloseCalled).toBe(true); + }); + + test('onClose() is called for ALL registered modules', async () => { + const module1 = createMockModule(); + const module2 = createMockModule(); + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + await protocol.connect(transport); + await protocol.close(); + + expect(module1.onCloseCalled).toBe(true); + expect(module2.onCloseCalled).toBe(true); + }); + + test('bind() is called for each module with its own host', () => { + const module1 = createMockModule(); + const module2 = createMockModule(); + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + expect(module1.boundHost).toBeDefined(); + expect(module2.boundHost).toBeDefined(); + // Both hosts should be distinct objects + expect(module1.boundHost).not.toBe(module2.boundHost); + }); +}); + +describe('Multiple modules composition', () => { + let protocol: TestProtocolImpl; + let transport: MockTransport; + let sendSpy: MockInstance; + + beforeEach(() => { + transport = new MockTransport(); + sendSpy = vi.spyOn(transport, 'send'); + protocol = new TestProtocolImpl(); + }); + + test('both modules processInboundRequest are called', async () => { + const calls: string[] = []; + const module1 = createMockModule({ + processInboundRequest(request, ctx) { + calls.push('module1'); + return { + sendNotification: (notification: Notification) => ctx.sendNotification(notification), + sendRequest: ctx.sendRequest, + routeResponse: async () => false, + hasTaskCreationParams: false + }; + } + }); + const module2 = createMockModule({ + processInboundRequest(request, ctx) { + calls.push('module2'); + return { + sendNotification: (notification: Notification) => ctx.sendNotification(notification), + sendRequest: ctx.sendRequest, + routeResponse: async () => false, + hasTaskCreationParams: false + }; + } + }); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + protocol.setRequestHandler('ping', async () => ({})); + + await protocol.connect(transport); + transport.onmessage?.({ jsonrpc: '2.0', method: 'ping', id: 1 }); + + await new Promise(resolve => setTimeout(resolve, 10)); + expect(calls).toEqual(['module1', 'module2']); + }); + + test('processOutboundRequest: first module queues, second module not called', async () => { + const calls: string[] = []; + const module1 = createMockModule({ + processOutboundRequest() { + calls.push('module1'); + return { queued: true }; + } + }); + const module2 = createMockModule({ + processOutboundRequest() { + calls.push('module2'); + return { queued: false }; + } + }); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + await protocol.connect(transport); + const mockSchema = z.object({ result: z.string() }); + + void testRequest(protocol, { method: 'example', params: {} }, mockSchema, { timeout: 5000 }).catch(() => {}); + + expect(calls).toEqual(['module1']); + // Transport send should NOT have been called since module1 queued + expect(sendSpy).not.toHaveBeenCalled(); + }); + + test('processInboundResponse: first module consumes, second module not called', async () => { + const calls: string[] = []; + const module1 = createMockModule({ + processInboundResponse() { + calls.push('module1'); + return { consumed: true, preserveProgress: false }; + } + }); + const module2 = createMockModule({ + processInboundResponse() { + calls.push('module2'); + return { consumed: false, preserveProgress: false }; + } + }); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + await protocol.connect(transport); + + // Send a response - it should be consumed by module1 + transport.onmessage?.({ + jsonrpc: '2.0', + id: 999, + result: { data: 'test' } + }); + + expect(calls).toEqual(['module1']); + }); + + test('processInboundResponse: preserveProgress is ORed across non-consuming modules', async () => { + const module1 = createMockModule({ + processInboundResponse() { + return { consumed: false, preserveProgress: true }; + } + }); + const module2 = createMockModule({ + processInboundResponse() { + return { consumed: false, preserveProgress: false }; + } + }); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + await protocol.connect(transport); + const mockSchema = z.object({ result: z.string() }); + const onProgress = vi.fn(); + + void testRequest(protocol, { method: 'example', params: {} }, mockSchema, { + timeout: 5000, + onprogress: onProgress + }).catch(() => {}); + + // Simulate response + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: { result: 'done' } + }); + + // Progress handler should still exist because module1 said preserveProgress=true + // Send a progress notification - it should still be handled + transport.onmessage?.({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 100, + total: 100 + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + expect(onProgress).toHaveBeenCalledWith({ progress: 100, total: 100 }); + }); + + test('processOutboundNotification: first module queues, second module not called', async () => { + const calls: string[] = []; + const module1 = createMockModule({ + async processOutboundNotification() { + calls.push('module1'); + return { queued: true }; + } + }); + const module2 = createMockModule({ + async processOutboundNotification() { + calls.push('module2'); + return { queued: false, jsonrpcNotification: undefined }; + } + }); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + await protocol.connect(transport); + await protocol.notification({ method: 'notifications/cancelled', params: { requestId: '1', reason: 'test' } }); + + expect(calls).toEqual(['module1']); + // Transport send should NOT have been called since module1 queued + expect(sendSpy).not.toHaveBeenCalled(); + }); + + test('onClose: both modules onClose called', async () => { + const module1 = createMockModule(); + const module2 = createMockModule(); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + await protocol.connect(transport); + await protocol.close(); + + expect(module1.onCloseCalled).toBe(true); + expect(module2.onCloseCalled).toBe(true); + }); + + test('sendNotification/sendRequest wrappers chain correctly through modules', async () => { + const wrappedCalls: string[] = []; + + const module1 = createMockModule({ + processInboundRequest(request, ctx) { + const wrappedSendNotification = async (notification: Notification) => { + wrappedCalls.push('module1-notify'); + await ctx.sendNotification(notification); + }; + return { + sendNotification: wrappedSendNotification, + sendRequest: ctx.sendRequest, + routeResponse: async () => false, + hasTaskCreationParams: false + }; + } + }); + const module2 = createMockModule({ + processInboundRequest(request, ctx) { + const wrappedSendNotification = async (notification: Notification) => { + wrappedCalls.push('module2-notify'); + await ctx.sendNotification(notification); + }; + return { + sendNotification: wrappedSendNotification, + sendRequest: ctx.sendRequest, + routeResponse: async () => false, + hasTaskCreationParams: false + }; + } + }); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + // Register a handler that sends a notification + protocol.setRequestHandler('ping', async (_request, ctx) => { + await ctx.mcpReq.notify({ method: 'notifications/cancelled', params: { requestId: '1', reason: 'test' } }); + return {}; + }); + + await protocol.connect(transport); + transport.onmessage?.({ jsonrpc: '2.0', method: 'ping', id: 1 }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + // The last module's wrapper is what gets called by the handler + // module2's sendNotification is the one in ctx.mcpReq.notify + expect(wrappedCalls).toContain('module2-notify'); + }); + + test('routeResponse OR-chain: first module returning true wins', async () => { + const module1 = createMockModule({ + processInboundRequest(request, ctx) { + return { + sendNotification: (notification: Notification) => ctx.sendNotification(notification), + sendRequest: ctx.sendRequest, + routeResponse: async () => true, + hasTaskCreationParams: false + }; + } + }); + const routeResponseCalled = vi.fn().mockResolvedValue(false); + const module2 = createMockModule({ + processInboundRequest(request, ctx) { + return { + sendNotification: (notification: Notification) => ctx.sendNotification(notification), + sendRequest: ctx.sendRequest, + routeResponse: routeResponseCalled, + hasTaskCreationParams: false + }; + } + }); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + protocol.setRequestHandler('ping', async () => ({ content: 'ok' })); + + await protocol.connect(transport); + transport.onmessage?.({ jsonrpc: '2.0', method: 'ping', id: 1 }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + // module1 routes the response, so transport.send should NOT be called for the response + // (only ping handler internal sends may happen) + // module2's routeResponse should still be called because OR-chain tries previous first + // Actually: the OR-chain is: prevRouteResponse(msg) || moduleRouteResponse(msg) + // For module1: prev is default (false), module1 returns true => true + // For module2: prev is module1's composed (true), so module2's routeResponse is NOT called + expect(routeResponseCalled).not.toHaveBeenCalled(); + }); + + test('hasTaskCreationParams is ORed across modules', async () => { + // We verify this indirectly through the taskContext behavior + // Module1 returns hasTaskCreationParams=false, module2 returns true + const module1 = createMockModule({ + processInboundRequest(request, ctx) { + return { + sendNotification: (notification: Notification) => ctx.sendNotification(notification), + sendRequest: ctx.sendRequest, + routeResponse: async () => false, + hasTaskCreationParams: false, + taskContext: undefined + }; + } + }); + const module2 = createMockModule({ + processInboundRequest(request, ctx) { + return { + sendNotification: (notification: Notification) => ctx.sendNotification(notification), + sendRequest: ctx.sendRequest, + routeResponse: async () => false, + hasTaskCreationParams: true, + taskContext: { id: 'test-task', store: {} as any, requestedTtl: undefined } + }; + } + }); + + protocol.registerTestModule(module1); + protocol.registerTestModule(module2); + + let receivedCtx: BaseContext | undefined; + protocol.setRequestHandler('ping', async (_request, ctx) => { + receivedCtx = ctx; + return {}; + }); + + await protocol.connect(transport); + transport.onmessage?.({ jsonrpc: '2.0', method: 'ping', id: 1 }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Last non-undefined taskContext wins (module2) + expect(receivedCtx?.task).toBeDefined(); + expect(receivedCtx?.task?.id).toBe('test-task'); + }); +}); + +describe('ExperimentalTasks throw when not configured', () => { + // These tests verify the pattern used by ExperimentalClientTasks and ExperimentalServerTasks: + // accessing taskModule when it's undefined should throw a clear error. + + test('tasks accessor throws when taskModule is undefined (client pattern)', () => { + const mockClient = { taskModule: undefined } as any; + const tasksAccessor = { + get _module() { + const module = mockClient.taskModule; + if (!module) { + throw new Error('Tasks capability is not configured. Declare tasks in capabilities to use task features.'); + } + return module; + } + }; + + expect(() => tasksAccessor._module).toThrow('Tasks capability is not configured'); + }); + + test('tasks accessor throws when taskModule is undefined (server pattern)', () => { + const mockServer = { taskModule: undefined } as any; + const tasksAccessor = { + get _module() { + const module = mockServer.taskModule; + if (!module) { + throw new Error('Tasks capability is not configured. Declare tasks in capabilities to use task features.'); + } + return module; + } + }; + + expect(() => tasksAccessor._module).toThrow('Tasks capability is not configured'); + }); + + test('tasks accessor succeeds when taskModule is defined', () => { + const mockTaskModule = { getTask: vi.fn() }; + const mockClient = { taskModule: mockTaskModule } as any; + const tasksAccessor = { + get _module() { + const module = mockClient.taskModule; + if (!module) { + throw new Error('Tasks capability is not configured. Declare tasks in capabilities to use task features.'); + } + return module; + } + }; + + expect(() => tasksAccessor._module).not.toThrow(); + expect(tasksAccessor._module).toBe(mockTaskModule); + }); +}); diff --git a/packages/core/test/shared/protocolTransportHandling.test.ts b/packages/core/test/shared/protocolTransportHandling.test.ts index adc7e2234..c8f4c4b33 100644 --- a/packages/core/test/shared/protocolTransportHandling.test.ts +++ b/packages/core/test/shared/protocolTransportHandling.test.ts @@ -38,11 +38,9 @@ describe('Protocol transport handling bug', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } - protected assertTaskHandlerCapability(): void {} })(); transportA = new MockTransport('A'); diff --git a/packages/server/src/experimental/tasks/server.ts b/packages/server/src/experimental/tasks/server.ts index ec07dfaec..a826fac35 100644 --- a/packages/server/src/experimental/tasks/server.ts +++ b/packages/server/src/experimental/tasks/server.ts @@ -6,6 +6,7 @@ */ import type { + AnyObjectSchema, AnySchema, CancelTaskResult, CreateMessageRequestParams, @@ -15,12 +16,14 @@ import type { ElicitResult, GetTaskResult, ListTasksResult, + Request, RequestMethod, RequestOptions, ResponseMessage, ResultTypeMap, SchemaOutput } from '@modelcontextprotocol/core'; +import { getResultSchema } from '@modelcontextprotocol/core'; import type { Server } from '../../server/server.js'; @@ -39,6 +42,14 @@ import type { Server } from '../../server/server.js'; export class ExperimentalServerTasks { constructor(private readonly _server: Server) {} + private get _module() { + const module = this._server.taskModule; + if (!module) { + throw new Error('Tasks capability is not configured. Declare tasks in capabilities to use task features.'); + } + return module; + } + /** * Sends a request and returns an AsyncGenerator that yields response messages. * The generator is guaranteed to end with either a `'result'` or `'error'` message. @@ -56,14 +67,12 @@ export class ExperimentalServerTasks { request: { method: M; params?: Record }, options?: RequestOptions ): AsyncGenerator, void, void> { - // Delegate to the server's underlying Protocol method - type ServerWithRequestStream = { - requestStream( - request: { method: N; params?: Record }, - options?: RequestOptions - ): AsyncGenerator, void, void>; - }; - return (this._server as unknown as ServerWithRequestStream).requestStream(request, options); + const resultSchema = getResultSchema(request.method) as unknown as AnyObjectSchema; + return this._module.requestStream(request as Request, resultSchema, options) as AsyncGenerator< + ResponseMessage, + void, + void + >; } /** @@ -250,8 +259,7 @@ export class ExperimentalServerTasks { * @experimental */ async getTask(taskId: string, options?: RequestOptions): Promise { - type ServerWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise }; - return (this._server as unknown as ServerWithGetTask).getTask({ taskId }, options); + return this._module.getTask({ taskId }, options); } /** @@ -265,15 +273,7 @@ export class ExperimentalServerTasks { * @experimental */ async getTaskResult(taskId: string, resultSchema?: T, options?: RequestOptions): Promise> { - return ( - this._server as unknown as { - getTaskResult: ( - params: { taskId: string }, - resultSchema?: U, - options?: RequestOptions - ) => Promise>; - } - ).getTaskResult({ taskId }, resultSchema, options); + return this._module.getTaskResult({ taskId }, resultSchema!, options); } /** @@ -286,11 +286,7 @@ export class ExperimentalServerTasks { * @experimental */ async listTasks(cursor?: string, options?: RequestOptions): Promise { - return ( - this._server as unknown as { - listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise; - } - ).listTasks(cursor ? { cursor } : undefined, options); + return this._module.listTasks(cursor ? { cursor } : undefined, options); } /** @@ -302,10 +298,6 @@ export class ExperimentalServerTasks { * @experimental */ async cancelTask(taskId: string, options?: RequestOptions): Promise { - return ( - this._server as unknown as { - cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise; - } - ).cancelTask({ taskId }, options); + return this._module.cancelTask({ taskId }, options); } } diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 00d3e6f52..c0663adb2 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -29,6 +29,7 @@ import type { ServerCapabilities, ServerContext, ServerResult, + TaskManagerOptions, ToolResultContent, ToolUseContent } from '@modelcontextprotocol/core'; @@ -51,17 +52,27 @@ import { ProtocolError, ProtocolErrorCode, SdkError, - SdkErrorCode + SdkErrorCode, + TaskManager } from '@modelcontextprotocol/core'; import { DefaultJsonSchemaValidator } from '@modelcontextprotocol/server/_shims'; import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; +/** + * Extended tasks capability that includes runtime configuration (store, messageQueue). + * The runtime-only fields are stripped before advertising capabilities to clients. + */ +export type ServerTasksCapabilityWithRuntime = NonNullable & + Pick; + export type ServerOptions = ProtocolOptions & { /** * Capabilities to advertise as being supported by this server. */ - capabilities?: ServerCapabilities; + capabilities?: Omit & { + tasks?: ServerTasksCapabilityWithRuntime; + }; /** * Optional instructions describing how to use the server and its features. @@ -93,6 +104,7 @@ export class Server extends Protocol { private _instructions?: string; private _jsonSchemaValidator: jsonSchemaValidator; private _experimental?: { tasks: ExperimentalServerTasks }; + private _taskModule?: TaskManager; /** * Callback for when initialization has fully completed (i.e., the client has sent an `notifications/initialized` notification). @@ -107,10 +119,26 @@ export class Server extends Protocol { options?: ServerOptions ) { super(options); - this._capabilities = options?.capabilities ?? {}; + this._capabilities = options?.capabilities ? { ...options.capabilities } : {}; this._instructions = options?.instructions; this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); + // If tasks capability is declared, create and register the task module + if (options?.capabilities?.tasks) { + const { taskStore, taskMessageQueue, ...wireCapabilities } = options.capabilities.tasks; + // Strip runtime-only config from advertised capabilities + this._capabilities.tasks = wireCapabilities; + this._taskModule = new TaskManager({ + taskStore, + taskMessageQueue, + enforceStrictCapabilities: options?.enforceStrictCapabilities, + assertTaskCapability: method => + assertClientRequestTaskCapability(this._clientCapabilities?.tasks?.requests, method, 'Client'), + assertTaskHandlerCapability: method => assertToolsCallTaskCapability(this._capabilities.tasks?.requests, method, 'Server') + }); + this.registerModule(this._taskModule); + } + this.setRequestHandler('initialize', request => this._oninitialize(request)); this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); @@ -119,6 +147,13 @@ export class Server extends Protocol { } } + /** + * Access the task module, if tasks capability is configured. + */ + get taskModule(): TaskManager | undefined { + return this._taskModule; + } + private _registerLoggingHandler(): void { this.setRequestHandler('logging/setLevel', async (request, ctx) => { const transportSessionId: string | undefined = @@ -347,12 +382,6 @@ export class Server extends Protocol { } protected assertRequestHandlerCapability(method: string): void { - // Task handlers are registered in Protocol constructor before _capabilities is initialized - // Skip capability check for task methods during initialization - if (!this._capabilities) { - return; - } - switch (method) { case 'completion/complete': { if (!this._capabilities.completions) { @@ -393,19 +422,6 @@ export class Server extends Protocol { break; } - case 'tasks/get': - case 'tasks/list': - case 'tasks/result': - case 'tasks/cancel': { - if (!this._capabilities.tasks) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `Server does not support tasks capability (required for ${method})` - ); - } - break; - } - case 'ping': case 'initialize': { // No specific capability required for these methods @@ -414,20 +430,6 @@ export class Server extends Protocol { } } - protected assertTaskCapability(method: string): void { - assertClientRequestTaskCapability(this._clientCapabilities?.tasks?.requests, method, 'Client'); - } - - protected assertTaskHandlerCapability(method: string): void { - // Task handlers are registered in Protocol constructor before _capabilities is initialized - // Skip capability check for task methods during initialization - if (!this._capabilities) { - return; - } - - assertToolsCallTaskCapability(this._capabilities.tasks?.requests, method, 'Server'); - } - private async _oninitialize(request: InitializeRequest): Promise { const requestedVersion = request.params.protocolVersion; diff --git a/test/integration/test/client/client.test.ts b/test/integration/test/client/client.test.ts index 948d16e17..1e742b1bb 100644 --- a/test/integration/test/client/client.test.ts +++ b/test/integration/test/client/client.test.ts @@ -2206,10 +2206,13 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { capabilities: { tasks: { requests: { tools: { call: {} } } } } } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2250,10 +2253,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore: serverTaskStore } - }, - taskStore: serverTaskStore + } } ); @@ -2290,10 +2294,13 @@ describe('Task-based execution', () => { } ); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { capabilities: { tasks: { requests: { tools: { call: {} } } } } } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2329,10 +2336,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore: serverTaskStore } - }, - taskStore: serverTaskStore + } } ); @@ -2369,10 +2377,13 @@ describe('Task-based execution', () => { } ); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { capabilities: { tasks: { requests: { tools: { call: {} } } } } } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2409,10 +2420,11 @@ describe('Task-based execution', () => { call: {}, list: {} } - } + }, + + taskStore: serverTaskStore } - }, - taskStore: serverTaskStore + } } ); @@ -2449,10 +2461,13 @@ describe('Task-based execution', () => { } ); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { capabilities: { tasks: { requests: { tools: { call: {} } } } } } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2493,10 +2508,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore: serverTaskStore } - }, - taskStore: serverTaskStore + } } ); @@ -2533,10 +2549,13 @@ describe('Task-based execution', () => { } ); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { capabilities: { tasks: { requests: { tools: { call: {} } } } } } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2600,10 +2619,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -2692,10 +2712,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -2783,10 +2804,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -2873,10 +2895,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -2975,10 +2998,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore: serverTaskStore } - }, - taskStore: serverTaskStore + } } ); @@ -3100,10 +3124,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore: serverTaskStore } - }, - taskStore: serverTaskStore + } } ); @@ -3147,10 +3172,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore: serverTaskStore } - }, - taskStore: serverTaskStore + } } ); @@ -3194,10 +3220,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -3248,10 +3275,11 @@ test('should respect server task capabilities', async () => { tools: { call: {} } - } + }, + + taskStore: serverTaskStore } - }, - taskStore: serverTaskStore + } } ); @@ -3294,7 +3322,16 @@ test('should respect server task capabilities', async () => { version: '1.0.0' }, { - enforceStrictCapabilities: true + enforceStrictCapabilities: true, + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } } ); @@ -3367,7 +3404,7 @@ test('should expose requestStream() method for streaming responses', async () => version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); @@ -3429,7 +3466,7 @@ test('should expose callToolStream() method for streaming tool calls', async () version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); @@ -3507,7 +3544,7 @@ test('should validate structured output in callToolStream()', async () => { version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); @@ -3585,7 +3622,7 @@ test('callToolStream() should yield error when structuredContent does not match version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); @@ -3658,7 +3695,7 @@ test('callToolStream() should yield error when tool with outputSchema returns no version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); @@ -3723,7 +3760,7 @@ test('callToolStream() should handle tools without outputSchema normally', async version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); @@ -3818,7 +3855,7 @@ test('callToolStream() should handle complex JSON schema validation', async () = version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); @@ -3897,7 +3934,7 @@ test('callToolStream() should yield error with additional properties when not al version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); @@ -3971,7 +4008,7 @@ test('callToolStream() should not validate structuredContent when isError is tru version: '1.0.0' }, { - capabilities: {} + capabilities: { tasks: { requests: { tools: { call: {} } } } } } ); diff --git a/test/integration/test/helpers/mcp.ts b/test/integration/test/helpers/mcp.ts index 5c53c7a92..1fe0b3391 100644 --- a/test/integration/test/helpers/mcp.ts +++ b/test/integration/test/helpers/mcp.ts @@ -50,11 +50,11 @@ export async function createInMemoryTaskEnvironment(options?: { tools: { call: {} } - } + }, + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() } - }, - taskStore, - taskMessageQueue: new InMemoryTaskMessageQueue() + } } ); diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index fe7b4c187..4d1d51fc8 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -1875,182 +1875,195 @@ describe('createMessageStream', () => { }).toThrow('tool_result blocks are not matching any tool_use from the previous message'); }); - describe('terminal message guarantees', () => { - test('should yield exactly one terminal message for successful request', async () => { - const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); - const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); - - client.setRequestHandler('sampling/createMessage', async () => ({ - role: 'assistant', - content: { type: 'text', text: 'Response' }, - model: 'test-model' - })); + describe('with tasks', () => { + let server: Server; + let client: Client; + let clientTransport: ReturnType[0]; + let serverTransport: ReturnType[1]; + + beforeEach(async () => { + server = new Server( + { name: 'test server', version: '1.0' }, + { + capabilities: { + tasks: { + taskStore: new InMemoryTaskStore() + } + } + } + ); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); - const stream = server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100 - }); + [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + }); - const allMessages = await toArrayAsync(stream); + afterEach(async () => { + await server.close().catch(() => {}); + await client.close().catch(() => {}); + }); - expect(allMessages.length).toBe(1); - expect(allMessages[0].type).toBe('result'); + describe('terminal message guarantees', () => { + test('should yield exactly one terminal message for successful request', async () => { + client.setRequestHandler('sampling/createMessage', async () => ({ + role: 'assistant', + content: { type: 'text', text: 'Response' }, + model: 'test-model' + })); - const taskMessages = allMessages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); - expect(taskMessages.length).toBe(0); - }); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - test('should yield error as terminal message when client returns error', async () => { - const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); - const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + const stream = server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100 + }); - client.setRequestHandler('sampling/createMessage', async () => { - throw new Error('Simulated client error'); - }); + const allMessages = await toArrayAsync(stream); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + expect(allMessages.length).toBe(1); + expect(allMessages[0].type).toBe('result'); - const stream = server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100 + const taskMessages = allMessages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); + expect(taskMessages.length).toBe(0); }); - const allMessages = await toArrayAsync(stream); - - expect(allMessages.length).toBe(1); - expect(allMessages[0].type).toBe('error'); - }); + test('should yield error as terminal message when client returns error', async () => { + client.setRequestHandler('sampling/createMessage', async () => { + throw new Error('Simulated client error'); + }); - test('should yield exactly one terminal message with result', async () => { - const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); - const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - client.setRequestHandler('sampling/createMessage', () => ({ - model: 'test-model', - role: 'assistant' as const, - content: { type: 'text' as const, text: 'Response' } - })); + const stream = server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100 + }); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + const allMessages = await toArrayAsync(stream); - const stream = server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Message' } }], - maxTokens: 100 + expect(allMessages.length).toBe(1); + expect(allMessages[0].type).toBe('error'); }); - const messages = await toArrayAsync(stream); - const terminalMessages = messages.filter(m => m.type === 'result' || m.type === 'error'); + test('should yield exactly one terminal message with result', async () => { + client.setRequestHandler('sampling/createMessage', () => ({ + model: 'test-model', + role: 'assistant' as const, + content: { type: 'text' as const, text: 'Response' } + })); - expect(terminalMessages.length).toBe(1); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - const lastMessage = messages.at(-1); - expect(lastMessage.type === 'result' || lastMessage.type === 'error').toBe(true); + const stream = server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Message' } }], + maxTokens: 100 + }); - if (lastMessage.type === 'result') { - expect((lastMessage.result as CreateMessageResult).content).toBeDefined(); - } - }); - }); + const messages = await toArrayAsync(stream); + const terminalMessages = messages.filter(m => m.type === 'result' || m.type === 'error'); - describe('non-task request minimality', () => { - test('should yield only result message for non-task request', async () => { - const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); - const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); - - client.setRequestHandler('sampling/createMessage', () => ({ - model: 'test-model', - role: 'assistant' as const, - content: { type: 'text' as const, text: 'Response' } - })); + expect(terminalMessages.length).toBe(1); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + const lastMessage = messages.at(-1); + expect(lastMessage.type === 'result' || lastMessage.type === 'error').toBe(true); - const stream = server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Message' } }], - maxTokens: 100 + if (lastMessage.type === 'result') { + expect((lastMessage.result as CreateMessageResult).content).toBeDefined(); + } }); + }); - const messages = await toArrayAsync(stream); + describe('non-task request minimality', () => { + test('should yield only result message for non-task request', async () => { + client.setRequestHandler('sampling/createMessage', () => ({ + model: 'test-model', + role: 'assistant' as const, + content: { type: 'text' as const, text: 'Response' } + })); - const taskMessages = messages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); - expect(taskMessages.length).toBe(0); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - const resultMessages = messages.filter(m => m.type === 'result'); - expect(resultMessages.length).toBe(1); + const stream = server.experimental.tasks.createMessageStream({ + messages: [{ role: 'user', content: { type: 'text', text: 'Message' } }], + maxTokens: 100 + }); - expect(messages.length).toBe(1); + const messages = await toArrayAsync(stream); + + const taskMessages = messages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); + expect(taskMessages.length).toBe(0); + + const resultMessages = messages.filter(m => m.type === 'result'); + expect(resultMessages.length).toBe(1); + + expect(messages.length).toBe(1); + }); }); - }); - describe('task-augmented request handling', () => { - test('should yield taskCreated and result for task-augmented request', async () => { - const clientTaskStore = new InMemoryTaskStore(); - const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); - const client = new Client( - { name: 'test client', version: '1.0' }, - { - capabilities: { - sampling: {}, - tasks: { - requests: { - sampling: { createMessage: {} } + describe('task-augmented request handling', () => { + test('should yield taskCreated and result for task-augmented request', async () => { + const clientTaskStore = new InMemoryTaskStore(); + const taskClient = new Client( + { name: 'test client', version: '1.0' }, + { + capabilities: { + sampling: {}, + tasks: { + taskStore: clientTaskStore, + requests: { + sampling: { createMessage: {} } + } } } - }, - taskStore: clientTaskStore - } - ); - - client.setRequestHandler('sampling/createMessage', async (request, extra) => { - const result = { - model: 'test-model', - role: 'assistant' as const, - content: { type: 'text' as const, text: 'Task response' } - }; + } + ); - if (request.params.task && extra.task?.store) { - const task = await extra.task.store.createTask({ ttl: extra.task.requestedTtl }); - await extra.task.store.storeTaskResult(task.taskId, 'completed', result); - return { task }; - } - return result; - }); + taskClient.setRequestHandler('sampling/createMessage', async (request, extra) => { + const result = { + model: 'test-model', + role: 'assistant' as const, + content: { type: 'text' as const, text: 'Task response' } + }; + + if (request.params.task && extra.task?.store) { + const task = await extra.task.store.createTask({ ttl: extra.task.requestedTtl }); + await extra.task.store.storeTaskResult(task.taskId, 'completed', result); + return { task }; + } + return result; + }); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + const [taskClientTransport, taskServerTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([taskClient.connect(taskClientTransport), server.connect(taskServerTransport)]); - const stream = server.experimental.tasks.createMessageStream( - { - messages: [{ role: 'user', content: { type: 'text', text: 'Task-augmented message' } }], - maxTokens: 100 - }, - { task: { ttl: 60_000 } } - ); + const stream = server.experimental.tasks.createMessageStream( + { + messages: [{ role: 'user', content: { type: 'text', text: 'Task-augmented message' } }], + maxTokens: 100 + }, + { task: { ttl: 60_000 } } + ); - const messages = await toArrayAsync(stream); + const messages = await toArrayAsync(stream); - // Should have taskCreated and result - expect(messages.length).toBeGreaterThanOrEqual(2); + // Should have taskCreated and result + expect(messages.length).toBeGreaterThanOrEqual(2); - // First message should be taskCreated - expect(messages[0].type).toBe('taskCreated'); - const taskCreated = messages[0] as { type: 'taskCreated'; task: Task }; - expect(taskCreated.task.taskId).toBeDefined(); + // First message should be taskCreated + expect(messages[0].type).toBe('taskCreated'); + const taskCreated = messages[0] as { type: 'taskCreated'; task: Task }; + expect(taskCreated.task.taskId).toBeDefined(); - // Last message should be result - const lastMessage = messages.at(-1); - expect(lastMessage.type).toBe('result'); - if (lastMessage.type === 'result') { - expect((lastMessage.result as CreateMessageResult).model).toBe('test-model'); - } + // Last message should be result + const lastMessage = messages.at(-1); + expect(lastMessage.type).toBe('result'); + if (lastMessage.type === 'result') { + expect((lastMessage.result as CreateMessageResult).model).toBe('test-model'); + } - clientTaskStore.cleanup(); + clientTaskStore.cleanup(); + await taskClient.close().catch(() => {}); + }); }); }); }); @@ -2362,10 +2375,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -2542,10 +2556,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -2719,10 +2734,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -2746,10 +2762,23 @@ describe('Task-based execution', () => { return result; }); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2799,10 +2828,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -2826,10 +2856,21 @@ describe('Task-based execution', () => { return result; }); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { create: {} } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2877,10 +2918,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -2904,10 +2946,21 @@ describe('Task-based execution', () => { return result; }); - const server = new Server({ - name: 'test-server', - version: '1.0.0' - }); + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { create: {} } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2957,10 +3010,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -3059,10 +3113,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -3198,10 +3253,11 @@ describe('Task-based execution', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -3245,10 +3301,11 @@ describe('Task-based execution', () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -3302,10 +3359,11 @@ test('should respect client task capabilities', async () => { elicitation: { create: {} } - } + }, + + taskStore: clientTaskStore } - }, - taskStore: clientTaskStore + } } ); @@ -3414,7 +3472,16 @@ describe('elicitInputStream', () => { let serverTransport: ReturnType[1]; beforeEach(async () => { - server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + server = new Server( + { name: 'test server', version: '1.0' }, + { + capabilities: { + tasks: { + taskStore: new InMemoryTaskStore() + } + } + } + ); client = new Client( { name: 'test client', version: '1.0' }, @@ -3643,12 +3710,12 @@ describe('elicitInputStream', () => { capabilities: { elicitation: { form: {} }, tasks: { + taskStore: clientTaskStore, requests: { elicitation: { create: {} } } } - }, - taskStore: clientTaskStore + } } ); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 416f05102..64e0732e1 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1981,10 +1981,11 @@ describe('Zod v4', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -2050,10 +2051,11 @@ describe('Zod v4', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -6313,10 +6315,11 @@ describe('Zod v4', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -6415,10 +6418,11 @@ describe('Zod v4', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -6520,10 +6524,11 @@ describe('Zod v4', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -6640,10 +6645,11 @@ describe('Zod v4', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -6743,10 +6749,11 @@ describe('Zod v4', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); @@ -6841,10 +6848,11 @@ describe('Zod v4', () => { tools: { call: {} } - } + }, + + taskStore } - }, - taskStore + } } ); diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts index 5a2cd7ca0..6281e833d 100644 --- a/test/integration/test/taskLifecycle.test.ts +++ b/test/integration/test/taskLifecycle.test.ts @@ -39,11 +39,11 @@ describe('Task Lifecycle Integration Tests', () => { } }, list: {}, - cancel: {} + cancel: {}, + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() } - }, - taskStore, - taskMessageQueue: new InMemoryTaskMessageQueue() + } } ); @@ -300,10 +300,15 @@ describe('Task Lifecycle Integration Tests', () => { describe('Task Cancellation', () => { it('should cancel a working task and return the cancelled task', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { tasks: {} } + } + ); const transport = new StreamableHTTPClientTransport(baseUrl); await client.connect(transport); @@ -346,10 +351,15 @@ describe('Task Lifecycle Integration Tests', () => { }); it('should reject cancellation of completed task with error code -32602', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { tasks: {} } + } + ); const transport = new StreamableHTTPClientTransport(baseUrl); await client.connect(transport); @@ -734,10 +744,15 @@ describe('Task Lifecycle Integration Tests', () => { describe('Error Handling', () => { it('should return error code -32602 for non-existent task in tasks/get', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { tasks: {} } + } + ); const transport = new StreamableHTTPClientTransport(baseUrl); await client.connect(transport); @@ -754,10 +769,15 @@ describe('Task Lifecycle Integration Tests', () => { }); it('should return error code -32602 for non-existent task in tasks/cancel', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { tasks: {} } + } + ); const transport = new StreamableHTTPClientTransport(baseUrl); await client.connect(transport); @@ -1491,7 +1511,8 @@ describe('Task Lifecycle Integration Tests', () => { }, { capabilities: { - elicitation: {} + elicitation: {}, + tasks: {} } } );