diff --git a/go/types.go b/go/types.go index f568d1325..9a53f223c 100644 --- a/go/types.go +++ b/go/types.go @@ -489,6 +489,9 @@ func (MCPStdioServerConfig) mcpServerConfig() {} // MarshalJSON implements json.Marshaler, injecting the "type" discriminator. func (c MCPStdioServerConfig) MarshalJSON() ([]byte, error) { type alias MCPStdioServerConfig + if c.Args == nil { + c.Args = []string{} + } return json.Marshal(struct { Type string `json:"type"` alias diff --git a/go/types_test.go b/go/types_test.go index 2d80d206c..989de2f92 100644 --- a/go/types_test.go +++ b/go/types_test.go @@ -123,6 +123,31 @@ func TestProviderConfig_JSONIncludesHeaders(t *testing.T) { } } +func TestMCPStdioServerConfig_JSONDefaultsNilArgsToEmptyArray(t *testing.T) { + config := MCPStdioServerConfig{ + Command: "mcp-server", + Tools: []string{"*"}, + } + + data, err := json.Marshal(config) + if err != nil { + t.Fatalf("failed to marshal MCP stdio server config: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal MCP stdio server config: %v", err) + } + + args, ok := decoded["args"].([]any) + if !ok { + t.Fatalf("expected args to be an array, got %T in %s", decoded["args"], string(data)) + } + if len(args) != 0 { + t.Fatalf("expected empty args array, got %v", args) + } +} + func TestSessionSendRequest_JSONIncludesRequestHeaders(t *testing.T) { req := sessionSendRequest{ SessionID: "session-1", diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index b7f474d1d..c891e5ec3 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -45,6 +45,7 @@ import type { ForegroundSessionInfo, GetAuthStatusResponse, GetStatusResponse, + MCPServerConfig, ModelInfo, ProviderConfig, ResumeSessionConfig, @@ -80,6 +81,26 @@ function toWireProviderConfig(provider: ProviderConfig): Record return { ...rest, maxPromptTokens: maxInputTokens }; } +function normalizeMcpServers( + mcpServers: Record | undefined +): Record | undefined { + if (!mcpServers) return undefined; + + let normalized: Record | undefined; + for (const [name, server] of Object.entries(mcpServers)) { + if ( + "command" in server && + server.args === undefined && + (server.type === undefined || server.type === "stdio" || server.type === "local") + ) { + normalized ??= { ...mcpServers }; + normalized[name] = { ...server, args: [] }; + } + } + + return normalized ?? mcpServers; +} + /** * Minimum protocol version this SDK can communicate with. * Servers reporting a version below this are rejected. @@ -829,9 +850,13 @@ export class CopilotClient { workingDirectory: config.workingDirectory, streaming: config.streaming, includeSubAgentStreamingEvents: config.includeSubAgentStreamingEvents ?? true, - mcpServers: config.mcpServers, + mcpServers: normalizeMcpServers(config.mcpServers), envValueMode: "direct", - customAgents: config.customAgents, + customAgents: config.customAgents?.map((agent) => + agent.mcpServers + ? { ...agent, mcpServers: normalizeMcpServers(agent.mcpServers) } + : agent + ), defaultAgent: config.defaultAgent, agent: config.agent, configDir: config.configDir, @@ -970,9 +995,13 @@ export class CopilotClient { enableConfigDiscovery: config.enableConfigDiscovery, streaming: config.streaming, includeSubAgentStreamingEvents: config.includeSubAgentStreamingEvents ?? true, - mcpServers: config.mcpServers, + mcpServers: normalizeMcpServers(config.mcpServers), envValueMode: "direct", - customAgents: config.customAgents, + customAgents: config.customAgents?.map((agent) => + agent.mcpServers + ? { ...agent, mcpServers: normalizeMcpServers(agent.mcpServers) } + : agent + ), defaultAgent: config.defaultAgent, agent: config.agent, skillDirectories: config.skillDirectories, diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index a8e3bdfe5..00cb177a6 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1165,7 +1165,7 @@ interface MCPServerConfigBase { export interface MCPStdioServerConfig extends MCPServerConfigBase { type?: "local" | "stdio"; command: string; - args: string[]; + args?: string[]; /** * Environment variables to pass to the server. */ diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index a92f54253..8bd27d616 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, expect, it, onTestFinished, vi } from "vitest"; -import { approveAll, CopilotClient, type ModelInfo } from "../src/index.js"; +import { approveAll, CopilotClient, type MCPServerConfig, type ModelInfo } from "../src/index.js"; import { CopilotSession } from "../src/session.js"; import { defaultJoinSessionPermissionHandler } from "../src/types.js"; @@ -82,6 +82,28 @@ describe("CopilotClient", () => { ); }); + it("defaults omitted stdio MCP server args to an empty array in session.create", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockResolvedValue({ sessionId: "mcp-session" }); + const mcpServers: Record = { + local: { command: "mcp-server", tools: ["*"] }, + }; + + await client.createSession({ + onPermissionRequest: approveAll, + mcpServers, + }); + + const payload = spy.mock.calls.find((c) => c[0] === "session.create")![1] as any; + expect(payload.mcpServers.local.args).toEqual([]); + expect(mcpServers.local).not.toHaveProperty("args"); + }); + it("forwards clientName in session.resume request", async () => { const client = new CopilotClient(); await client.start(); @@ -107,6 +129,33 @@ describe("CopilotClient", () => { spy.mockRestore(); }); + it("defaults omitted stdio MCP server args to an empty array in session.resume", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.resume") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + const mcpServers: Record = { + local: { command: "mcp-server", tools: ["*"] }, + }; + + await client.resumeSession(session.sessionId, { + onPermissionRequest: approveAll, + mcpServers, + }); + + const payload = spy.mock.calls.find((c) => c[0] === "session.resume")![1] as any; + expect(payload.mcpServers.local.args).toEqual([]); + expect(mcpServers.local).not.toHaveProperty("args"); + spy.mockRestore(); + }); + it("forwards enableSessionTelemetry in session.create request", async () => { const client = new CopilotClient(); await client.start(); diff --git a/python/copilot/client.py b/python/copilot/client.py index cb5c98c90..1f20c6ea8 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -1548,7 +1548,7 @@ async def create_session( # Add MCP servers configuration if provided if mcp_servers: - payload["mcpServers"] = mcp_servers + payload["mcpServers"] = self._normalize_mcp_servers(mcp_servers) payload["envValueMode"] = "direct" # Add custom agents configuration if provided @@ -1921,7 +1921,7 @@ async def resume_session( # TODO: disable_resume is not a keyword arg yet; keeping for future use if mcp_servers: - payload["mcpServers"] = mcp_servers + payload["mcpServers"] = self._normalize_mcp_servers(mcp_servers) payload["envValueMode"] = "direct" if custom_agents: @@ -2556,7 +2556,7 @@ def _convert_custom_agent_to_wire_format( if "tools" in agent: wire_agent["tools"] = agent["tools"] if "mcp_servers" in agent: - wire_agent["mcpServers"] = agent["mcp_servers"] + wire_agent["mcpServers"] = self._normalize_mcp_servers(agent["mcp_servers"]) if "infer" in agent: wire_agent["infer"] = agent["infer"] if "skills" in agent: @@ -2582,6 +2582,26 @@ def _convert_default_agent_to_wire_format( wire["excludedTools"] = config["excluded_tools"] return wire + def _normalize_mcp_servers( + self, mcp_servers: dict[str, MCPServerConfig] | None + ) -> dict[str, MCPServerConfig] | None: + if not mcp_servers: + return mcp_servers + + normalized: dict[str, MCPServerConfig] | None = None + for name, server in mcp_servers.items(): + server_type = server.get("type") + if ( + "command" in server + and "args" not in server + and (server_type is None or server_type in ("local", "stdio")) + ): + if normalized is None: + normalized = dict(mcp_servers) + normalized[name] = {**server, "args": []} + + return normalized if normalized is not None else mcp_servers + async def _start_cli_server(self) -> None: """ Start the CLI server process. diff --git a/python/test_client.py b/python/test_client.py index c03968c55..af27f98ea 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -114,6 +114,63 @@ async def mock_request(method, params): finally: await client.force_stop() + @pytest.mark.asyncio + async def test_create_session_defaults_omitted_stdio_mcp_args_to_empty_list(self): + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + try: + captured = {} + + async def mock_request(method, params): + captured[method] = params + if method == "session.create": + return {"sessionId": params["sessionId"], "workspacePath": None} + return {} + + mcp_servers = {"local": {"command": "mcp-server", "tools": ["*"]}} + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + mcp_servers=mcp_servers, + ) + + assert captured["session.create"]["mcpServers"]["local"]["args"] == [] + assert "args" not in mcp_servers["local"] + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_create_session_defaults_custom_agent_stdio_mcp_args_to_empty_list(self): + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + try: + captured = {} + + async def mock_request(method, params): + captured[method] = params + if method == "session.create": + return {"sessionId": params["sessionId"], "workspacePath": None} + return {} + + custom_agents = [ + { + "name": "agent", + "prompt": "You are helpful.", + "mcp_servers": {"local": {"command": "mcp-server", "tools": ["*"]}}, + } + ] + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + custom_agents=custom_agents, + ) + + agent = captured["session.create"]["customAgents"][0] + assert agent["mcpServers"]["local"]["args"] == [] + assert "args" not in custom_agents[0]["mcp_servers"]["local"] + finally: + await client.force_stop() + class TestURLParsing: def test_parse_port_only_url(self):