diff --git a/src/const.ts b/src/const.ts index 5e0715b7..732457ef 100644 --- a/src/const.ts +++ b/src/const.ts @@ -208,3 +208,7 @@ export const HTTP_NOT_FOUND = 404; // Modes that allow long running task tool executions export const ALLOWED_TASK_TOOL_EXECUTION_MODES = ['optional', 'required'] as const; + +// MCP _meta key for associating messages with a task. +// TODO: replace with RELATED_TASK_META_KEY from @modelcontextprotocol/sdk once the installed SDK exports it. +export const RELATED_TASK_META_KEY = 'io.modelcontextprotocol/related-task'; diff --git a/src/mcp/server.ts b/src/mcp/server.ts index 051b3f08..e1ec65a4 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -871,7 +871,7 @@ export class ActorsMcpServer { } // Only set up notification handlers if progressToken is provided by the client - if (progressToken) { + if (progressToken !== undefined && progressToken !== null) { // Set up notification handlers for the client for (const schema of ServerNotificationSchema.options) { const method = schema.shape.method.value; diff --git a/src/utils/progress.ts b/src/utils/progress.ts index 5d0f7268..dde27600 100644 --- a/src/utils/progress.ts +++ b/src/utils/progress.ts @@ -1,7 +1,7 @@ import type { ProgressNotification } from '@modelcontextprotocol/sdk/types.js'; import type { ApifyClient } from '../apify_client.js'; -import { PROGRESS_NOTIFICATION_INTERVAL_MS } from '../const.js'; +import { PROGRESS_NOTIFICATION_INTERVAL_MS, RELATED_TASK_META_KEY } from '../const.js'; export class ProgressTracker { private progressToken?: string | number; @@ -27,7 +27,7 @@ export class ProgressTracker { this.currentProgress += 1; // Send progress notification only if progressToken and sendNotification are available - if (this.progressToken && this.sendNotification) { + if (this.progressToken !== undefined && this.progressToken !== null && this.sendNotification) { try { const notification: ProgressNotification = { method: 'notifications/progress' as const, @@ -39,7 +39,7 @@ export class ProgressTracker { // Per MCP spec: progress notifications during task execution should include related-task metadata ...(this.taskId && { _meta: { - 'io.modelcontextprotocol/related-task': { + [RELATED_TASK_META_KEY]: { taskId: this.taskId, }, }, @@ -111,7 +111,8 @@ export function createProgressTracker( onStatusMessage?: (message: string) => Promise, ): ProgressTracker | null { // Create tracker if we have either progress notification support or a status message callback - if ((!progressToken || !sendNotification) && !onStatusMessage) { + const hasProgressNotificationSupport = progressToken !== undefined && progressToken !== null && !!sendNotification; + if (!hasProgressNotificationSupport && !onStatusMessage) { return null; } diff --git a/tests/unit/mcp.utils.test.ts b/tests/unit/mcp.utils.test.ts index 1227847e..194068e1 100644 --- a/tests/unit/mcp.utils.test.ts +++ b/tests/unit/mcp.utils.test.ts @@ -1,7 +1,8 @@ +import type { TaskStore } from '@modelcontextprotocol/sdk/experimental/tasks/interfaces.js'; import { describe, expect, it, vi } from 'vitest'; import { SKYFIRE_README_CONTENT } from '../../src/const.js'; -import { parseInputParamsFromUrl } from '../../src/mcp/utils.js'; +import { isTaskCancelled, parseInputParamsFromUrl } from '../../src/mcp/utils.js'; import { resolvePaymentProvider } from '../../src/payments/index.js'; import { createResourceService } from '../../src/resources/resource_service.js'; import type { AvailableWidget } from '../../src/resources/widgets.js'; @@ -61,6 +62,40 @@ describe('parseInputParamsFromUrl', () => { }); }); +describe('isTaskCancelled', () => { + const makeTaskStore = (getTaskReturn: unknown) => ({ + getTask: vi.fn().mockResolvedValue(getTaskReturn), + } as unknown as TaskStore); + + it('should return true when task status is cancelled', async () => { + const taskStore = makeTaskStore({ status: 'cancelled' }); + const result = await isTaskCancelled('task-1', 'session-1', taskStore); + + expect(result).toBe(true); + }); + + it('should return false when task status is not cancelled', async () => { + const taskStore = makeTaskStore({ status: 'working' }); + const result = await isTaskCancelled('task-1', 'session-1', taskStore); + + expect(result).toBe(false); + }); + + it('should return false when task is not found (getTask returns undefined)', async () => { + const taskStore = makeTaskStore(undefined); + const result = await isTaskCancelled('task-1', 'session-1', taskStore); + + expect(result).toBe(false); + }); + + it('should pass taskId and mcpSessionId through to taskStore.getTask', async () => { + const taskStore = makeTaskStore({ status: 'working' }); + await isTaskCancelled('task-42', 'session-xyz', taskStore); + + expect(taskStore.getTask).toHaveBeenCalledWith('task-42', 'session-xyz'); + }); +}); + describe('MCP resources', () => { const buildAvailableWidget = (uri: string, exists: boolean): AvailableWidget => ({ ...WIDGET_REGISTRY[uri], diff --git a/tests/unit/tools.mode_contract.test.ts b/tests/unit/tools.mode_contract.test.ts index a0a85fdf..cefb16e3 100644 --- a/tests/unit/tools.mode_contract.test.ts +++ b/tests/unit/tools.mode_contract.test.ts @@ -9,7 +9,7 @@ */ import { describe, expect, it } from 'vitest'; -import { HelperTools } from '../../src/const.js'; +import { ALLOWED_TASK_TOOL_EXECUTION_MODES, HelperTools } from '../../src/const.js'; import { searchApifyDocsTool } from '../../src/tools/common/search_apify_docs.js'; import { CATEGORY_NAMES, getCategoryTools } from '../../src/tools/index.js'; import type { ToolBase, ToolEntry } from '../../src/types.js'; @@ -174,6 +174,40 @@ describe('getCategoryTools mode contract (tool-mode separation)', () => { }); }); +describe('taskSupport contract across tool categories', () => { + it('should declare taskSupport only on call-actor in default mode, with an allowed value', () => { + const defaultCategories = getCategoryTools('default'); + const toolsWithTaskSupport: { name: string; value: unknown }[] = []; + + for (const categoryName of CATEGORY_NAMES) { + for (const tool of defaultCategories[categoryName]) { + if (tool.execution?.taskSupport !== undefined) { + toolsWithTaskSupport.push({ name: tool.name, value: tool.execution.taskSupport }); + } + } + } + + // Only default-mode call-actor is expected to declare taskSupport among static internal tools. + // (Dynamically-created Actor tools from actor_tools_factory also declare it, but those are not + // returned by getCategoryTools.) + expect(toolsWithTaskSupport.map((t) => t.name)).toEqual([HelperTools.ACTOR_CALL]); + + for (const { value } of toolsWithTaskSupport) { + expect(ALLOWED_TASK_TOOL_EXECUTION_MODES).toContain(value); + } + }); + + it('should not declare taskSupport on any tool in openai mode', () => { + const openaiCategories = getCategoryTools('openai'); + + for (const categoryName of CATEGORY_NAMES) { + for (const tool of openaiCategories[categoryName]) { + expect(tool.execution?.taskSupport).toBeUndefined(); + } + } + }); +}); + describe('getToolPublicFieldOnly _meta filtering', () => { const toolWithOpenAiMeta = { name: 'test-tool', diff --git a/tests/unit/utils.progress.test.ts b/tests/unit/utils.progress.test.ts index e5c3fb89..78f4f569 100644 --- a/tests/unit/utils.progress.test.ts +++ b/tests/unit/utils.progress.test.ts @@ -1,6 +1,7 @@ import { describe, expect, it, vi } from 'vitest'; -import { ProgressTracker } from '../../src/utils/progress.js'; +import { RELATED_TASK_META_KEY } from '../../src/const.js'; +import { createProgressTracker, ProgressTracker } from '../../src/utils/progress.js'; describe('ProgressTracker', () => { it('should send progress notifications correctly', async () => { @@ -83,4 +84,70 @@ describe('ProgressTracker', () => { await expect(tracker.updateProgress('Test')).resolves.toBeUndefined(); expect(mockOnStatusMessage).toHaveBeenCalledWith('Test'); }); + + it('should include related-task metadata with taskId in progress notifications', async () => { + const mockSendNotification = vi.fn(); + const tracker = new ProgressTracker({ + progressToken: 'tok', + sendNotification: mockSendNotification, + taskId: 'task-abc', + }); + + await tracker.updateProgress('running'); + + expect(mockSendNotification).toHaveBeenCalledWith({ + method: 'notifications/progress', + params: { + progressToken: 'tok', + progress: 1, + message: 'running', + }, + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: 'task-abc', + }, + }, + }); + }); + + it('should not include _meta when taskId is not provided', async () => { + const mockSendNotification = vi.fn(); + const tracker = new ProgressTracker({ + progressToken: 'tok', + sendNotification: mockSendNotification, + }); + + await tracker.updateProgress('running'); + + const notification = mockSendNotification.mock.calls[0][0]; + expect(notification).not.toHaveProperty('_meta'); + }); +}); + +describe('createProgressTracker', () => { + it('should return null when no progressToken, no sendNotification, and no onStatusMessage', () => { + expect(createProgressTracker(undefined, undefined)).toBeNull(); + }); + + it('should return ProgressTracker when only onStatusMessage is provided', () => { + const tracker = createProgressTracker(undefined, undefined, undefined, vi.fn()); + expect(tracker).toBeInstanceOf(ProgressTracker); + }); + + it('should return ProgressTracker and send notifications for progressToken = 0', async () => { + const mockSendNotification = vi.fn(); + const tracker = createProgressTracker(0, mockSendNotification); + + expect(tracker).toBeInstanceOf(ProgressTracker); + await tracker?.updateProgress('Started'); + + expect(mockSendNotification).toHaveBeenCalledWith({ + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 1, + message: 'Started', + }, + }); + }); });