Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/const.ts
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,7 @@ export const HTTP_NOT_FOUND = 404;

// Modes that allow long running task tool executions
export const ALLOWED_TASK_TOOL_EXECUTION_MODES = ['optional', 'required'] as const;

// MCP _meta key for associating messages with a task.
// TODO: replace with RELATED_TASK_META_KEY from @modelcontextprotocol/sdk once the installed SDK exports it.
export const RELATED_TASK_META_KEY = 'io.modelcontextprotocol/related-task';
2 changes: 1 addition & 1 deletion src/mcp/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ export class ActorsMcpServer {
}

// Only set up notification handlers if progressToken is provided by the client
if (progressToken) {
if (progressToken !== undefined && progressToken !== null) {
// Set up notification handlers for the client
for (const schema of ServerNotificationSchema.options) {
const method = schema.shape.method.value;
Expand Down
9 changes: 5 additions & 4 deletions src/utils/progress.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ProgressNotification } from '@modelcontextprotocol/sdk/types.js';

import type { ApifyClient } from '../apify_client.js';
import { PROGRESS_NOTIFICATION_INTERVAL_MS } from '../const.js';
import { PROGRESS_NOTIFICATION_INTERVAL_MS, RELATED_TASK_META_KEY } from '../const.js';

export class ProgressTracker {
private progressToken?: string | number;
Expand All @@ -27,7 +27,7 @@ export class ProgressTracker {
this.currentProgress += 1;

// Send progress notification only if progressToken and sendNotification are available
if (this.progressToken && this.sendNotification) {
if (this.progressToken !== undefined && this.progressToken !== null && this.sendNotification) {
try {
const notification: ProgressNotification = {
method: 'notifications/progress' as const,
Expand All @@ -39,7 +39,7 @@ export class ProgressTracker {
// Per MCP spec: progress notifications during task execution should include related-task metadata
...(this.taskId && {
_meta: {
'io.modelcontextprotocol/related-task': {
[RELATED_TASK_META_KEY]: {
taskId: this.taskId,
},
},
Expand Down Expand Up @@ -111,7 +111,8 @@ export function createProgressTracker(
onStatusMessage?: (message: string) => Promise<void>,
): ProgressTracker | null {
// Create tracker if we have either progress notification support or a status message callback
if ((!progressToken || !sendNotification) && !onStatusMessage) {
const hasProgressNotificationSupport = progressToken !== undefined && progressToken !== null && !!sendNotification;
if (!hasProgressNotificationSupport && !onStatusMessage) {
return null;
}

Expand Down
37 changes: 36 additions & 1 deletion tests/unit/mcp.utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import type { TaskStore } from '@modelcontextprotocol/sdk/experimental/tasks/interfaces.js';
import { describe, expect, it, vi } from 'vitest';

import { SKYFIRE_README_CONTENT } from '../../src/const.js';
import { parseInputParamsFromUrl } from '../../src/mcp/utils.js';
import { isTaskCancelled, parseInputParamsFromUrl } from '../../src/mcp/utils.js';
import { resolvePaymentProvider } from '../../src/payments/index.js';
import { createResourceService } from '../../src/resources/resource_service.js';
import type { AvailableWidget } from '../../src/resources/widgets.js';
Expand Down Expand Up @@ -61,6 +62,40 @@ describe('parseInputParamsFromUrl', () => {
});
});

describe('isTaskCancelled', () => {
const makeTaskStore = (getTaskReturn: unknown) => ({
getTask: vi.fn().mockResolvedValue(getTaskReturn),
} as unknown as TaskStore);

it('should return true when task status is cancelled', async () => {
const taskStore = makeTaskStore({ status: 'cancelled' });
const result = await isTaskCancelled('task-1', 'session-1', taskStore);

expect(result).toBe(true);
});

it('should return false when task status is not cancelled', async () => {
const taskStore = makeTaskStore({ status: 'working' });
const result = await isTaskCancelled('task-1', 'session-1', taskStore);

expect(result).toBe(false);
});

it('should return false when task is not found (getTask returns undefined)', async () => {
const taskStore = makeTaskStore(undefined);
const result = await isTaskCancelled('task-1', 'session-1', taskStore);

expect(result).toBe(false);
});

it('should pass taskId and mcpSessionId through to taskStore.getTask', async () => {
const taskStore = makeTaskStore({ status: 'working' });
await isTaskCancelled('task-42', 'session-xyz', taskStore);

expect(taskStore.getTask).toHaveBeenCalledWith('task-42', 'session-xyz');
});
});

describe('MCP resources', () => {
const buildAvailableWidget = (uri: string, exists: boolean): AvailableWidget => ({
...WIDGET_REGISTRY[uri],
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/tools.mode_contract.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
*/
import { describe, expect, it } from 'vitest';

import { HelperTools } from '../../src/const.js';
import { ALLOWED_TASK_TOOL_EXECUTION_MODES, HelperTools } from '../../src/const.js';
import { searchApifyDocsTool } from '../../src/tools/common/search_apify_docs.js';
import { CATEGORY_NAMES, getCategoryTools } from '../../src/tools/index.js';
import type { ToolBase, ToolEntry } from '../../src/types.js';
Expand Down Expand Up @@ -174,6 +174,40 @@ describe('getCategoryTools mode contract (tool-mode separation)', () => {
});
});

describe('taskSupport contract across tool categories', () => {
it('should declare taskSupport only on call-actor in default mode, with an allowed value', () => {
const defaultCategories = getCategoryTools('default');
const toolsWithTaskSupport: { name: string; value: unknown }[] = [];

for (const categoryName of CATEGORY_NAMES) {
for (const tool of defaultCategories[categoryName]) {
if (tool.execution?.taskSupport !== undefined) {
toolsWithTaskSupport.push({ name: tool.name, value: tool.execution.taskSupport });
}
}
}

// Only default-mode call-actor is expected to declare taskSupport among static internal tools.
// (Dynamically-created Actor tools from actor_tools_factory also declare it, but those are not
// returned by getCategoryTools.)
expect(toolsWithTaskSupport.map((t) => t.name)).toEqual([HelperTools.ACTOR_CALL]);

for (const { value } of toolsWithTaskSupport) {
expect(ALLOWED_TASK_TOOL_EXECUTION_MODES).toContain(value);
}
});

it('should not declare taskSupport on any tool in openai mode', () => {
const openaiCategories = getCategoryTools('openai');

for (const categoryName of CATEGORY_NAMES) {
for (const tool of openaiCategories[categoryName]) {
expect(tool.execution?.taskSupport).toBeUndefined();
}
}
});
});

describe('getToolPublicFieldOnly _meta filtering', () => {
const toolWithOpenAiMeta = {
name: 'test-tool',
Expand Down
69 changes: 68 additions & 1 deletion tests/unit/utils.progress.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { describe, expect, it, vi } from 'vitest';

import { ProgressTracker } from '../../src/utils/progress.js';
import { RELATED_TASK_META_KEY } from '../../src/const.js';
import { createProgressTracker, ProgressTracker } from '../../src/utils/progress.js';

describe('ProgressTracker', () => {
it('should send progress notifications correctly', async () => {
Expand Down Expand Up @@ -83,4 +84,70 @@ describe('ProgressTracker', () => {
await expect(tracker.updateProgress('Test')).resolves.toBeUndefined();
expect(mockOnStatusMessage).toHaveBeenCalledWith('Test');
});

it('should include related-task metadata with taskId in progress notifications', async () => {
const mockSendNotification = vi.fn();
const tracker = new ProgressTracker({
progressToken: 'tok',
sendNotification: mockSendNotification,
taskId: 'task-abc',
});

await tracker.updateProgress('running');

expect(mockSendNotification).toHaveBeenCalledWith({
method: 'notifications/progress',
params: {
progressToken: 'tok',
progress: 1,
message: 'running',
},
_meta: {
[RELATED_TASK_META_KEY]: {
taskId: 'task-abc',
},
},
});
});

it('should not include _meta when taskId is not provided', async () => {
const mockSendNotification = vi.fn();
const tracker = new ProgressTracker({
progressToken: 'tok',
sendNotification: mockSendNotification,
});

await tracker.updateProgress('running');

const notification = mockSendNotification.mock.calls[0][0];
expect(notification).not.toHaveProperty('_meta');
});
});

describe('createProgressTracker', () => {
it('should return null when no progressToken, no sendNotification, and no onStatusMessage', () => {
expect(createProgressTracker(undefined, undefined)).toBeNull();
});

it('should return ProgressTracker when only onStatusMessage is provided', () => {
const tracker = createProgressTracker(undefined, undefined, undefined, vi.fn());
Comment thread
jirispilka marked this conversation as resolved.
expect(tracker).toBeInstanceOf(ProgressTracker);
});

it('should return ProgressTracker and send notifications for progressToken = 0', async () => {
const mockSendNotification = vi.fn();
const tracker = createProgressTracker(0, mockSendNotification);

expect(tracker).toBeInstanceOf(ProgressTracker);
await tracker?.updateProgress('Started');

expect(mockSendNotification).toHaveBeenCalledWith({
method: 'notifications/progress',
params: {
progressToken: 0,
progress: 1,
message: 'Started',
},
});
});
});
Loading