Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .changeset/extract-task-manager.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
"@modelcontextprotocol/core": minor
"@modelcontextprotocol/client": minor
"@modelcontextprotocol/server": minor
---

refactor: extract task orchestration from Protocol into TaskManager

**Breaking changes:**
- `extra.taskId` → `extra.task?.taskId`
- `extra.taskStore` → `extra.task?.taskStore`
- `extra.taskRequestedTtl` → `extra.task?.requestedTtl`
- `ProtocolOptions` no longer accepts `taskStore`/`taskMessageQueue` — pass via `TaskManagerOptions` in `ClientOptions`/`ServerOptions`
- Abstract methods `assertTaskCapability`/`assertTaskHandlerCapability` removed from Protocol
4 changes: 2 additions & 2 deletions examples/client/src/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,14 @@ async function connect(url?: string): Promise<void> {
form: {}
},
tasks: {
taskStore: clientTaskStore,
requests: {
elicitation: {
create: {}
}
}
}
},
taskStore: clientTaskStore
}
}
);
client.onerror = error => {
Expand Down
11 changes: 8 additions & 3 deletions examples/server/src/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ const getServer = () => {
websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk'
},
{
capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } },
taskStore, // Enable task support
taskMessageQueue: new InMemoryTaskMessageQueue()
capabilities: {
logging: {},
tasks: {
requests: { tools: { call: {} } },
taskStore,
taskMessageQueue: new InMemoryTaskMessageQueue()
}
}
}
);

Expand Down
75 changes: 38 additions & 37 deletions packages/client/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import type {
ResultTypeMap,
ServerCapabilities,
SubscribeRequest,
TaskManagerOptions,
Tool,
Transport,
UnsubscribeRequest
Expand Down Expand Up @@ -61,7 +62,8 @@ import {
ProtocolErrorCode,
ReadResourceResultSchema,
SdkError,
SdkErrorCode
SdkErrorCode,
TaskManager
} from '@modelcontextprotocol/core';

import { ExperimentalClientTasks } from '../experimental/tasks/client.js';
Expand Down Expand Up @@ -140,11 +142,20 @@ export function getSupportedElicitationModes(capabilities: ClientCapabilities['e
return { supportsFormMode, supportsUrlMode };
}

/**
* Extended tasks capability that includes runtime configuration (store, messageQueue).
* The runtime-only fields are stripped before advertising capabilities to servers.
*/
export type ClientTasksCapabilityWithRuntime = NonNullable<ClientCapabilities['tasks']> &
Pick<TaskManagerOptions, 'taskStore' | 'taskMessageQueue'>;

export type ClientOptions = ProtocolOptions & {
/**
* Capabilities to advertise as being supported by this client.
*/
capabilities?: ClientCapabilities;
capabilities?: Omit<ClientCapabilities, 'tasks'> & {
tasks?: ClientTasksCapabilityWithRuntime;
};

/**
* JSON Schema validator for tool output validation.
Expand Down Expand Up @@ -204,6 +215,7 @@ export class Client extends Protocol<ClientContext> {
private _listChangedDebounceTimers: Map<string, ReturnType<typeof setTimeout>> = new Map();
private _pendingListChangedConfig?: ListChangedHandlers;
private _enforceStrictCapabilities: boolean;
private _taskModule?: TaskManager;

/**
* Initializes this client with the given name and version information.
Expand All @@ -213,16 +225,39 @@ export class Client extends Protocol<ClientContext> {
options?: ClientOptions
) {
super(options);
this._capabilities = options?.capabilities ?? {};
this._capabilities = options?.capabilities ? { ...options.capabilities } : {};
this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator();
this._enforceStrictCapabilities = options?.enforceStrictCapabilities ?? false;

// If tasks capability is declared, create and register the task module
if (options?.capabilities?.tasks) {
const { taskStore, taskMessageQueue, ...wireCapabilities } = options.capabilities.tasks;
// Strip runtime-only config from advertised capabilities
this._capabilities.tasks = wireCapabilities;
this._taskModule = new TaskManager({
taskStore,
taskMessageQueue,
enforceStrictCapabilities: options?.enforceStrictCapabilities,
assertTaskCapability: method => assertToolsCallTaskCapability(this._serverCapabilities?.tasks?.requests, method, 'Server'),
assertTaskHandlerCapability: method =>
assertClientRequestTaskCapability(this._capabilities.tasks?.requests, method, 'Client')
});
this.registerModule(this._taskModule);
}

// Store list changed config for setup after connection (when we know server capabilities)
if (options?.listChanged) {
this._pendingListChangedConfig = options.listChanged;
}
}

/**
* Access the task module, if tasks capability is configured.
*/
get taskModule(): TaskManager | undefined {
return this._taskModule;
}

protected override buildContext(ctx: BaseContext, _transportInfo?: MessageExtraInfo): ClientContext {
return ctx;
}
Expand Down Expand Up @@ -635,12 +670,6 @@ export class Client extends Protocol<ClientContext> {
}

protected assertRequestHandlerCapability(method: string): void {
// Task handlers are registered in Protocol constructor before _capabilities is initialized
// Skip capability check for task methods during initialization
if (!this._capabilities) {
return;
}

switch (method) {
case 'sampling/createMessage': {
if (!this._capabilities.sampling) {
Expand Down Expand Up @@ -672,41 +701,13 @@ export class Client extends Protocol<ClientContext> {
break;
}

case 'tasks/get':
case 'tasks/list':
case 'tasks/result':
case 'tasks/cancel': {
if (!this._capabilities.tasks) {
throw new SdkError(
SdkErrorCode.CapabilityNotSupported,
`Client does not support tasks capability (required for ${method})`
);
}
break;
}

case 'ping': {
// No specific capability required for ping
break;
}
}
}

protected assertTaskCapability(method: string): void {
assertToolsCallTaskCapability(this._serverCapabilities?.tasks?.requests, method, 'Server');
}

protected assertTaskHandlerCapability(method: string): void {
// Task handlers are registered in Protocol constructor before _capabilities is initialized
// Skip capability check for task methods during initialization
if (!this._capabilities) {
return;
}

assertClientRequestTaskCapability(this._capabilities.tasks?.requests, method, 'Client');
}

/** Sends a ping to the server to check connectivity. */
async ping(options?: RequestOptions) {
return this._requestWithSchema({ method: 'ping' }, EmptyResultSchema, options);
}
Expand Down
54 changes: 21 additions & 33 deletions packages/client/src/experimental/tasks/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ import type {
CreateTaskResult,
GetTaskResult,
ListTasksResult,
Request,
RequestMethod,
RequestOptions,
ResponseMessage,
ResultTypeMap,
SchemaOutput
} from '@modelcontextprotocol/core';
import { ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core';
import { CallToolResultSchema, getResultSchema, ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core';

import type { Client } from '../../client/client.js';

Expand All @@ -28,10 +29,6 @@ import type { Client } from '../../client/client.js';
* @internal
*/
interface ClientInternal {
requestStream<M extends RequestMethod>(
request: { method: M; params?: Record<string, unknown> },
options?: RequestOptions
): AsyncGenerator<ResponseMessage<ResultTypeMap[M]>, void, void>;
isToolTask(toolName: string): boolean;
getToolOutputValidator(toolName: string): ((data: unknown) => { valid: boolean; errorMessage?: string }) | undefined;
}
Expand All @@ -50,6 +47,14 @@ interface ClientInternal {
export class ExperimentalClientTasks {
constructor(private readonly _client: Client) {}

private get _module() {
const module = this._client.taskModule;
if (!module) {
throw new Error('Tasks capability is not configured. Declare tasks in capabilities to use task features.');
}
return module;
}

/**
* Calls a tool and returns an AsyncGenerator that yields response messages.
* The generator is guaranteed to end with either a `'result'` or `'error'` message.
Expand Down Expand Up @@ -104,7 +109,7 @@ export class ExperimentalClientTasks {
task: options?.task ?? (clientInternal.isToolTask(params.name) ? {} : undefined)
};

const stream = clientInternal.requestStream({ method: 'tools/call', params }, optionsWithTask);
const stream = this._module.requestStream({ method: 'tools/call', params }, CallToolResultSchema, optionsWithTask);

// Get the validator for this tool (if it has an output schema)
const validator = clientInternal.getToolOutputValidator(params.name);
Expand Down Expand Up @@ -176,9 +181,7 @@ export class ExperimentalClientTasks {
* @experimental
*/
async getTask(taskId: string, options?: RequestOptions): Promise<GetTaskResult> {
// Delegate to the client's underlying Protocol method
type ClientWithGetTask = { getTask(params: { taskId: string }, options?: RequestOptions): Promise<GetTaskResult> };
return (this._client as unknown as ClientWithGetTask).getTask({ taskId }, options);
return this._module.getTask({ taskId }, options);
}

/**
Expand All @@ -192,16 +195,7 @@ export class ExperimentalClientTasks {
* @experimental
*/
async getTaskResult<T extends AnyObjectSchema>(taskId: string, resultSchema?: T, options?: RequestOptions): Promise<SchemaOutput<T>> {
// Delegate to the client's underlying Protocol method
return (
this._client as unknown as {
getTaskResult: <U extends AnyObjectSchema>(
params: { taskId: string },
resultSchema?: U,
options?: RequestOptions
) => Promise<SchemaOutput<U>>;
}
).getTaskResult({ taskId }, resultSchema, options);
return this._module.getTaskResult({ taskId }, resultSchema!, options);
}

/**
Expand All @@ -214,12 +208,7 @@ export class ExperimentalClientTasks {
* @experimental
*/
async listTasks(cursor?: string, options?: RequestOptions): Promise<ListTasksResult> {
// Delegate to the client's underlying Protocol method
return (
this._client as unknown as {
listTasks: (params?: { cursor?: string }, options?: RequestOptions) => Promise<ListTasksResult>;
}
).listTasks(cursor ? { cursor } : undefined, options);
return this._module.listTasks(cursor ? { cursor } : undefined, options);
}

/**
Expand All @@ -231,12 +220,7 @@ export class ExperimentalClientTasks {
* @experimental
*/
async cancelTask(taskId: string, options?: RequestOptions): Promise<CancelTaskResult> {
// Delegate to the client's underlying Protocol method
return (
this._client as unknown as {
cancelTask: (params: { taskId: string }, options?: RequestOptions) => Promise<CancelTaskResult>;
}
).cancelTask({ taskId }, options);
return this._module.cancelTask({ taskId }, options);
}

/**
Expand Down Expand Up @@ -281,7 +265,11 @@ export class ExperimentalClientTasks {
request: { method: M; params?: Record<string, unknown> },
options?: RequestOptions
): AsyncGenerator<ResponseMessage<ResultTypeMap[M]>, void, void> {
// Delegate to the client's underlying Protocol method
return (this._client as unknown as ClientInternal).requestStream(request, options);
const resultSchema = getResultSchema(request.method) as unknown as AnyObjectSchema;
return this._module.requestStream(request as Request, resultSchema, options) as AsyncGenerator<
ResponseMessage<ResultTypeMap[M]>,
void,
void
>;
}
}
4 changes: 2 additions & 2 deletions packages/core/src/experimental/tasks/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ interface TaskRequestsCapability {

/**
* Asserts that task creation is supported for `tools/call`.
* Used by {@linkcode @modelcontextprotocol/client!client/client.Client.assertTaskCapability | Client.assertTaskCapability} and {@linkcode @modelcontextprotocol/server!server/server.Server.assertTaskHandlerCapability | Server.assertTaskHandlerCapability}.
* Used as the `assertTaskCapability` or `assertTaskHandlerCapability` callback in `TaskManagerOptions`.
*
* @param requests - The task requests capability object
* @param method - The method being checked
Expand Down Expand Up @@ -52,7 +52,7 @@ export function assertToolsCallTaskCapability(

/**
* Asserts that task creation is supported for `sampling/createMessage` or `elicitation/create`.
* Used by {@linkcode @modelcontextprotocol/server!server/server.Server.assertTaskCapability | Server.assertTaskCapability} and {@linkcode @modelcontextprotocol/client!client/client.Client.assertTaskHandlerCapability | Client.assertTaskHandlerCapability}.
* Used as the `assertTaskCapability` or `assertTaskHandlerCapability` callback in `TaskManagerOptions`.
*
* @param requests - The task requests capability object
* @param method - The method being checked
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/experimental/tasks/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
* WARNING: These APIs are experimental and may change without notice.
*/

import type { RequestTaskStore, ServerContext } from '../../shared/protocol.js';
import type { ServerContext } from '../../shared/protocol.js';
import type { RequestTaskStore } from '../../shared/taskManager.js';
import type {
JSONRPCErrorResponse,
JSONRPCNotification,
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export * from './shared/metadataUtils.js';
export * from './shared/protocol.js';
export * from './shared/responseMessage.js';
export * from './shared/stdio.js';
export type { RequestTaskStore, TaskContext, TaskManagerOptions, TaskRequestOptions } from './shared/taskManager.js';
export { TaskManager } from './shared/taskManager.js';
export * from './shared/toolNameValidation.js';
export * from './shared/transport.js';
export * from './shared/uriTemplate.js';
Expand Down
Loading
Loading