Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions go/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,9 @@ func (MCPStdioServerConfig) mcpServerConfig() {}
// MarshalJSON implements json.Marshaler, injecting the "type" discriminator.
func (c MCPStdioServerConfig) MarshalJSON() ([]byte, error) {
type alias MCPStdioServerConfig
if c.Args == nil {
c.Args = []string{}
}
return json.Marshal(struct {
Type string `json:"type"`
alias
Expand Down
25 changes: 25 additions & 0 deletions go/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,31 @@ func TestProviderConfig_JSONIncludesHeaders(t *testing.T) {
}
}

func TestMCPStdioServerConfig_JSONDefaultsNilArgsToEmptyArray(t *testing.T) {
config := MCPStdioServerConfig{
Command: "mcp-server",
Tools: []string{"*"},
}

data, err := json.Marshal(config)
if err != nil {
t.Fatalf("failed to marshal MCP stdio server config: %v", err)
}

var decoded map[string]any
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal MCP stdio server config: %v", err)
}

args, ok := decoded["args"].([]any)
if !ok {
t.Fatalf("expected args to be an array, got %T in %s", decoded["args"], string(data))
}
if len(args) != 0 {
t.Fatalf("expected empty args array, got %v", args)
}
}

func TestSessionSendRequest_JSONIncludesRequestHeaders(t *testing.T) {
req := sessionSendRequest{
SessionID: "session-1",
Expand Down
37 changes: 33 additions & 4 deletions nodejs/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import type {
ForegroundSessionInfo,
GetAuthStatusResponse,
GetStatusResponse,
MCPServerConfig,
ModelInfo,
ProviderConfig,
ResumeSessionConfig,
Expand Down Expand Up @@ -80,6 +81,26 @@ function toWireProviderConfig(provider: ProviderConfig): Record<string, unknown>
return { ...rest, maxPromptTokens: maxInputTokens };
}

function normalizeMcpServers(
mcpServers: Record<string, MCPServerConfig> | undefined
): Record<string, MCPServerConfig> | undefined {
if (!mcpServers) return undefined;

let normalized: Record<string, MCPServerConfig> | undefined;
for (const [name, server] of Object.entries(mcpServers)) {
if (
"command" in server &&
server.args === undefined &&
(server.type === undefined || server.type === "stdio" || server.type === "local")
) {
normalized ??= { ...mcpServers };
normalized[name] = { ...server, args: [] };
}
}

return normalized ?? mcpServers;
}

/**
* Minimum protocol version this SDK can communicate with.
* Servers reporting a version below this are rejected.
Expand Down Expand Up @@ -829,9 +850,13 @@ export class CopilotClient {
workingDirectory: config.workingDirectory,
streaming: config.streaming,
includeSubAgentStreamingEvents: config.includeSubAgentStreamingEvents ?? true,
mcpServers: config.mcpServers,
mcpServers: normalizeMcpServers(config.mcpServers),
envValueMode: "direct",
customAgents: config.customAgents,
customAgents: config.customAgents?.map((agent) =>
agent.mcpServers
? { ...agent, mcpServers: normalizeMcpServers(agent.mcpServers) }
: agent
),
defaultAgent: config.defaultAgent,
agent: config.agent,
configDir: config.configDir,
Expand Down Expand Up @@ -970,9 +995,13 @@ export class CopilotClient {
enableConfigDiscovery: config.enableConfigDiscovery,
streaming: config.streaming,
includeSubAgentStreamingEvents: config.includeSubAgentStreamingEvents ?? true,
mcpServers: config.mcpServers,
mcpServers: normalizeMcpServers(config.mcpServers),
envValueMode: "direct",
customAgents: config.customAgents,
customAgents: config.customAgents?.map((agent) =>
agent.mcpServers
? { ...agent, mcpServers: normalizeMcpServers(agent.mcpServers) }
: agent
),
defaultAgent: config.defaultAgent,
agent: config.agent,
skillDirectories: config.skillDirectories,
Expand Down
2 changes: 1 addition & 1 deletion nodejs/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ interface MCPServerConfigBase {
export interface MCPStdioServerConfig extends MCPServerConfigBase {
type?: "local" | "stdio";
command: string;
args: string[];
args?: string[];
/**
* Environment variables to pass to the server.
*/
Expand Down
51 changes: 50 additions & 1 deletion nodejs/test/client.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, expect, it, onTestFinished, vi } from "vitest";
import { approveAll, CopilotClient, type ModelInfo } from "../src/index.js";
import { approveAll, CopilotClient, type MCPServerConfig, type ModelInfo } from "../src/index.js";
import { CopilotSession } from "../src/session.js";
import { defaultJoinSessionPermissionHandler } from "../src/types.js";

Expand Down Expand Up @@ -82,6 +82,28 @@ describe("CopilotClient", () => {
);
});

it("defaults omitted stdio MCP server args to an empty array in session.create", async () => {
const client = new CopilotClient();
await client.start();
onTestFinished(() => client.forceStop());

const spy = vi
.spyOn((client as any).connection!, "sendRequest")
.mockResolvedValue({ sessionId: "mcp-session" });
const mcpServers: Record<string, MCPServerConfig> = {
local: { command: "mcp-server", tools: ["*"] },
};

await client.createSession({
onPermissionRequest: approveAll,
mcpServers,
});

const payload = spy.mock.calls.find((c) => c[0] === "session.create")![1] as any;
expect(payload.mcpServers.local.args).toEqual([]);
expect(mcpServers.local).not.toHaveProperty("args");
});

it("forwards clientName in session.resume request", async () => {
const client = new CopilotClient();
await client.start();
Expand All @@ -107,6 +129,33 @@ describe("CopilotClient", () => {
spy.mockRestore();
});

it("defaults omitted stdio MCP server args to an empty array in session.resume", async () => {
const client = new CopilotClient();
await client.start();
onTestFinished(() => client.forceStop());

const session = await client.createSession({ onPermissionRequest: approveAll });
const spy = vi
.spyOn((client as any).connection!, "sendRequest")
.mockImplementation(async (method: string, params: any) => {
if (method === "session.resume") return { sessionId: params.sessionId };
throw new Error(`Unexpected method: ${method}`);
});
const mcpServers: Record<string, MCPServerConfig> = {
local: { command: "mcp-server", tools: ["*"] },
};

await client.resumeSession(session.sessionId, {
onPermissionRequest: approveAll,
mcpServers,
});

const payload = spy.mock.calls.find((c) => c[0] === "session.resume")![1] as any;
expect(payload.mcpServers.local.args).toEqual([]);
expect(mcpServers.local).not.toHaveProperty("args");
spy.mockRestore();
});

it("forwards enableSessionTelemetry in session.create request", async () => {
const client = new CopilotClient();
await client.start();
Expand Down
26 changes: 23 additions & 3 deletions python/copilot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ async def create_session(

# Add MCP servers configuration if provided
if mcp_servers:
payload["mcpServers"] = mcp_servers
payload["mcpServers"] = self._normalize_mcp_servers(mcp_servers)
payload["envValueMode"] = "direct"

# Add custom agents configuration if provided
Expand Down Expand Up @@ -1921,7 +1921,7 @@ async def resume_session(

# TODO: disable_resume is not a keyword arg yet; keeping for future use
if mcp_servers:
payload["mcpServers"] = mcp_servers
payload["mcpServers"] = self._normalize_mcp_servers(mcp_servers)
payload["envValueMode"] = "direct"

if custom_agents:
Expand Down Expand Up @@ -2556,7 +2556,7 @@ def _convert_custom_agent_to_wire_format(
if "tools" in agent:
wire_agent["tools"] = agent["tools"]
if "mcp_servers" in agent:
wire_agent["mcpServers"] = agent["mcp_servers"]
wire_agent["mcpServers"] = self._normalize_mcp_servers(agent["mcp_servers"])
if "infer" in agent:
wire_agent["infer"] = agent["infer"]
if "skills" in agent:
Expand All @@ -2582,6 +2582,26 @@ def _convert_default_agent_to_wire_format(
wire["excludedTools"] = config["excluded_tools"]
return wire

def _normalize_mcp_servers(
self, mcp_servers: dict[str, MCPServerConfig] | None
) -> dict[str, MCPServerConfig] | None:
if not mcp_servers:
return mcp_servers

normalized: dict[str, MCPServerConfig] | None = None
for name, server in mcp_servers.items():
server_type = server.get("type")
if (
"command" in server
and "args" not in server
and (server_type is None or server_type in ("local", "stdio"))
):
if normalized is None:
normalized = dict(mcp_servers)
normalized[name] = {**server, "args": []}

return normalized if normalized is not None else mcp_servers

async def _start_cli_server(self) -> None:
"""
Start the CLI server process.
Expand Down
57 changes: 57 additions & 0 deletions python/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,63 @@ async def mock_request(method, params):
finally:
await client.force_stop()

@pytest.mark.asyncio
async def test_create_session_defaults_omitted_stdio_mcp_args_to_empty_list(self):
client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH))
await client.start()
try:
captured = {}

async def mock_request(method, params):
captured[method] = params
if method == "session.create":
return {"sessionId": params["sessionId"], "workspacePath": None}
return {}

mcp_servers = {"local": {"command": "mcp-server", "tools": ["*"]}}
client._client.request = mock_request
await client.create_session(
on_permission_request=PermissionHandler.approve_all,
mcp_servers=mcp_servers,
)

assert captured["session.create"]["mcpServers"]["local"]["args"] == []
assert "args" not in mcp_servers["local"]
finally:
await client.force_stop()

@pytest.mark.asyncio
async def test_create_session_defaults_custom_agent_stdio_mcp_args_to_empty_list(self):
client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH))
await client.start()
try:
captured = {}

async def mock_request(method, params):
captured[method] = params
if method == "session.create":
return {"sessionId": params["sessionId"], "workspacePath": None}
return {}

custom_agents = [
{
"name": "agent",
"prompt": "You are helpful.",
"mcp_servers": {"local": {"command": "mcp-server", "tools": ["*"]}},
}
]
client._client.request = mock_request
await client.create_session(
on_permission_request=PermissionHandler.approve_all,
custom_agents=custom_agents,
)

agent = captured["session.create"]["customAgents"][0]
assert agent["mcpServers"]["local"]["args"] == []
assert "args" not in custom_agents[0]["mcp_servers"]["local"]
finally:
await client.force_stop()


class TestURLParsing:
def test_parse_port_only_url(self):
Expand Down