diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index b058535ba..a18f7d6f1 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -39,6 +39,8 @@ ## Project-specific conventions & patterns ✅ +- **Cross-SDK features ship in a single PR.** When adding or changing a feature that lives in multiple SDKs (client/session API, wire protocol, etc.), put all the language implementations in **one** pull request — not one PR per language. The `sdk-consistency-review` workflow (`.github/workflows/sdk-consistency-review.md`) auto-reviews every PR touching `nodejs/`, `python/`, `go/`, or `dotnet/` and will flag per-language PRs as "missing in N other languages" even when follow-ups are planned. Exception: an early, isolated reference implementation (e.g. the Rust spike at PR #1394) can ship alone; the remaining languages then follow up as **one** consolidated PR. + - Tools: each SDK has helper APIs to expose functions as tools; prefer the language's `DefineTool`/`@define_tool`/`CopilotTool.DefineTool` patterns (see language READMEs). - Infinite sessions are enabled by default and persist workspace state to `~/.copilot/session-state/{sessionId}`; compaction events are emitted (`session.compaction_start`, `session.compaction_complete`). See language READMEs for usage. - Streaming: when `streaming`/`Streaming=true` you receive delta events (`assistant.message_delta`, `assistant.reasoning_delta`) and final events (`assistant.message`, `assistant.reasoning`) — tests expect this behavior. diff --git a/dotnet/README.md b/dotnet/README.md index a9527f447..b7bea362a 100644 --- a/dotnet/README.md +++ b/dotnet/README.md @@ -133,6 +133,47 @@ Resume an existing session. Returns the session with `WorkspacePath` populated i - `OnPermissionRequest` - Optional handler called before each tool execution to approve or deny it. See [Permission Handling](#permission-handling) section. +##### `CreateCloudSessionAsync(CloudSessionConfig config): Task` + +Create a session that is hosted by and routed to a remote Copilot cloud server rather than the local CLI. The returned `CopilotSession` behaves identically to a locally-created session: send messages, receive events, and use tools through the same API. + +A `Cloud` configuration is required; passing `Cloud = null` throws `ArgumentException`. Conversely, passing a `Cloud` config to the regular `CreateSessionAsync` also throws, preventing accidental misconfiguration. + +The session ID is assigned by the cloud server. After `CreateCloudSessionAsync` returns, `session.SessionId` contains the server-assigned ID and `session.RemoteUrl` contains the URL of the cloud server that owns the session. Any RPC requests (user-input, permission, hooks, etc.) that arrive from the server before the session is fully registered are buffered and replayed automatically — your handlers will never miss an early request. + +**CloudSessionConfig:** + +- `Cloud` *(required)* — Cloud routing options: + - `Repository` *(required)* — The repository to route the session to: + - `Nwo` *(required)* — Full `owner/repo` name (e.g. `"github/my-repo"`) + - `Ref` *(optional)* — Branch or tag ref (e.g. `"refs/heads/main"`) + - `SdkOptions` *(optional)* — Pass-through options forwarded to the cloud provider +- `Model`, `ReasoningEffort`, `Tools`, `SystemMessage`, `AvailableTools`, `ExcludedTools`, `Provider`, `Streaming`, `InfiniteSessions`, `OnPermissionRequest`, `OnUserInputRequest`, `Hooks` — Same as `SessionConfig` + +**Returned session properties:** + +- `SessionId` — Server-assigned session ID +- `RemoteUrl` — URL of the cloud server that owns this session + +```csharp +await using var session = await client.CreateCloudSessionAsync(new CloudSessionConfig +{ + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Nwo = "github/my-repo" } + }, + OnUserInputRequest = async (request, _) => + { + Console.Write($"{request.Question} "); + return new UserInputResponse { Answer = Console.ReadLine()! }; + } +}); + +Console.WriteLine($"Cloud session: {session.SessionId} at {session.RemoteUrl}"); +var reply = await session.SendAsync(new MessageOptions { Content = "Hello from the cloud!" }); +Console.WriteLine(reply); +``` + ##### `PingAsync(string? message = null): Task` Ping the server to check connectivity. @@ -193,6 +234,7 @@ Represents a single conversation session. - `SessionId` - The unique identifier for this session - `WorkspacePath` - Path to the session workspace directory when infinite sessions are enabled. Contains `checkpoints/`, `plan.md`, and `files/` subdirectories. Null if infinite sessions are disabled. +- `RemoteUrl` - URL of the cloud server that owns this session. Non-null only for sessions created via `CreateCloudSessionAsync`. #### Methods diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index a5cc62354..6d90a668e 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -83,6 +83,16 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private readonly object _lifecycleHandlersLock = new(); private ServerRpc? _serverRpc; + // Pending-routing state for cloud session.create in-flight windows. + // While _pendingRoutingCount > 0, notifications and inbound requests + // addressed to not-yet-registered session ids are buffered and replayed + // once the runtime-assigned id is registered. + private int _pendingRoutingCount; + private readonly Dictionary> _pendingSessionEvents = new(); + private readonly Dictionary>> _pendingSessionWaiters = new(); + private readonly object _pendingLock = new(); + private const int PendingSessionBufferLimit = 128; + private sealed record LifecycleSubscription(Type EventType, Action Handler); /// @@ -520,6 +530,12 @@ public async Task CreateSessionAsync(SessionConfig config, Cance { ArgumentNullException.ThrowIfNull(config); + if (config.Cloud != null) + { + throw new InvalidOperationException( + "CopilotClient.CreateSessionAsync does not support cloud sessions; use CreateCloudSessionAsync instead."); + } + var connection = await EnsureConnectedAsync(cancellationToken); var totalTimestamp = Stopwatch.GetTimestamp(); @@ -651,6 +667,359 @@ public async Task CreateSessionAsync(SessionConfig config, Cance return session; } + /// + /// Creates a Mission Control–backed cloud session. + /// + /// + /// + /// The runtime owns the session ID for cloud sessions: do not set + /// or on the + /// config (the SDK rejects both). The SDK omits sessionId from the + /// session.create wire payload and registers the resulting session under the id + /// the runtime returns. + /// + /// + /// Any session.event notifications or inbound JSON-RPC requests that arrive between + /// sending session.create and receiving its response are buffered (bounded, + /// drop-oldest, limit 128 per id) and replayed once the returned id is registered, + /// so early events are not lost. + /// + /// + /// Known limitation: inbound sessionFs.* requests are not pending-buffered. + /// In practice the runtime does not initiate sessionFs.* before the + /// session.create response, so this is theoretical. + /// + /// + /// Configuration for the cloud session. is required. + /// A that can be used to cancel the operation. + /// A task that resolves to the created . + /// Thrown when is null. + /// Thrown when is null, or when + /// or is set on the config. + /// + /// + /// var session = await client.CreateCloudSessionAsync(new SessionConfig + /// { + /// OnPermissionRequest = PermissionHandler.ApproveAll, + /// Cloud = new CloudSessionOptions + /// { + /// Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk", Branch = "main" } + /// } + /// }); + /// Console.WriteLine($"Cloud session id: {session.SessionId}"); + /// + /// + public async Task CreateCloudSessionAsync(SessionConfig config, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(config); + + if (config.Cloud == null) + { + throw new InvalidOperationException( + "CopilotClient.CreateCloudSessionAsync requires config.Cloud to be set."); + } + if (!string.IsNullOrEmpty(config.SessionId)) + { + throw new InvalidOperationException( + "CopilotClient.CreateCloudSessionAsync does not support a caller-provided SessionId; the runtime assigns one."); + } + if (config.Provider != null) + { + throw new InvalidOperationException( + "CopilotClient.CreateCloudSessionAsync does not support config.Provider; cloud sessions use the runtime's provider."); + } + + var connection = await EnsureConnectedAsync(cancellationToken); + var totalTimestamp = Stopwatch.GetTimestamp(); + + var hasHooks = config.Hooks != null && ( + config.Hooks.OnPreToolUse != null || + config.Hooks.OnPreMcpToolCall != null || + config.Hooks.OnPostToolUse != null || + config.Hooks.OnUserPromptSubmitted != null || + config.Hooks.OnSessionStart != null || + config.Hooks.OnSessionEnd != null || + config.Hooks.OnErrorOccurred != null); + + var (wireSystemMessage, transformCallbacks) = ExtractTransformCallbacks(config.SystemMessage); + + // Begin pending-routing mode so notifications/requests that arrive + // before the runtime returns the session id are buffered. + var guard = BeginPendingSessionRouting(); + + CreateSessionResponse response; + try + { + var (traceparent, tracestate) = TelemetryHelpers.GetTraceContext(); + + // sessionId is intentionally omitted (null) on the cloud path: + // the runtime assigns the Mission Control session id. + var request = new CreateSessionRequest( + config.Model, + null /* sessionId omitted */, + config.ClientName, + config.ReasoningEffort, + config.Tools?.Select(ToolDefinition.FromAIFunction).ToList(), + wireSystemMessage, + config.AvailableTools, + config.ExcludedTools, + null /* Provider not allowed on cloud path */, + config.EnableSessionTelemetry, + config.OnPermissionRequest != null ? true : null, + config.OnUserInputRequest != null ? true : null, + config.OnExitPlanModeRequest != null ? true : null, + config.OnAutoModeSwitchRequest != null ? true : null, + hasHooks ? true : null, + config.WorkingDirectory, + config.Streaming is true ? true : null, + config.IncludeSubAgentStreamingEvents, + config.McpServers, + "direct", + config.CustomAgents, + config.DefaultAgent, + config.Agent, + config.ConfigDir, + config.EnableConfigDiscovery, + config.SkillDirectories, + config.DisabledSkills, + config.InfiniteSessions, + Commands: config.Commands?.Select(c => new CommandWireDefinition(c.Name, c.Description)).ToList(), + RequestElicitation: config.OnElicitationRequest != null, + Traceparent: traceparent, + Tracestate: tracestate, + ModelCapabilities: config.ModelCapabilities, + GitHubToken: config.GitHubToken, + RemoteSession: config.RemoteSession, + Cloud: config.Cloud, + InstructionDirectories: config.InstructionDirectories); + + response = await InvokeRpcAsync( + connection.Rpc, "session.create", [request], cancellationToken); + } + catch (Exception ex) + { + guard.Dispose(); + if (ex is not OperationCanceledException) + { + LoggingHelpers.LogTiming(_logger, LogLevel.Warning, ex, + "CopilotClient.CreateCloudSessionAsync failed during session.create RPC. Elapsed={Elapsed}", + totalTimestamp); + } + throw; + } + + if (string.IsNullOrEmpty(response.SessionId)) + { + // No id to issue session.destroy against; release the guard and surface the error. + // Any runtime session created on the other side may leak. + _logger.LogWarning("Cloud session.create response missing sessionId; runtime session may leak."); + guard.Dispose(); + throw new InvalidOperationException( + "Cloud session.create response did not include a sessionId; cannot register session."); + } + + var sessionId = response.SessionId; + var setupTimestamp = Stopwatch.GetTimestamp(); + var session = new CopilotSession(sessionId, connection.Rpc, _logger, this); + session.RegisterTools(config.Tools ?? []); + session.RegisterPermissionHandler(config.OnPermissionRequest); + session.RegisterCommands(config.Commands); + session.RegisterElicitationHandler(config.OnElicitationRequest); + session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest); + session.RegisterAutoModeSwitchHandler(config.OnAutoModeSwitchRequest); + if (config.OnUserInputRequest != null) + { + session.RegisterUserInputHandler(config.OnUserInputRequest); + } + if (config.Hooks != null) + { + session.RegisterHooks(config.Hooks); + } + if (transformCallbacks != null) + { + session.RegisterTransformCallbacks(transformCallbacks); + } + if (config.OnEvent != null) + { + session.On(config.OnEvent); + } + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, + "CopilotClient.CreateCloudSessionAsync local setup complete. Elapsed={Elapsed}, SessionId={SessionId}", + setupTimestamp, + sessionId); + + try + { + RegisterSession(session); + ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider); + session.StartProcessingEvents(); + session.WorkspacePath = response.WorkspacePath; + session.SetCapabilities(response.Capabilities); + session.RemoteUrl = response.RemoteUrl; + + // Flush buffered notifications and unblock parked request waiters + // now that the session is registered. Must happen before the guard + // is released so nothing races into a still-pending buffer. + FlushPendingForSession(sessionId, session); + } + catch (Exception ex) + { + session.RemoveFromClient(); + guard.Dispose(); + // Destroy the cloud session on the server — we already have the sessionId from the + // session.create response, so we can send session.destroy even though setup failed. + try { await session.DisposeAsync().ConfigureAwait(false); } catch { /* best effort */ } + LoggingHelpers.LogTiming(_logger, LogLevel.Warning, ex, + "CopilotClient.CreateCloudSessionAsync failed during post-response setup. Elapsed={Elapsed}, SessionId={SessionId}", + totalTimestamp, + sessionId); + throw; + } + + guard.Dispose(); + + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, + "CopilotClient.CreateCloudSessionAsync complete. Elapsed={Elapsed}, SessionId={SessionId}", + totalTimestamp, + sessionId); + return session; + } + + /// + /// Enter pending-routing mode. While the returned guard is undisposed, + /// notifications and inbound requests addressed to session ids that are + /// not yet registered are buffered (up to + /// per id, drop-oldest) and replayed on registration. When the last guard is + /// disposed without registration, buffered events are dropped and parked request + /// waiters are faulted with an . + /// + private PendingSessionRoutingGuard BeginPendingSessionRouting() + { + lock (_pendingLock) + { + _pendingRoutingCount++; + } + return new PendingSessionRoutingGuard(this); + } + + private void EndPendingSessionRouting() + { + List>? waiters = null; + lock (_pendingLock) + { + _pendingRoutingCount--; + if (_pendingRoutingCount == 0) + { + _pendingSessionEvents.Clear(); + waiters = new List>(); + foreach (var list in _pendingSessionWaiters.Values) + waiters.AddRange(list); + _pendingSessionWaiters.Clear(); + } + } + + // Fault pending waiters outside the lock so TCS continuations don't run under it. + // Distinct phrasing from the overflow-eviction path so the runtime / debugging can tell + // the two cases apart. Matches the Rust SDK message in PR #1394 (commit e0ff254f). + if (waiters != null) + { + foreach (var tcs in waiters) + { + tcs.TrySetException(new InvalidOperationException( + "pending session routing ended before session was registered")); + } + } + } + + private void FlushPendingForSession(string sessionId, CopilotSession session) + { + List events; + List> waiters; + + lock (_pendingLock) + { + _pendingSessionEvents.TryGetValue(sessionId, out var rawEvents); + _pendingSessionEvents.Remove(sessionId); + events = rawEvents ?? []; + + _pendingSessionWaiters.TryGetValue(sessionId, out var rawWaiters); + _pendingSessionWaiters.Remove(sessionId); + waiters = rawWaiters ?? []; + } + + foreach (var evt in events) + { + session.DispatchEvent(evt); + } + + // Resolve waiters outside the lock so TCS continuations don't run under it. + foreach (var tcs in waiters) + { + tcs.TrySetResult(session); + } + } + + private Task ResolveSessionAsync(string sessionId) + { + var session = GetSession(sessionId); + if (session != null) + { + return Task.FromResult(session); + } + + lock (_pendingLock) + { + // Re-check inside lock to avoid the race where the session was registered + // between the unlock-free GetSession call and acquiring _pendingLock. + session = GetSession(sessionId); + if (session != null) + { + return Task.FromResult(session); + } + + if (_pendingRoutingCount > 0) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + if (!_pendingSessionWaiters.TryGetValue(sessionId, out var list)) + { + list = []; + _pendingSessionWaiters[sessionId] = list; + } + + // Cap parked waiters per session id. When exceeded, reject the oldest with a + // distinct message so the runtime isn't left waiting on a hung request id. + // RunContinuationsAsynchronously ensures the continuation won't run under this lock. + // Matches the Rust SDK fix in PR #1394 (commit 491b4427). + if (list.Count >= PendingSessionBufferLimit) + { + var oldest = list[0]; + list.RemoveAt(0); + oldest.TrySetException(new InvalidOperationException("pending session buffer overflow")); + } + + list.Add(tcs); + return tcs.Task; + } + } + + return Task.FromException(new ArgumentException($"Unknown session {sessionId}")); + } + + private sealed class PendingSessionRoutingGuard : IDisposable + { + private readonly CopilotClient _client; + private bool _disposed; + + internal PendingSessionRoutingGuard(CopilotClient client) => _client = client; + + public void Dispose() + { + if (_disposed) return; + _disposed = true; + _client.EndPendingSessionRouting(); + } + } + /// /// Resumes an existing Copilot session with the specified configuration. /// @@ -1710,14 +2079,45 @@ private class RpcHandler(CopilotClient client) { public void OnSessionEvent(string sessionId, JsonElement? @event) { + if (@event == null) return; + var session = client.GetSession(sessionId); - if (session != null && @event != null) + if (session != null) { var evt = SessionEvent.FromJson(@event.Value.GetRawText()); if (evt != null) { session.DispatchEvent(evt); } + return; + } + + // Session not yet registered — buffer if pending routing is active. + lock (client._pendingLock) + { + if (client._pendingRoutingCount == 0) + { + return; + } + + var parsed = SessionEvent.FromJson(@event.Value.GetRawText()); + if (parsed == null) return; + + if (!client._pendingSessionEvents.TryGetValue(sessionId, out var buf)) + { + buf = []; + client._pendingSessionEvents[sessionId] = buf; + } + + if (buf.Count >= PendingSessionBufferLimit) + { + buf.RemoveAt(0); + client._logger.LogWarning( + "Pending session event buffer full for session {SessionId}; dropping oldest event.", + sessionId); + } + + buf.Add(parsed); } } @@ -1747,7 +2147,7 @@ public void OnSessionLifecycle(string type, string sessionId, JsonElement? metad public async ValueTask OnUserInputRequest(string sessionId, string question, IList? choices = null, bool? allowFreeform = null) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); var request = new UserInputRequest { Question = question, @@ -1766,7 +2166,7 @@ public async ValueTask OnExitPlanModeRequest( IList? actions = null, string? recommendedAction = null) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); var request = new ExitPlanModeRequest { Summary = summary, @@ -1783,7 +2183,7 @@ public async ValueTask OnAutoModeSwitchRequest( string? errorCode = null, double? retryAfterSeconds = null) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); var response = await session.HandleAutoModeSwitchRequestAsync(new AutoModeSwitchRequest { ErrorCode = errorCode, @@ -1794,14 +2194,14 @@ public async ValueTask OnAutoModeSwitchRequest( public async ValueTask OnHooksInvoke(string sessionId, string hookType, JsonElement input) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); var output = await session.HandleHooksInvokeAsync(hookType, input); return new HooksInvokeResponse(output); } public async ValueTask OnSystemMessageTransform(string sessionId, JsonElement sections) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); return await session.HandleSystemMessageTransformAsync(sections); } } @@ -1888,7 +2288,8 @@ public static ToolDefinition FromAIFunction(AIFunctionDeclaration function) internal record CreateSessionResponse( string SessionId, string? WorkspacePath, - SessionCapabilities? Capabilities = null); + SessionCapabilities? Capabilities = null, + string? RemoteUrl = null); internal record ResumeSessionRequest( string SessionId, diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 6ad8e14d9..646cc256f 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -107,6 +107,15 @@ private sealed record EventSubscription(Type EventType, Action Han /// public string? WorkspacePath { get; internal set; } + /// + /// Gets the remote URL of this cloud session, if available. + /// + /// + /// The Mission Control remote URL returned by the runtime when creating a cloud session, + /// or null for local sessions. + /// + public string? RemoteUrl { get; internal set; } + /// /// Gets the capabilities reported by the host for this session. /// diff --git a/dotnet/test/Unit/CloudSessionTests.cs b/dotnet/test/Unit/CloudSessionTests.cs new file mode 100644 index 000000000..97dfe9f02 --- /dev/null +++ b/dotnet/test/Unit/CloudSessionTests.cs @@ -0,0 +1,766 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +#if NET8_0_OR_GREATER +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Text.Json; +using Xunit; + +namespace GitHub.Copilot.Test.Unit; + +/// +/// Unit tests for and the rejection guard +/// on for cloud configs. +/// +public sealed class CloudSessionTests +{ + // ------------------------------------------------------------------------- + // 1. CreateSessionAsync rejects cloud config + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateSessionAsync_Rejects_CloudConfig() + { + await using var server = await FakeCloudServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + var ex = await Assert.ThrowsAsync(() => + client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions { Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } } + })); + + Assert.Contains("CreateCloudSessionAsync", ex.Message); + } + + // ------------------------------------------------------------------------- + // 2. CreateCloudSessionAsync sends session.create with cloud and without sessionId + // (wire-shape correctness: assert sessionId absent from serialized JSON) + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Sends_Create_With_Cloud_And_Without_SessionId() + { + await using var server = await FakeCloudServer.StartAsync(cloudSessionId: "remote-cloud-session"); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var session = await client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk", Branch = "main" } + } + }); + + // Verify the session id is the runtime-assigned one. + Assert.Equal("remote-cloud-session", session.SessionId); + + // Verify the wire payload captured by the server had no sessionId and had cloud. + var payload = server.LastCreatePayload; + Assert.NotNull(payload); + Assert.False(payload!.Value.TryGetProperty("sessionId", out _), + "session.create payload must not contain 'sessionId' on the cloud path."); + Assert.True(payload.Value.TryGetProperty("cloud", out var cloud)); + Assert.Equal("github", cloud.GetProperty("repository").GetProperty("owner").GetString()); + Assert.Equal("copilot-sdk", cloud.GetProperty("repository").GetProperty("name").GetString()); + Assert.Equal("main", cloud.GetProperty("repository").GetProperty("branch").GetString()); + } + + // Supplementary serialization-layer assertion: JsonSerializer.Serialize on a + // CreateSessionRequest with SessionId=null must not emit the key. + [Fact] + public void CreateSessionRequest_WithNullSessionId_DoesNotEmitSessionIdKey() + { + // Retrieve the private serializer options the SDK uses (same approach as SerializationTests). + var prop = typeof(CopilotClient) + .GetProperty("SerializerOptionsForMessageFormatter", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + var options = (JsonSerializerOptions?)prop?.GetValue(null); + Assert.NotNull(options); + + var requestType = typeof(CopilotClient) + .GetNestedType("CreateSessionRequest", System.Reflection.BindingFlags.NonPublic); + Assert.NotNull(requestType); + + // Build a request with SessionId = null (the cloud path). + var instance = System.Runtime.CompilerServices.RuntimeHelpers.GetUninitializedObject(requestType!); + var cloudField = requestType!.GetField("k__BackingField", + System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic); + cloudField?.SetValue(instance, new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "o", Name = "r" } + }); + + var json = JsonSerializer.Serialize(instance, requestType, options!); + + Assert.DoesNotContain("\"sessionId\"", json); + Assert.Contains("\"cloud\"", json); + } + + // ------------------------------------------------------------------------- + // 3. Rejects caller-provided SessionId + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Rejects_CallerSessionId() + { + await using var server = await FakeCloudServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + var ex = await Assert.ThrowsAsync(() => + client.CreateCloudSessionAsync(new SessionConfig + { + SessionId = "caller-id", + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions { Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } } + })); + + Assert.Contains("SessionId", ex.Message); + } + + // ------------------------------------------------------------------------- + // 4. Rejects caller-provided Provider + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Rejects_CallerProvider() + { + await using var server = await FakeCloudServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + var ex = await Assert.ThrowsAsync(() => + client.CreateCloudSessionAsync(new SessionConfig + { + Provider = new ProviderConfig { BaseUrl = "https://api.example.com/v1" }, + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions { Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } } + })); + + Assert.Contains("Provider", ex.Message); + } + + // ------------------------------------------------------------------------- + // 5. Requires Cloud option + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Requires_Cloud() + { + await using var server = await FakeCloudServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + var ex = await Assert.ThrowsAsync(() => + client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll + // Cloud deliberately absent + })); + + Assert.Contains("Cloud", ex.Message); + } + + // ------------------------------------------------------------------------- + // 6. Buffers early session.event notifications until session id is registered + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Buffers_Early_Notifications_Until_Registration() + { + const string cloudId = "remote-cloud-session"; + + // Server is configured to send a session.event notification before + // responding to session.create. + await using var server = await FakeCloudServer.StartAsync( + cloudSessionId: cloudId, + earlyNotification: new Dictionary + { + ["method"] = "session.event", + ["params"] = new Dictionary + { + ["sessionId"] = cloudId, + ["event"] = new Dictionary + { + ["type"] = "capabilities.changed", + ["data"] = new Dictionary + { + ["ui"] = new Dictionary { ["elicitation"] = true } + } + } + } + }); + + var receivedEvents = new List(); + + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + await using var session = await client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } + }, + OnEvent = evt => + { + if (evt is CapabilitiesChangedEvent capEvt) + { + receivedEvents.Add(capEvt); + } + } + }); + + // Allow the event channel to drain. + var deadline = DateTime.UtcNow.AddSeconds(5); + while (receivedEvents.Count == 0 && DateTime.UtcNow < deadline) + { + await Task.Delay(20); + } + + Assert.Single(receivedEvents); + Assert.True(receivedEvents[0].Data?.Ui?.Elicitation == true); + } + + // ------------------------------------------------------------------------- + // 7. Parks inbound requests until session id is registered + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Parks_Inbound_Requests_Until_Registration() + { + const string cloudId = "remote-cloud-session"; + + // Server sends a userInput.request before responding to session.create. + await using var server = await FakeCloudServer.StartAsync( + cloudSessionId: cloudId, + earlyInboundRequest: new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = 301, + ["method"] = "userInput.request", + ["params"] = new Dictionary + { + ["sessionId"] = cloudId, + ["question"] = "Pick a color", + ["choices"] = new object?[] { "red", "blue" }, + ["allowFreeform"] = true + } + }); + + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + await using var session = await client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } + }, + OnUserInputRequest = (_, _) => Task.FromResult(new UserInputResponse { Answer = "blue", WasFreeform = true }) + }); + + // Wait for the server to receive the userInput response. + var response = await server.WaitForUserInputResponse(TimeSpan.FromSeconds(5)); + Assert.NotNull(response); + Assert.Equal("blue", response!["answer"]?.ToString()); + } + + // ------------------------------------------------------------------------- + // 8. Pending-waiter overflow: oldest is rejected, remaining 128 succeed + // ------------------------------------------------------------------------- + + [Fact] + public async Task PendingRequestWaiterOverflow_RejectsOldestWithOverflowMessage() + { + const string cloudId = "overflow-session"; + const int requestCount = 129; // one beyond the 128-waiter cap + + await using var server = await FakeCloudServer.StartAsync( + cloudSessionId: cloudId, + earlyInboundRequestCount: requestCount); + + await using var client = new CopilotClient(new CopilotClientOptions + { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var _ = await client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } + }, + OnUserInputRequest = (_, _) => Task.FromResult(new UserInputResponse { Answer = "yes", WasFreeform = false }) + }); + + var responses = await server.WaitForAllInboundResponses(requestCount, TimeSpan.FromSeconds(10)); + + // Exactly one overflow eviction, 128 successful completions. + Assert.Equal(1, responses.Count(r => r.IsError)); + Assert.Equal(128, responses.Count(r => !r.IsError)); + + var err = responses.Single(r => r.IsError); + Assert.Contains("pending session buffer overflow", err.ErrorMessage ?? ""); + } + + // ------------------------------------------------------------------------- + // 9. Guard-drop path: parked requests are rejected with distinct message + // ------------------------------------------------------------------------- + + [Fact] + public async Task PendingSessionGuardDrop_RejectsParkedRequestWithDistinctMessage() + { + const string cloudId = "guard-drop-session"; + const int inboundRequestId = 500; + + await using var server = await FakeCloudServer.StartAsync( + cloudSessionId: cloudId, + failSessionCreate: true, + earlyInboundRequest: new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = inboundRequestId, + ["method"] = "userInput.request", + ["params"] = new Dictionary + { + ["sessionId"] = cloudId, + ["question"] = "Color?", + ["choices"] = new object?[] { "red", "blue" }, + ["allowFreeform"] = false + } + }); + + await using var client = new CopilotClient(new CopilotClientOptions + { Connection = RuntimeConnection.ForUri(server.Url) }); + + await Assert.ThrowsAnyAsync(() => + client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } + } + })); + + // The parked request must have been rejected with the guard-drop message (not the overflow message). + var responses = await server.WaitForAllInboundResponses(1, TimeSpan.FromSeconds(5)); + + Assert.Single(responses); + Assert.True(responses[0].IsError); + Assert.Contains( + "pending session routing ended before session was registered", + responses[0].ErrorMessage ?? ""); + } + + // ========================================================================= + // Fake server infrastructure + // ========================================================================= + + private sealed class FakeCloudServer : IAsyncDisposable + { + private readonly TcpListener _listener; + private readonly CancellationTokenSource _cts = new(); + private readonly SemaphoreSlim _writeLock = new(1, 1); + private readonly Task _serverTask; + private readonly string _cloudSessionId; + private readonly Dictionary? _earlyNotification; + private readonly Dictionary? _earlyInboundRequest; + private readonly int _earlyInboundRequestCount; + private readonly bool _failSessionCreate; + private readonly TaskCompletionSource?> _userInputResponseTcs = + new(TaskCreationOptions.RunContinuationsAsynchronously); + + // Response tracking for overflow / guard-drop tests. + private readonly object _inboundResponsesLock = new(); + private readonly List _collectedInboundResponses = []; + private int _waitForInboundResponseCount; + private TaskCompletionSource>? _allInboundResponsesTcs; + + public JsonElement? LastCreatePayload { get; private set; } + + public record InboundResponse(int RequestId, bool IsError, string? ErrorMessage); + + private FakeCloudServer( + TcpListener listener, + string cloudSessionId, + Dictionary? earlyNotification, + Dictionary? earlyInboundRequest, + int earlyInboundRequestCount, + bool failSessionCreate) + { + _listener = listener; + _cloudSessionId = cloudSessionId; + _earlyNotification = earlyNotification; + _earlyInboundRequest = earlyInboundRequest; + _earlyInboundRequestCount = earlyInboundRequestCount; + _failSessionCreate = failSessionCreate; + _serverTask = RunAsync(); + } + + public string Url + { + get + { + var endpoint = (IPEndPoint)_listener.LocalEndpoint; + return $"http://127.0.0.1:{endpoint.Port}"; + } + } + + public static Task StartAsync( + string cloudSessionId = "cloud-session-id", + Dictionary? earlyNotification = null, + Dictionary? earlyInboundRequest = null, + int earlyInboundRequestCount = 0, + bool failSessionCreate = false) + { + var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + return Task.FromResult(new FakeCloudServer( + listener, cloudSessionId, earlyNotification, earlyInboundRequest, + earlyInboundRequestCount, failSessionCreate)); + } + + public Task?> WaitForUserInputResponse(TimeSpan timeout) + => _userInputResponseTcs.Task.WaitAsync(timeout); + + /// + /// Waits until the server has collected responses (error or success) + /// from the client for inbound requests. Used by overflow and guard-drop tests. + /// + public Task> WaitForAllInboundResponses(int count, TimeSpan timeout) + { + var tcs = new TaskCompletionSource>( + TaskCreationOptions.RunContinuationsAsynchronously); + lock (_inboundResponsesLock) + { + _waitForInboundResponseCount = count; + _allInboundResponsesTcs = tcs; + if (_collectedInboundResponses.Count >= count) + tcs.TrySetResult(new List(_collectedInboundResponses)); + } + return tcs.Task.WaitAsync(timeout); + } + + private void RecordInboundResponse(InboundResponse response) + { + TaskCompletionSource>? tcs = null; + IReadOnlyList? snapshot = null; + lock (_inboundResponsesLock) + { + _collectedInboundResponses.Add(response); + if (_allInboundResponsesTcs != null && + _collectedInboundResponses.Count >= _waitForInboundResponseCount) + { + tcs = _allInboundResponsesTcs; + snapshot = new List(_collectedInboundResponses); + } + } + tcs?.TrySetResult(snapshot!); + } + + public async ValueTask DisposeAsync() + { + _cts.Cancel(); + _listener.Stop(); + + try { await _serverTask; } + catch (Exception ex) when (ex is OperationCanceledException or ObjectDisposedException or IOException or SocketException) { } + + _cts.Dispose(); + _writeLock.Dispose(); + } + + private async Task RunAsync() + { + using var tcpClient = await _listener.AcceptTcpClientAsync(_cts.Token); + using var stream = tcpClient.GetStream(); + + while (!_cts.Token.IsCancellationRequested) + { + using var request = await ReadMessageAsync(stream, _cts.Token); + if (request is null) return; + await HandleRequestAsync(stream, request.RootElement, _cts.Token); + } + } + + private async Task HandleRequestAsync(Stream stream, JsonElement request, CancellationToken cancellationToken) + { + // Identify the message type: + // - Response from SDK: has "id" and "result"/"error" but no "method" + // - Request from SDK: has "id" and "method" + // - Notification from SDK: has "method" but no "id" + + bool hasId = request.TryGetProperty("id", out var idElement); + bool hasMethod = request.TryGetProperty("method", out var methodEl); + bool hasResult = request.TryGetProperty("result", out var resultEl); + + if (hasId && !hasMethod) + { + // This is a response from the SDK (e.g. userInput reply). + if (hasResult) + { + var dict = new Dictionary(); + foreach (var prop in resultEl.EnumerateObject()) + { + dict[prop.Name] = prop.Value.ValueKind switch + { + JsonValueKind.String => prop.Value.GetString(), + JsonValueKind.True => true, + JsonValueKind.False => false, + _ => prop.Value.GetRawText() + }; + } + _userInputResponseTcs.TrySetResult(dict); + + if (idElement.ValueKind == JsonValueKind.Number && idElement.TryGetInt32(out var successId)) + RecordInboundResponse(new InboundResponse(successId, IsError: false, null)); + } + else if (request.TryGetProperty("error", out var errorEl)) + { + var requestId = idElement.ValueKind == JsonValueKind.Number && idElement.TryGetInt32(out var errId) + ? errId : -1; + var msg = errorEl.TryGetProperty("message", out var msgEl) ? msgEl.GetString() : null; + RecordInboundResponse(new InboundResponse(requestId, IsError: true, msg)); + } + return; + } + + if (!hasId) + { + // Notification — nothing to respond to. + return; + } + + var id = idElement.Clone(); + var method = methodEl.GetString(); + + if (method == "connect") + { + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["result"] = new Dictionary + { + ["ok"] = true, + ["protocolVersion"] = 3, + ["version"] = "test" + } + }, cancellationToken); + return; + } + + if (method == "session.create") + { + // Capture the params for assertions. + if (request.TryGetProperty("params", out var paramsEl)) + { + LastCreatePayload = paramsEl.Clone(); + } + + // Optionally send an early notification before responding. + if (_earlyNotification != null) + { + await WriteMessageAsync(stream, _earlyNotification, cancellationToken); + } + + // Optionally send an early inbound request before responding. + if (_earlyInboundRequest != null) + { + await WriteMessageAsync(stream, _earlyInboundRequest, cancellationToken); + // Give the SDK a moment to park the request before we unblock create. + await Task.Delay(50, cancellationToken); + } + + // For overflow tests: send N inbound requests to exercise the buffer cap. + if (_earlyInboundRequestCount > 0) + { + for (var i = 1; i <= _earlyInboundRequestCount; i++) + { + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = i, + ["method"] = "userInput.request", + ["params"] = new Dictionary + { + ["sessionId"] = _cloudSessionId, + ["question"] = $"Question {i}", + ["choices"] = new object?[] { "yes", "no" }, + ["allowFreeform"] = false + } + }, cancellationToken); + } + + // Give the client time to park/overflow all requests before responding. + await Task.Delay(100, cancellationToken); + } + + if (_failSessionCreate) + { + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["error"] = new Dictionary + { + ["code"] = -32603, + ["message"] = "session.create failed (test-induced failure)" + } + }, cancellationToken); + return; + } + + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["result"] = new Dictionary + { + ["sessionId"] = _cloudSessionId, + ["workspacePath"] = null, + ["capabilities"] = null + } + }, cancellationToken); + return; + } + + if (method == "session.destroy") + { + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["result"] = new Dictionary() + }, cancellationToken); + return; + } + + // Default: return an empty success result. + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["result"] = new Dictionary() + }, cancellationToken); + } + + private async Task WriteMessageAsync(Stream stream, object payload, CancellationToken cancellationToken) + { + using var bodyStream = new MemoryStream(); + using (var writer = new Utf8JsonWriter(bodyStream)) + { + WriteJsonValue(writer, payload); + } + + var body = bodyStream.ToArray(); + var header = Encoding.ASCII.GetBytes($"Content-Length: {body.Length}\r\n\r\n"); + + await _writeLock.WaitAsync(cancellationToken); + try + { + await stream.WriteAsync(header, cancellationToken); + await stream.WriteAsync(body, cancellationToken); + await stream.FlushAsync(cancellationToken); + } + finally + { + _writeLock.Release(); + } + } + + private static void WriteJsonValue(Utf8JsonWriter writer, object? value) + { + switch (value) + { + case null: + writer.WriteNullValue(); + break; + case string s: + writer.WriteStringValue(s); + break; + case bool b: + writer.WriteBooleanValue(b); + break; + case int i: + writer.WriteNumberValue(i); + break; + case long l: + writer.WriteNumberValue(l); + break; + case JsonElement je: + je.WriteTo(writer); + break; + case Dictionary dict: + writer.WriteStartObject(); + foreach (var (k, v) in dict) + { + writer.WritePropertyName(k); + WriteJsonValue(writer, v); + } + writer.WriteEndObject(); + break; + case object?[] arr: + writer.WriteStartArray(); + foreach (var item in arr) + { + WriteJsonValue(writer, item); + } + writer.WriteEndArray(); + break; + default: + throw new InvalidOperationException($"Unexpected JSON value type '{value.GetType().Name}'."); + } + } + + private static async Task ReadMessageAsync(Stream stream, CancellationToken cancellationToken) + { + var headerBytes = new List(); + while (true) + { + var value = await ReadByteAsync(stream, cancellationToken); + if (value < 0) return null; + + headerBytes.Add((byte)value); + var count = headerBytes.Count; + if (count >= 4 && + headerBytes[count - 4] == '\r' && + headerBytes[count - 3] == '\n' && + headerBytes[count - 2] == '\r' && + headerBytes[count - 1] == '\n') + { + break; + } + } + + var header = Encoding.ASCII.GetString([.. headerBytes]); + var contentLength = header + .Split(["\r\n"], StringSplitOptions.RemoveEmptyEntries) + .Select(line => line.Split(':', 2)) + .Where(parts => parts.Length == 2 && parts[0].Equals("Content-Length", StringComparison.OrdinalIgnoreCase)) + .Select(parts => int.Parse(parts[1].Trim(), System.Globalization.CultureInfo.InvariantCulture)) + .Single(); + + var body = new byte[contentLength]; + var offset = 0; + while (offset < body.Length) + { + var read = await stream.ReadAsync(body.AsMemory(offset, body.Length - offset), cancellationToken); + if (read == 0) return null; + offset += read; + } + + return JsonDocument.Parse(body); + } + + private static async Task ReadByteAsync(Stream stream, CancellationToken cancellationToken) + { + var buffer = new byte[1]; + var read = await stream.ReadAsync(buffer, cancellationToken); + return read == 0 ? -1 : buffer[0]; + } + } +} +#endif diff --git a/go/README.md b/go/README.md index da77033f8..62f3b9198 100644 --- a/go/README.md +++ b/go/README.md @@ -105,6 +105,7 @@ That's it! When your application calls `copilot.NewClient` without a `Connection - `Stop() error` - Stop the CLI server - `ForceStop()` - Forcefully stop without graceful cleanup - `CreateSession(config *SessionConfig) (*Session, error)` - Create a new session +- `CreateCloudSession(ctx context.Context, config *SessionConfig) (*Session, error)` - Create a Mission Control–backed cloud session - `ResumeSession(sessionID string, config *ResumeSessionConfig) (*Session, error)` - Resume an existing session - `ResumeSessionWithOptions(sessionID string, config *ResumeSessionConfig) (*Session, error)` - Resume with additional configuration - `ListSessions(filter *SessionListFilter) ([]SessionMetadata, error)` - List sessions (with optional filter) @@ -170,6 +171,8 @@ Event types: `SessionLifecycleCreated`, `SessionLifecycleDeleted`, `SessionLifec - `Commands` ([]CommandDefinition): Slash-commands registered for this session. See [Commands](#commands) section. - `OnElicitationRequest` (ElicitationHandler): Handler for elicitation requests from the server. See [Elicitation Requests](#elicitation-requests-serverclient) section. +- `Cloud` (\*CloudSessionOptions): Cloud session configuration. When set, `CreateSession` rejects the config; use `CreateCloudSession` instead. Do not set `SessionID` or `Provider` when using cloud sessions. + **ResumeSessionConfig:** - `OnPermissionRequest` (PermissionHandlerFunc): Optional handler called before each tool execution to approve or deny it. See [Permission Handling](#permission-handling) section. @@ -487,6 +490,27 @@ When enabled, sessions emit compaction events: - `session.compaction_start` - Background compaction started - `session.compaction_complete` - Compaction finished (includes token counts) +## Cloud Sessions + +`CreateCloudSession` creates a Mission Control–backed cloud session. The runtime assigns the session ID; do not set `SessionID` or `Provider` on the config (the SDK rejects both). `CreateSession` also rejects any config that has `Cloud` set. + +Any `session.event` notifications or inbound JSON-RPC requests that arrive between sending `session.create` and receiving its response are buffered (bounded, drop-oldest, limit 128 per id) and replayed once the runtime-assigned session ID is registered. + +```go +session, err := client.CreateCloudSession(context.Background(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Cloud: &copilot.CloudSessionOptions{ + Repository: &copilot.CloudSessionRepository{ + Owner: "github", Name: "copilot-sdk", + }, + }, +}) +if err != nil { + log.Fatal(err) +} +fmt.Println("cloud session id:", session.SessionID) +``` + ## Custom Providers The SDK supports custom OpenAI-compatible API providers (BYOK - Bring Your Own Key), including local providers like Ollama. When using a custom provider, you must specify the `Model` explicitly. diff --git a/go/client.go b/go/client.go index 9491eb199..a34269cdb 100644 --- a/go/client.go +++ b/go/client.go @@ -87,6 +87,34 @@ func validateSessionFsConfig(config *SessionFsConfig) error { // log.Fatal(err) // } // defer client.Stop() +// +// Sentinel errors for the two pending-routing rejection paths. Using distinct +// values lets callers (and debugging) tell overflow eviction from guard-drop. +var ( + errPendingSessionBufferOverflow = errors.New("pending session buffer overflow") + errPendingSessionRoutingEnded = errors.New("pending session routing ended before session was registered") +) + +// pendingResult carries the outcome of a parked inbound-request session lookup. +type pendingResult struct { + session *Session + err error +} + +// pendingRouting buffers session.event notifications and parks inbound request +// handlers that arrive before a cloud session.create response is received. +// A refcount tracks how many cloud creates are in flight; when it reaches zero +// the buffers are cleared and parked handlers are rejected. +type pendingRouting struct { + mu sync.Mutex + count int + events map[string][]sessionEventRequest + waiters map[string][]chan pendingResult +} + +// pendingSessionBufferLimit caps buffered notifications per in-flight session id. +const pendingSessionBufferLimit = 128 + type Client struct { options ClientOptions process *exec.Cmd @@ -121,6 +149,10 @@ type Client struct { effectiveConnectionToken string onListModels func(ctx context.Context) ([]ModelInfo, error) + // pending buffers traffic that arrives between session.create being sent and + // the response for cloud sessions. + pending pendingRouting + // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). RPC *rpc.ServerRpc @@ -162,6 +194,10 @@ func NewClient(options *ClientOptions) *Client { actualHost: "localhost", isExternalServer: false, useStdio: true, + pending: pendingRouting{ + events: make(map[string][]sessionEventRequest), + waiters: make(map[string][]chan pendingResult), + }, } if options != nil { @@ -593,6 +629,10 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses config = &SessionConfig{} } + if config.Cloud != nil { + return nil, fmt.Errorf("CreateSession does not support cloud sessions; use CreateCloudSession instead") + } + if err := c.ensureConnected(ctx); err != nil { return nil, err } @@ -754,6 +794,310 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses return session, nil } +// CreateCloudSession creates a Mission Control–backed cloud session. +// +// The runtime owns the session ID for cloud sessions: do not set SessionID or +// Provider on the config (the SDK rejects both). Build the config with Cloud +// set to a [CloudSessionOptions] value; [Client.CreateSession] rejects any +// config that has Cloud set. +// +// Any session.event notifications or inbound JSON-RPC requests that arrive +// between sending session.create and receiving its response are buffered +// (bounded, drop-oldest, limit 128 per id) and replayed once the +// runtime-assigned session ID is registered. +// +// Known limitation: inbound sessionFs.* requests from the generated +// client-session API handlers are not pending-buffered. In practice the +// runtime does not initiate sessionFs.* calls before the session.create +// response, so this is theoretical. +// +// Example: +// +// session, err := client.CreateCloudSession(context.Background(), &copilot.SessionConfig{ +// OnPermissionRequest: copilot.PermissionHandler.ApproveAll, +// Cloud: &copilot.CloudSessionOptions{ +// Repository: &copilot.CloudSessionRepository{ +// Owner: "github", Name: "copilot-sdk", +// }, +// }, +// }) +func (c *Client) CreateCloudSession(ctx context.Context, config *SessionConfig) (*Session, error) { + if config == nil { + config = &SessionConfig{} + } + + if config.Cloud == nil { + return nil, fmt.Errorf("CreateCloudSession requires config.Cloud to be set") + } + if config.SessionID != "" { + return nil, fmt.Errorf("CreateCloudSession does not accept a caller-provided SessionID; the runtime assigns one") + } + if config.Provider != nil { + return nil, fmt.Errorf("CreateCloudSession does not accept config.Provider; cloud sessions use the runtime's provider") + } + + if err := c.ensureConnected(ctx); err != nil { + return nil, err + } + + req := createSessionRequest{} + req.Model = config.Model + req.ClientName = config.ClientName + req.ReasoningEffort = config.ReasoningEffort + req.ConfigDir = config.ConfigDir + if config.EnableConfigDiscovery { + req.EnableConfigDiscovery = Bool(true) + } + req.Tools = config.Tools + wireSystemMessage, transformCallbacks := extractTransformCallbacks(config.SystemMessage) + req.SystemMessage = wireSystemMessage + req.AvailableTools = config.AvailableTools + req.ExcludedTools = config.ExcludedTools + req.EnableSessionTelemetry = config.EnableSessionTelemetry + req.ModelCapabilities = config.ModelCapabilities + req.WorkingDirectory = config.WorkingDirectory + req.MCPServers = config.MCPServers + req.EnvValueMode = "direct" + req.CustomAgents = config.CustomAgents + req.DefaultAgent = config.DefaultAgent + req.Agent = config.Agent + req.SkillDirectories = config.SkillDirectories + req.InstructionDirectories = config.InstructionDirectories + req.DisabledSkills = config.DisabledSkills + req.InfiniteSessions = config.InfiniteSessions + req.GitHubToken = config.GitHubToken + req.RemoteSession = config.RemoteSession + req.Cloud = config.Cloud + // SessionID intentionally omitted: the runtime assigns the id for cloud sessions. + + if len(config.Commands) > 0 { + cmds := make([]wireCommand, 0, len(config.Commands)) + for _, cmd := range config.Commands { + cmds = append(cmds, wireCommand{Name: cmd.Name, Description: cmd.Description}) + } + req.Commands = cmds + } + if config.OnElicitationRequest != nil { + req.RequestElicitation = Bool(true) + } + if config.OnExitPlanModeRequest != nil { + req.RequestExitPlanMode = Bool(true) + } + if config.OnAutoModeSwitchRequest != nil { + req.RequestAutoModeSwitch = Bool(true) + } + if config.Streaming != nil { + req.Streaming = config.Streaming + } + if config.IncludeSubAgentStreamingEvents != nil { + req.IncludeSubAgentStreamingEvents = config.IncludeSubAgentStreamingEvents + } else { + req.IncludeSubAgentStreamingEvents = Bool(true) + } + if config.OnUserInputRequest != nil { + req.RequestUserInput = Bool(true) + } + if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || + config.Hooks.OnPreMcpToolCall != nil || + config.Hooks.OnPostToolUse != nil || + config.Hooks.OnUserPromptSubmitted != nil || + config.Hooks.OnSessionStart != nil || + config.Hooks.OnSessionEnd != nil || + config.Hooks.OnErrorOccurred != nil) { + req.Hooks = Bool(true) + } + if config.OnPermissionRequest != nil { + req.RequestPermission = Bool(true) + } + + traceparent, tracestate := getTraceContext(ctx) + req.Traceparent = traceparent + req.Tracestate = tracestate + + dispose := c.beginPendingSessionRouting() + + result, err := c.client.Request("session.create", req) + if err != nil { + dispose() + return nil, fmt.Errorf("failed to create cloud session: %w", err) + } + + var response createSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + dispose() + return nil, fmt.Errorf("failed to unmarshal cloud session response: %w", err) + } + + if response.SessionID == "" { + fmt.Println("warning: cloud session.create response missing sessionId; runtime session may leak") + dispose() + return nil, fmt.Errorf("cloud session.create response did not include a sessionId; cannot register session") + } + + sessionID := response.SessionID + session := newSession(sessionID, c.client, response.WorkspacePath) + session.remoteURL = response.RemoteURL + + session.registerTools(config.Tools) + session.registerPermissionHandler(config.OnPermissionRequest) + if config.OnUserInputRequest != nil { + session.registerUserInputHandler(config.OnUserInputRequest) + } + if config.Hooks != nil { + session.registerHooks(config.Hooks) + } + if transformCallbacks != nil { + session.registerTransformCallbacks(transformCallbacks) + } + if config.OnEvent != nil { + session.On(config.OnEvent) + } + if len(config.Commands) > 0 { + session.registerCommands(config.Commands) + } + if config.OnElicitationRequest != nil { + session.registerElicitationHandler(config.OnElicitationRequest) + } + if config.OnExitPlanModeRequest != nil { + session.registerExitPlanModeHandler(config.OnExitPlanModeRequest) + } + if config.OnAutoModeSwitchRequest != nil { + session.registerAutoModeSwitchHandler(config.OnAutoModeSwitchRequest) + } + session.setCapabilities(response.Capabilities) + + c.sessionsMux.Lock() + c.sessions[sessionID] = session + c.sessionsMux.Unlock() + + if c.options.SessionFs != nil { + if config.CreateSessionFsProvider == nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + dispose() + return nil, fmt.Errorf("CreateSessionFsProvider is required in session config when SessionFs is enabled in client options") + } + provider := config.CreateSessionFsProvider(session) + if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { + if _, ok := provider.(SessionFsSqliteProvider); !ok { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + dispose() + return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") + } + } + session.clientSessionApis.SessionFs = newSessionFsAdapter(provider) + } + + // Drain buffered events and unblock parked request handlers before + // releasing the guard, so they see a fully-wired session. + c.flushPendingForSession(sessionID, session) + dispose() + + return session, nil +} + +// beginPendingSessionRouting increments the pending-routing refcount and +// returns a disposer. While any disposer is undisposed, session.event +// notifications and inbound JSON-RPC requests addressed to unknown session ids +// are buffered/parked rather than dropped. When the last disposer fires, any +// remaining buffers are cleared and parked handlers receive an error so they +// don't block forever. +func (c *Client) beginPendingSessionRouting() func() { + c.pending.mu.Lock() + c.pending.count++ + c.pending.mu.Unlock() + + var once sync.Once + return func() { + once.Do(func() { + c.pending.mu.Lock() + c.pending.count-- + if c.pending.count > 0 { + c.pending.mu.Unlock() + return + } + // Last guard: swap out the maps so we can signal waiters without + // holding the lock (buffered channels make sends non-blocking, but + // releasing the lock is cleaner). + c.pending.events = make(map[string][]sessionEventRequest) + waiters := c.pending.waiters + c.pending.waiters = make(map[string][]chan pendingResult) + c.pending.mu.Unlock() + + for _, chs := range waiters { + for _, ch := range chs { + ch <- pendingResult{err: errPendingSessionRoutingEnded} + } + } + }) + } +} + +// flushPendingForSession drains buffered events and resolves parked request +// handlers for sessionID into the freshly-registered session. Called from +// CreateCloudSession after the session is in c.sessions and before the pending +// guard is released. +func (c *Client) flushPendingForSession(sessionID string, session *Session) { + c.pending.mu.Lock() + events := c.pending.events[sessionID] + delete(c.pending.events, sessionID) + waiters := c.pending.waiters[sessionID] + delete(c.pending.waiters, sessionID) + c.pending.mu.Unlock() + + for _, req := range events { + session.dispatchEvent(req.Event) + } + for _, ch := range waiters { + ch <- pendingResult{session: session} + } +} + +// waitForSession looks up the session by id. If the session is not yet +// registered but pending routing is active, the call parks until the session +// is registered (or pending routing ends without registration). +func (c *Client) waitForSession(sessionID string) (*Session, error) { + c.sessionsMux.Lock() + session, ok := c.sessions[sessionID] + c.sessionsMux.Unlock() + if ok { + return session, nil + } + + c.pending.mu.Lock() + if c.pending.count == 0 { + c.pending.mu.Unlock() + return nil, fmt.Errorf("unknown session %s", sessionID) + } + // Re-check under pending.mu: the session may have been registered and + // flushed between the first lookup and acquiring this lock. + c.sessionsMux.Lock() + session, ok = c.sessions[sessionID] + c.sessionsMux.Unlock() + if ok { + c.pending.mu.Unlock() + return session, nil + } + ch := make(chan pendingResult, 1) + waiters := c.pending.waiters[sessionID] + if len(waiters) >= pendingSessionBufferLimit { + // Reject the oldest waiter to keep the queue bounded. Send a JSON-RPC + // error response via the handler return so the runtime doesn't hang on + // the request id waiting for a reply that would never come. + oldest := waiters[0] + waiters = waiters[1:] + oldest <- pendingResult{err: errPendingSessionBufferOverflow} + } + c.pending.waiters[sessionID] = append(waiters, ch) + c.pending.mu.Unlock() + + result := <-ch + return result.session, result.err +} + // ResumeSession resumes an existing conversation session by its ID. // // This is a convenience method that calls [Client.ResumeSessionWithOptions]. @@ -1761,14 +2105,47 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) { if req.SessionID == "" { return } - // Dispatch to session c.sessionsMux.Lock() session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if ok { session.dispatchEvent(req.Event) + return + } + + // Buffer if a cloud session.create is in flight for this id. + c.pending.mu.Lock() + if c.pending.count > 0 { + // Re-check under pending.mu: the session may have been registered and + // flushed between the first lookup and acquiring this lock. + c.sessionsMux.Lock() + session, ok = c.sessions[req.SessionID] + c.sessionsMux.Unlock() + if ok { + c.pending.mu.Unlock() + session.dispatchEvent(req.Event) + return + } + buf := c.pending.events[req.SessionID] + if len(buf) >= pendingSessionBufferLimit { + buf = buf[1:] // drop oldest + } + c.pending.events[req.SessionID] = append(buf, req) + } + c.pending.mu.Unlock() +} + +// pendingRoutingRPCError maps an error from waitForSession to the appropriate +// JSON-RPC error. Overflow and guard-drop rejections use -32603 (internal +// error) so the runtime gets a proper error response instead of hanging on the +// request id. All other waitForSession errors (e.g. unknown session) keep the +// existing -32602 (invalid params) code. +func pendingRoutingRPCError(err error) *jsonrpc2.Error { + if errors.Is(err, errPendingSessionBufferOverflow) || errors.Is(err, errPendingSessionRoutingEnded) { + return &jsonrpc2.Error{Code: -32603, Message: err.Error()} } + return &jsonrpc2.Error{Code: -32602, Message: err.Error()} } // handleUserInputRequest handles a user input request from the CLI server. @@ -1777,11 +2154,9 @@ func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputRespons return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, pendingRoutingRPCError(err) } response, err := session.handleUserInputRequest(UserInputRequest{ @@ -1806,11 +2181,9 @@ func (c *Client) handleExitPlanModeRequest(req exitPlanModeRequest) (*ExitPlanMo recommendedAction = "autopilot" } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, pendingRoutingRPCError(err) } response, err := session.handleExitPlanModeRequest(ExitPlanModeRequest{ @@ -1832,11 +2205,9 @@ func (c *Client) handleAutoModeSwitchRequest(req autoModeSwitchRequest) (*autoMo return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid auto mode switch request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, pendingRoutingRPCError(err) } response, err := session.handleAutoModeSwitchRequest(AutoModeSwitchRequest{ @@ -1856,11 +2227,9 @@ func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jso return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, pendingRoutingRPCError(err) } output, err := session.handleHooksInvoke(req.Type, req.Input) @@ -1881,11 +2250,9 @@ func (c *Client) handleSystemMessageTransform(req systemMessageTransformRequest) return systemMessageTransformResponse{}, &jsonrpc2.Error{Code: -32602, Message: "invalid system message transform payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return systemMessageTransformResponse{}, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return systemMessageTransformResponse{}, pendingRoutingRPCError(err) } resp, err := session.handleSystemMessageTransform(req.Sections) diff --git a/go/cloud_session_test.go b/go/cloud_session_test.go new file mode 100644 index 000000000..63c66ab78 --- /dev/null +++ b/go/cloud_session_test.go @@ -0,0 +1,454 @@ +package copilot + +import ( + "encoding/json" + "strings" + "sync" + "testing" + "time" + + "github.com/github/copilot-sdk/go/internal/jsonrpc2" +) + +// newCloudTestClient returns a Client with pending routing initialized and a +// pre-populated sessions map, suitable for unit-testing cloud session logic +// without a real network connection. +func newCloudTestClient() *Client { + return &Client{ + sessions: make(map[string]*Session), + pending: pendingRouting{ + events: make(map[string][]sessionEventRequest), + waiters: make(map[string][]chan pendingResult), + }, + } +} + +// TestCreateSession_RejectsCloudConfig verifies that CreateSession returns a +// clear error when config.Cloud is set. +func TestCreateSession_RejectsCloudConfig(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + }) + if err == nil { + t.Fatal("expected error when cloud config is set") + } + if !strings.Contains(err.Error(), "CreateCloudSession") { + t.Errorf("error should mention CreateCloudSession, got: %v", err) + } +} + +// TestCreateCloudSession_RejectsCallerSessionID verifies the SDK rejects a +// caller-supplied SessionID. +func TestCreateCloudSession_RejectsCallerSessionID(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + SessionID: "caller-supplied-id", + }) + if err == nil { + t.Fatal("expected error when SessionID is set") + } + if !strings.Contains(err.Error(), "SessionID") { + t.Errorf("error should mention SessionID, got: %v", err) + } +} + +// TestCreateCloudSession_RejectsCallerProvider verifies the SDK rejects a +// caller-supplied Provider. +func TestCreateCloudSession_RejectsCallerProvider(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + Provider: &ProviderConfig{ModelID: "gpt-4"}, + }) + if err == nil { + t.Fatal("expected error when Provider is set") + } + if !strings.Contains(err.Error(), "Provider") { + t.Errorf("error should mention Provider, got: %v", err) + } +} + +// TestCreateCloudSession_RequiresCloud verifies the SDK rejects configs without +// Cloud set. +func TestCreateCloudSession_RequiresCloud(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{}) + if err == nil { + t.Fatal("expected error when Cloud is nil") + } + if !strings.Contains(err.Error(), "Cloud") { + t.Errorf("error should mention Cloud, got: %v", err) + } +} + +// TestCreateCloudSession_WirePayload verifies that the session.create wire +// payload includes the cloud field and omits sessionId when built by the cloud +// path. +func TestCreateCloudSession_WirePayload(t *testing.T) { + req := createSessionRequest{ + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk"}, + }, + // SessionID intentionally left empty + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal error: %v", err) + } + + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if _, ok := m["sessionId"]; ok { + t.Error("sessionId must be omitted from the cloud session.create wire payload") + } + + cloud, ok := m["cloud"] + if !ok { + t.Fatal("cloud field must be present in the wire payload") + } + cloudMap, ok := cloud.(map[string]any) + if !ok { + t.Fatalf("cloud field should be a map, got %T", cloud) + } + repo, ok := cloudMap["repository"].(map[string]any) + if !ok { + t.Fatal("cloud.repository should be a map") + } + if repo["owner"] != "github" || repo["name"] != "copilot-sdk" { + t.Errorf("unexpected cloud.repository: %v", repo) + } +} + +// TestPendingRouting_BuffersEarlyNotifications verifies that session.event +// notifications arriving before the session is registered are buffered and +// replayed when flushPendingForSession is called. +func TestPendingRouting_BuffersEarlyNotifications(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + defer dispose() + + const pendingID = "runtime-assigned-id" + + // Simulate two session.event notifications arriving before the session is + // registered. + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{Data: &SessionIdleData{}}, + }) + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{Data: &SessionIdleData{}}, + }) + + // Verify they are buffered. + client.pending.mu.Lock() + bufLen := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + if bufLen != 2 { + t.Fatalf("expected 2 buffered events, got %d", bufLen) + } + + // Now register the session and flush. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + + var received []SessionEvent + var mu sync.Mutex + var wg sync.WaitGroup + wg.Add(2) + session.On(func(event SessionEvent) { + mu.Lock() + received = append(received, event) + mu.Unlock() + wg.Done() + }) + + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + + client.flushPendingForSession(pendingID, session) + + // Wait for the event handler goroutine to process. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for buffered events to be dispatched") + } + + mu.Lock() + got := len(received) + mu.Unlock() + if got != 2 { + t.Errorf("expected 2 events replayed, got %d", got) + } + + // Buffer should be cleared after flush. + client.pending.mu.Lock() + remaining := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + if remaining != 0 { + t.Errorf("buffer should be empty after flush, got %d", remaining) + } +} + +// TestPendingRouting_ParksInboundRequests verifies that inbound request handlers +// (e.g. userInput.request) park until the session is registered when pending +// routing is active. +func TestPendingRouting_ParksInboundRequests(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "runtime-assigned-id-2" + + // Launch a goroutine that simulates an inbound userInput.request arriving + // before the session is registered. + type result struct { + resp *userInputResponse + err *jsonrpcError + } + resultCh := make(chan result, 1) + go func() { + resp, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- result{resp, rpcErr} + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Register the session. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + session.registerUserInputHandler(func(req UserInputRequest, _ UserInputInvocation) (UserInputResponse, error) { + return UserInputResponse{Answer: "yes"}, nil + }) + + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + + client.flushPendingForSession(pendingID, session) + dispose() + + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("expected success, got rpc error: %v", r.err) + } + if r.resp == nil || r.resp.Answer != "yes" { + t.Errorf("expected answer 'yes', got %+v", r.resp) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for parked request to be resolved") + } +} + +// TestPendingRouting_DropOldestWhenBufferFull verifies drop-oldest behaviour +// when the notification buffer is full. +func TestPendingRouting_DropOldestWhenBufferFull(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + defer dispose() + + const pendingID = "overflow-session" + + // Fill buffer beyond the limit. + for i := range pendingSessionBufferLimit + 5 { + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{ + // Embed the index so we can verify drop-oldest. + Data: &SessionIdleData{}, + }, + }) + _ = i + } + + client.pending.mu.Lock() + bufLen := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + + if bufLen != pendingSessionBufferLimit { + t.Errorf("expected buffer capped at %d, got %d", pendingSessionBufferLimit, bufLen) + } +} + +// TestPendingRouting_RejectsWaitersOnDispose verifies that waiters are +// rejected with an error when pending mode ends without registration. +func TestPendingRouting_RejectsWaitersOnDispose(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "never-registered" + + resultCh := make(chan *jsonrpcError, 1) + go func() { + _, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- rpcErr + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Dispose without registering the session. + dispose() + + select { + case rpcErr := <-resultCh: + if rpcErr == nil { + t.Fatal("expected an rpc error after dispose without registration") + } + if !strings.Contains(rpcErr.Message, "routing ended before session was registered") { + t.Errorf("expected routing-ended message, got: %s", rpcErr.Message) + } + if rpcErr.Code != -32603 { + t.Errorf("expected code -32603, got: %d", rpcErr.Code) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for rejected waiter") + } +} + +// TestPendingRouting_OverflowEmitsError verifies that when the parked-waiter +// buffer reaches its cap, the oldest waiter receives the overflow error response +// and the remaining 128 waiters resolve normally after registration. +func TestPendingRouting_OverflowEmitsError(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "overflow-request-session" + const total = pendingSessionBufferLimit + 1 // 129 + + type result struct { + resp *userInputResponse + err *jsonrpcError + } + + // Register a user-input handler so the session resolves successfully. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + session.registerUserInputHandler(func(req UserInputRequest, _ UserInputInvocation) (UserInputResponse, error) { + return UserInputResponse{Answer: "yes"}, nil + }) + + results := make([]chan result, total) + for i := range total { + results[i] = make(chan result, 1) + go func(ch chan result) { + resp, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + ch <- result{resp, rpcErr} + }(results[i]) + } + + // Give goroutines time to park. + time.Sleep(50 * time.Millisecond) + + // Register the session and flush — this resolves the 128 remaining waiters. + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + client.flushPendingForSession(pendingID, session) + dispose() + + // Collect all results with a timeout. + var gotOverflow int + var gotSuccess int + deadline := time.After(2 * time.Second) + for _, ch := range results { + select { + case r := <-ch: + if r.err != nil { + if !strings.Contains(r.err.Message, "pending session buffer overflow") { + t.Errorf("unexpected error message: %s", r.err.Message) + } + if r.err.Code != -32603 { + t.Errorf("expected code -32603 for overflow, got: %d", r.err.Code) + } + gotOverflow++ + } else { + gotSuccess++ + } + case <-deadline: + t.Fatalf("timed out: overflow=%d success=%d", gotOverflow, gotSuccess) + } + } + + if gotOverflow != 1 { + t.Errorf("expected exactly 1 overflow rejection, got %d", gotOverflow) + } + if gotSuccess != pendingSessionBufferLimit { + t.Errorf("expected %d successful resolutions, got %d", pendingSessionBufferLimit, gotSuccess) + } +} + +// TestPendingRouting_GuardDropDistinctMessage verifies that when the last +// pending-routing guard drops without registration, parked waiters receive the +// distinct routing-ended error (not the overflow message) so the two paths are +// distinguishable in logs and debugging. +func TestPendingRouting_GuardDropDistinctMessage(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "guard-drop-session" + + resultCh := make(chan *jsonrpcError, 1) + go func() { + _, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- rpcErr + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Drop the guard without registering — simulates session.create failing. + dispose() + + select { + case rpcErr := <-resultCh: + if rpcErr == nil { + t.Fatal("expected an rpc error after guard drop without registration") + } + const want = "pending session routing ended before session was registered" + if rpcErr.Message != want { + t.Errorf("expected exact message %q, got %q", want, rpcErr.Message) + } + if rpcErr.Code != -32603 { + t.Errorf("expected code -32603, got: %d", rpcErr.Code) + } + // Must NOT contain the overflow message. + if strings.Contains(rpcErr.Message, "buffer overflow") { + t.Errorf("guard-drop path must not use overflow message, got: %s", rpcErr.Message) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for rejected waiter") + } +} + +// jsonrpcError is a local alias for jsonrpc2.Error used in test assertions. +type jsonrpcError = jsonrpc2.Error diff --git a/go/session.go b/go/session.go index f38b4be17..06d0447ab 100644 --- a/go/session.go +++ b/go/session.go @@ -52,6 +52,7 @@ type Session struct { // SessionID is the unique identifier for this session. SessionID string workspacePath string + remoteURL string client *jsonrpc2.Client clientSessionApis *rpc.ClientSessionApiHandlers handlers []sessionHandler @@ -94,6 +95,14 @@ func (s *Session) WorkspacePath() string { return s.workspacePath } +// RemoteURL returns the remote URL for a Mission Control–backed cloud session. +// Populated from the remoteUrl field in the session.create response for cloud +// sessions created via CreateCloudSession. Returns empty string for regular +// local sessions. +func (s *Session) RemoteURL() string { + return s.remoteURL +} + // newSession creates a new session wrapper with the given session ID and client. func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) *Session { s := &Session{ diff --git a/go/types.go b/go/types.go index be86a326c..a2e7523a7 100644 --- a/go/types.go +++ b/go/types.go @@ -1413,6 +1413,7 @@ type wireCommand struct { type createSessionResponse struct { SessionID string `json:"sessionId"` WorkspacePath string `json:"workspacePath"` + RemoteURL string `json:"remoteUrl,omitempty"` Capabilities *SessionCapabilities `json:"capabilities,omitempty"` } diff --git a/java/src/main/java/com/github/copilot/sdk/CopilotClient.java b/java/src/main/java/com/github/copilot/sdk/CopilotClient.java index 4d0770319..4bc8a0d09 100644 --- a/java/src/main/java/com/github/copilot/sdk/CopilotClient.java +++ b/java/src/main/java/com/github/copilot/sdk/CopilotClient.java @@ -81,6 +81,7 @@ public final class CopilotClient implements AutoCloseable { private final CliServerManager serverManager; private final LifecycleEventManager lifecycleManager = new LifecycleEventManager(); private final Map sessions = new ConcurrentHashMap<>(); + private final PendingRoutingState pendingRoutingState = new PendingRoutingState(); private volatile CompletableFuture connectionFuture; private volatile boolean disposed = false; private final String optionsHost; @@ -210,7 +211,7 @@ private Connection startCoreBody() { // Register handlers for server-to-client calls RpcHandlerDispatcher dispatcher = new RpcHandlerDispatcher(sessions, lifecycleManager::dispatch, - options.getExecutor()); + options.getExecutor(), pendingRoutingState); dispatcher.registerHandlers(rpc); // Verify protocol version @@ -426,6 +427,10 @@ public CompletableFuture createSession(SessionConfig config) { + "For example, to allow all permissions, use: " + "new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)")); } + if (config.getCloud() != null) { + return CompletableFuture.failedFuture(new IllegalArgumentException( + "CopilotClient.createSession does not support cloud sessions; use createCloudSession instead.")); + } return ensureConnected().thenCompose(connection -> { long totalNanos = System.nanoTime(); // Pre-generate session ID so the session can be registered before the RPC call, @@ -487,6 +492,156 @@ public CompletableFuture createSession(SessionConfig config) { }); } + /** + * Creates a Mission Control–backed cloud session. + * + *

+ * The runtime owns the session ID for cloud sessions. Do not + * set {@link SessionConfig#setSessionId(String) sessionId} or + * {@link SessionConfig#setProvider(com.github.copilot.sdk.json.ProviderConfig) + * provider} on the config; the SDK rejects both with + * {@link IllegalArgumentException}. The config must have + * {@link SessionConfig#setCloud(com.github.copilot.sdk.json.CloudSessionOptions) + * cloud} set; + * {@link SessionConfig#setOnPermissionRequest(com.github.copilot.sdk.json.PermissionHandler) + * onPermissionRequest} is required. + * + *

+ * The SDK omits {@code sessionId} from the {@code session.create} wire request + * and registers the returned session under the id the runtime assigns. Any + * {@code session.event} notifications or inbound JSON-RPC requests that arrive + * between sending {@code session.create} and receiving its response are + * buffered (bounded, drop-oldest, up to + * {@value PendingRoutingState#BUFFER_LIMIT} per session id) and replayed once + * the session is registered. + * + *

+ * Example: + * + *

{@code
+     * var session = client.createCloudSession(new SessionConfig()
+     * 		.setCloud(new CloudSessionOptions()
+     * 				.setRepository(new CloudSessionRepository().setOwner("github").setName("copilot-sdk")))
+     * 		.setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get();
+     * }
+ * + * @param config + * configuration for the cloud session; must have {@code cloud} set + * and an {@code onPermissionRequest} handler; must not have + * {@code sessionId} or {@code provider} set + * @return a future that resolves with the created {@link CopilotSession} + * @throws IllegalArgumentException + * if validation fails (see above) + * @see SessionConfig#setCloud(com.github.copilot.sdk.json.CloudSessionOptions) + * @see com.github.copilot.sdk.json.PermissionHandler#APPROVE_ALL + * @since 1.6.0 + */ + public CompletableFuture createCloudSession(SessionConfig config) { + if (config == null || config.getOnPermissionRequest() == null) { + return CompletableFuture.failedFuture( + new IllegalArgumentException("An onPermissionRequest handler is required when creating a session. " + + "For example, to allow all permissions, use: " + + "new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)")); + } + if (config.getCloud() == null) { + return CompletableFuture.failedFuture( + new IllegalArgumentException("CopilotClient.createCloudSession requires config.cloud to be set.")); + } + if (config.getSessionId() != null && !config.getSessionId().isEmpty()) { + return CompletableFuture.failedFuture(new IllegalArgumentException( + "CopilotClient.createCloudSession does not support a caller-provided sessionId; " + + "the runtime assigns one.")); + } + if (config.getProvider() != null) { + return CompletableFuture.failedFuture( + new IllegalArgumentException("CopilotClient.createCloudSession does not support config.provider; " + + "cloud sessions use the runtime's provider.")); + } + + return ensureConnected().thenCompose(connection -> { + long totalNanos = System.nanoTime(); + + // Enter pending-routing mode before sending session.create so that any + // session.event notifications or inbound RPC requests that arrive during + // the in-flight RPC are buffered rather than dropped. + pendingRoutingState.incrementGuard(); + + var request = SessionRequestBuilder.buildCloudCreateRequest(config); + + // Extract transform callbacks from the system message config. + var extracted = SessionRequestBuilder.extractTransformCallbacks(config.getSystemMessage()); + if (extracted.wireSystemMessage() != config.getSystemMessage()) { + request.setSystemMessage(extracted.wireSystemMessage()); + } + + long rpcNanos = System.nanoTime(); + return connection.rpc.invoke("session.create", request, CreateSessionResponse.class).thenApply(response -> { + LoggingHelpers.logTiming(LOG, Level.FINE, + "CopilotClient.createCloudSession session.create completed. Elapsed={Elapsed}", rpcNanos); + + String returnedId = response.sessionId(); + if (returnedId == null || returnedId.isEmpty()) { + // No id: release the guard and surface the error. Any runtime session + // created on the other side may leak — we have nothing to destroy. + pendingRoutingState.decrementGuard(); + LOG.warning("Cloud session.create response missing sessionId; runtime session may leak"); + throw new RuntimeException( + "Cloud session.create response did not include a sessionId; cannot register session."); + } + + var session = new CopilotSession(returnedId, connection.rpc); + if (options.getExecutor() != null) { + session.setExecutor(options.getExecutor()); + } + SessionRequestBuilder.configureSession(session, config); + if (extracted.transformCallbacks() != null) { + session.registerTransformCallbacks(extracted.transformCallbacks()); + } + + try { + // Atomically register the session in the sessions map and drain any + // buffered events/parked waiters. The sessions.put happens inside the + // PendingRoutingState lock so concurrent tryBufferNotification / + // tryParkRequest calls see the session as registered. + var flush = pendingRoutingState.registerAndFlush(returnedId, session, sessions); + + session.setWorkspacePath(response.workspacePath()); + session.setRemoteUrl(response.remoteUrl()); + session.setCapabilities(response.capabilities()); + + // Replay buffered session.event notifications + for (var event : flush.events()) { + session.dispatchEvent(event); + } + // Complete parked request waiters + for (var waiter : flush.waiters()) { + waiter.complete(session); + } + } catch (Exception e) { + // Roll back: remove session from map, release guard. + sessions.remove(returnedId); + pendingRoutingState.decrementGuard(); + LoggingHelpers.logTiming(LOG, Level.WARNING, e, + "CopilotClient.createCloudSession post-registration setup failed. Elapsed={Elapsed}", + totalNanos); + throw e instanceof RuntimeException re ? re : new RuntimeException(e); + } + + pendingRoutingState.decrementGuard(); + + LoggingHelpers.logTiming(LOG, Level.FINE, + "CopilotClient.createCloudSession complete. Elapsed={Elapsed}, SessionId=" + returnedId, + totalNanos); + return session; + }).exceptionally(ex -> { + pendingRoutingState.decrementGuard(); + LoggingHelpers.logTiming(LOG, Level.WARNING, ex, + "CopilotClient.createCloudSession failed. Elapsed={Elapsed}", totalNanos); + throw ex instanceof RuntimeException re ? re : new RuntimeException(ex); + }); + }); + } + /** * Resumes an existing Copilot session. *

diff --git a/java/src/main/java/com/github/copilot/sdk/CopilotSession.java b/java/src/main/java/com/github/copilot/sdk/CopilotSession.java index 5fb8733a2..e134a3b82 100644 --- a/java/src/main/java/com/github/copilot/sdk/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/sdk/CopilotSession.java @@ -154,6 +154,7 @@ public final class CopilotSession implements AutoCloseable { */ private volatile String sessionId; private volatile String workspacePath; + private volatile String remoteUrl; private volatile SessionCapabilities capabilities = new SessionCapabilities(); private final SessionUiApi ui; private final JsonRpcClient rpc; @@ -272,6 +273,31 @@ void setWorkspacePath(String workspacePath) { this.workspacePath = workspacePath; } + /** + * Gets the remote URL for a Mission Control–backed cloud session. + *

+ * Populated from the {@code remoteUrl} field in the {@code session.create} + * response for cloud sessions created via + * {@link CopilotClient#createCloudSession}. Returns {@code null} for regular + * local sessions. + * + * @return the remote URL, or {@code null} for local sessions + */ + public String getRemoteUrl() { + return remoteUrl; + } + + /** + * Sets the remote URL. Package-private; called by CopilotClient after + * session.create RPC response for cloud sessions. + * + * @param remoteUrl + * the remote URL + */ + void setRemoteUrl(String remoteUrl) { + this.remoteUrl = remoteUrl; + } + /** * Gets the capabilities reported by the host for this session. *

diff --git a/java/src/main/java/com/github/copilot/sdk/PendingRoutingState.java b/java/src/main/java/com/github/copilot/sdk/PendingRoutingState.java new file mode 100644 index 000000000..7c0678e20 --- /dev/null +++ b/java/src/main/java/com/github/copilot/sdk/PendingRoutingState.java @@ -0,0 +1,204 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.sdk; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.logging.Logger; + +import com.github.copilot.sdk.generated.SessionEvent; + +/** + * Thread-safe state for pending-routing mode used by + * {@link CopilotClient#createCloudSession}. + * + *

+ * While one or more cloud {@code session.create} calls are in flight (guard + * count {@code > 0}), notifications and inbound RPC requests addressed to + * session ids that are not yet registered are buffered here rather than + * dropped. Once {@link CopilotClient#createCloudSession} receives the + * runtime-assigned session id, it calls {@link #registerAndFlush} to atomically + * insert the session into the sessions map, drain any buffered events into it, + * and complete any parked request waiters. + * + *

+ * All mutation methods synchronize on {@code this}. The sessions-map put inside + * {@link #registerAndFlush} is performed while holding the lock so that the + * {@link #tryBufferNotification} / {@link #tryParkRequest} check-then-act is + * free of TOCTOU races. + */ +final class PendingRoutingState { + + static final int BUFFER_LIMIT = 128; + + private static final Logger LOG = Logger.getLogger(PendingRoutingState.class.getName()); + + private int guardCount = 0; + /** Buffered session.event notifications keyed by session id. */ + private final Map> pendingEvents = new HashMap<>(); + /** + * Parked CompletableFutures for inbound RPC requests waiting for a session to + * be registered. + */ + private final Map>> pendingWaiters = new HashMap<>(); + + /** Increment the guard count. Must be matched by {@link #decrementGuard}. */ + synchronized void incrementGuard() { + guardCount++; + } + + /** + * Decrement the guard count. If the count reaches zero, clears all buffered + * events and completes all parked request waiters exceptionally with a + * canonical message that is distinct from the overflow-eviction path. + */ + synchronized void decrementGuard() { + guardCount = Math.max(0, guardCount - 1); + if (guardCount != 0) { + return; + } + pendingEvents.clear(); + var stale = new ArrayList>(); + for (var list : pendingWaiters.values()) { + stale.addAll(list); + } + pendingWaiters.clear(); + if (!stale.isEmpty()) { + // Use a distinct phrasing from the overflow-eviction path so that + // debugging can tell the two failure modes apart. Matches the Rust + // SDK message (PR #1394 commit e0ff254f) and the TS SDK (commit + // c167bc3e). + LOG.warning("Pending session routing ended before session was registered; " + "completing " + stale.size() + + " parked request waiter(s) exceptionally"); + var ex = new RuntimeException("pending session routing ended before session was registered"); + for (var waiter : stale) { + waiter.completeExceptionally(ex); + } + } + } + + /** + * Attempt to buffer a {@code session.event} notification for a pending session. + * + *

+ * The {@code sessions} map is checked inside this synchronized method so that + * the "session not found → buffer" decision is atomic with + * {@link #registerAndFlush}'s "put in map → flush buffer" operation. + * + * @param sessionId + * the session id from the notification + * @param event + * the parsed event to buffer + * @param sessions + * the live sessions map (checked under lock) + * @return {@code true} if the event was buffered; {@code false} if the session + * is already registered (caller should dispatch directly) or pending + * routing is inactive (caller should drop) + */ + synchronized boolean tryBufferNotification(String sessionId, SessionEvent event, + Map sessions) { + if (sessions.containsKey(sessionId)) { + return false; // session found; caller dispatches directly + } + if (guardCount == 0) { + return false; // no pending routing; drop + } + var queue = pendingEvents.computeIfAbsent(sessionId, k -> new ArrayDeque<>()); + if (queue.size() >= BUFFER_LIMIT) { + queue.pollFirst(); + LOG.warning("Pending session notification buffer full for session " + sessionId + "; dropping oldest"); + } + queue.addLast(event); + return true; + } + + /** + * Attempt to park an inbound RPC request until the session is registered. + * + *

+ * Like {@link #tryBufferNotification}, the {@code sessions} map is checked + * under the lock to avoid TOCTOU races with {@link #registerAndFlush}. + * + * @param sessionId + * the session id from the request params + * @param sessions + * the live sessions map (checked under lock) + * @return a future that will be resolved with the {@link CopilotSession} when + * registered (or completed exceptionally when the guard is dropped), or + * {@code null} if the session is already registered (callers should use + * it directly) or if pending routing is inactive (caller should send + * error) + */ + synchronized CompletableFuture tryParkRequest(String sessionId, + Map sessions) { + CopilotSession existing = sessions.get(sessionId); + if (existing != null) { + return CompletableFuture.completedFuture(existing); + } + if (guardCount == 0) { + return null; // no pending; caller sends error + } + var future = new CompletableFuture(); + var list = pendingWaiters.computeIfAbsent(sessionId, k -> new ArrayList<>()); + if (list.size() >= BUFFER_LIMIT) { + // Cap parked waiters per session. When exceeded, evict the oldest + // and complete it with a distinct overflow message so the runtime + // gets an error response rather than hanging on a reply that will + // never arrive. Matches Rust PR #1394 (commit 491b4427) and TS + // (commit c167bc3e). + var oldest = list.remove(0); + LOG.warning("Pending session request waiter buffer full for session " + sessionId + " (limit=" + + BUFFER_LIMIT + "); evicting oldest request"); + oldest.completeExceptionally(new RuntimeException("pending session buffer overflow")); + } + list.add(future); + return future; + } + + /** + * Result of {@link #registerAndFlush}: buffered events to dispatch and parked + * waiters to complete. + */ + record FlushResult(List events, List> waiters) { + } + + /** + * Atomically register a session in the sessions map and drain any buffered + * events and parked waiters for that session. + * + *

+ * The {@code sessions.put} is performed inside the lock so that concurrent + * {@link #tryBufferNotification} / {@link #tryParkRequest} callers that haven't + * yet acquired the lock will see the session as registered. + * + * @param sessionId + * the session id to register + * @param session + * the session object + * @param sessions + * the live sessions map; the put is performed under lock + * @return buffered events and parked waiters to dispatch/complete outside the + * lock + */ + synchronized FlushResult registerAndFlush(String sessionId, CopilotSession session, + Map sessions) { + sessions.put(sessionId, session); + + var queue = pendingEvents.remove(sessionId); + var events = queue != null ? new ArrayList<>(queue) : Collections.emptyList(); + + var waiters = pendingWaiters.remove(sessionId); + var futures = waiters != null + ? new ArrayList<>(waiters) + : Collections.>emptyList(); + + return new FlushResult(events, futures); + } +} diff --git a/java/src/main/java/com/github/copilot/sdk/RpcHandlerDispatcher.java b/java/src/main/java/com/github/copilot/sdk/RpcHandlerDispatcher.java index 1d76d8b88..1b164aa28 100644 --- a/java/src/main/java/com/github/copilot/sdk/RpcHandlerDispatcher.java +++ b/java/src/main/java/com/github/copilot/sdk/RpcHandlerDispatcher.java @@ -9,8 +9,11 @@ import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.logging.Level; import java.util.logging.Logger; @@ -49,6 +52,10 @@ final class RpcHandlerDispatcher { private final Map sessions; private final LifecycleEventDispatcher lifecycleDispatcher; private final Executor executor; + private final PendingRoutingState pendingState; + + /** Timeout in seconds to wait for a pending cloud session to register. */ + private static final int PENDING_SESSION_TIMEOUT_SECONDS = 60; /** * Creates a dispatcher with session registry and lifecycle dispatcher. @@ -59,12 +66,16 @@ final class RpcHandlerDispatcher { * callback for dispatching lifecycle events * @param executor * the executor for async dispatch, or {@code null} for default + * @param pendingState + * the pending-routing state for cloud sessions, or {@code null} to + * disable pending routing */ RpcHandlerDispatcher(Map sessions, LifecycleEventDispatcher lifecycleDispatcher, - Executor executor) { + Executor executor, PendingRoutingState pendingState) { this.sessions = sessions; this.lifecycleDispatcher = lifecycleDispatcher; this.executor = executor; + this.pendingState = pendingState; } /** @@ -96,13 +107,28 @@ private void handleSessionEvent(JsonNode params) { JsonNode eventNode = params.get("event"); LOG.fine("Received session.event: " + eventNode); + if (eventNode == null) { + return; + } + SessionEvent event = MAPPER.treeToValue(eventNode, SessionEvent.class); + if (event == null) { + return; + } + + // Fast path: session already registered — dispatch directly. CopilotSession session = sessions.get(sessionId); - if (session != null && eventNode != null) { - SessionEvent event = MAPPER.treeToValue(eventNode, SessionEvent.class); - if (event != null) { - session.dispatchEvent(event); - } + if (session != null) { + session.dispatchEvent(event); + return; + } + + // Slow path: session not found. Attempt to buffer for a pending cloud + // session.create. The tryBufferNotification check is inside the pending + // state's lock so it's atomic with registerAndFlush's sessions.put. + if (pendingState != null && pendingState.tryBufferNotification(sessionId, event, sessions)) { + return; // buffered; will be replayed when session is registered } + // session not registered and no pending routing active; silently drop } catch (Exception e) { LOG.log(Level.SEVERE, "Error handling session event", e); } @@ -141,9 +167,8 @@ private void handleToolCall(JsonRpcClient rpc, String requestId, JsonNode params String toolName = params.get("toolName").asText(); JsonNode arguments = params.get("arguments"); - CopilotSession session = sessions.get(sessionId); + CopilotSession session = resolveSessionForRequest(sessionId, requestIdLong, rpc); if (session == null) { - rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId); return; } @@ -203,7 +228,20 @@ private void handlePermissionRequest(JsonRpcClient rpc, String requestId, JsonNo String sessionId = params.get("sessionId").asText(); JsonNode permissionRequest = params.get("permissionRequest"); + // Try to resolve the session; for a pending cloud session, park until + // registration. If not found and no pending routing, fall back to the + // protocol-correct DENIED_COULD_NOT_REQUEST_FROM_USER response. CopilotSession session = sessions.get(sessionId); + if (session == null && pendingState != null) { + CompletableFuture waiter = pendingState.tryParkRequest(sessionId, sessions); + if (waiter != null) { + try { + session = waiter.get(PENDING_SESSION_TIMEOUT_SECONDS, TimeUnit.SECONDS); + } catch (Exception e) { + session = null; + } + } + } if (session == null) { var result = new PermissionRequestResult() .setKind(PermissionRequestResultKind.DENIED_COULD_NOT_REQUEST_FROM_USER); @@ -254,11 +292,10 @@ private void handleUserInputRequest(JsonRpcClient rpc, String requestId, JsonNod JsonNode choicesNode = params.get("choices"); JsonNode allowFreeformNode = params.get("allowFreeform"); - CopilotSession session = sessions.get(sessionId); + CopilotSession session = resolveSessionForRequest(sessionId, requestIdLong, rpc); LOG.fine("Found session: " + (session != null)); if (session == null) { LOG.fine("Session not found, sending error"); - rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId); return; } @@ -309,9 +346,8 @@ private void handleExitPlanModeRequest(JsonRpcClient rpc, String requestId, Json try { String sessionId = params.get("sessionId").asText(); - CopilotSession session = sessions.get(sessionId); + CopilotSession session = resolveSessionForRequest(sessionId, requestIdLong, rpc); if (session == null) { - rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId); return; } @@ -363,9 +399,8 @@ private void handleAutoModeSwitchRequest(JsonRpcClient rpc, String requestId, Js try { String sessionId = params.get("sessionId").asText(); - CopilotSession session = sessions.get(sessionId); + CopilotSession session = resolveSessionForRequest(sessionId, requestIdLong, rpc); if (session == null) { - rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId); return; } @@ -409,9 +444,8 @@ private void handleHooksInvoke(JsonRpcClient rpc, String requestId, JsonNode par String hookType = params.get("hookType").asText(); JsonNode input = params.get("input"); - CopilotSession session = sessions.get(sessionId); + CopilotSession session = resolveSessionForRequest(sessionId, requestIdLong, rpc); if (session == null) { - rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId); return; } @@ -454,9 +488,8 @@ private void handleSystemMessageTransform(JsonRpcClient rpc, String requestId, J String sessionId = params.has("sessionId") ? params.get("sessionId").asText() : null; JsonNode sections = params.get("sections"); - CopilotSession session = sessionId != null ? sessions.get(sessionId) : null; + CopilotSession session = resolveSessionForRequest(sessionId, requestIdLong, rpc); if (session == null) { - rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId); return; } @@ -499,6 +532,105 @@ private static long parseRequestId(String requestId, String methodName) { } } + /** + * Resolve a session for an inbound RPC request, optionally parking until a + * pending cloud session is registered. + * + *

+ * If the session is already registered, returns it immediately. If pending + * routing is active (a cloud {@code session.create} is in flight), parks this + * call until the session is registered or the guard expires. On failure, sends + * an error response and returns {@code null}. + * + * @param sessionId + * the session id extracted from the request params + * @param requestId + * the JSON-RPC request id (for error responses) + * @param rpc + * the JSON-RPC client + * @return the resolved session, or {@code null} if an error was sent + */ + private CopilotSession resolveSessionForRequest(String sessionId, long requestId, JsonRpcClient rpc) { + if (sessionId == null) { + try { + rpc.sendErrorResponse(requestId, -32602, "Missing sessionId in request"); + } catch (IOException e) { + LOG.log(Level.SEVERE, "Failed to send missing-sessionId error", e); + } + return null; + } + + // Fast path: session already registered + CopilotSession session = sessions.get(sessionId); + if (session != null) { + return session; + } + + if (pendingState != null) { + CompletableFuture waiter = pendingState.tryParkRequest(sessionId, sessions); + if (waiter != null) { + if (waiter.isDone()) { + // Session was already registered under tryParkRequest's lock — + // complete synchronously. If the waiter was completed exceptionally + // (e.g. overflow eviction just beat us to isDone()), send the + // same -32603 error as the blocking path so the runtime isn't + // left waiting for a reply that will never arrive. + try { + return waiter.get(); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + try { + rpc.sendErrorResponse(requestId, -32603, "Session " + sessionId + " not registered: " + + (cause != null ? cause.getMessage() : e.getMessage())); + } catch (IOException ioe) { + LOG.log(Level.SEVERE, "Failed to send synchronous registration-failed error", ioe); + } + return null; + } catch (Exception e) { + // fall through to send error + } + } else { + try { + return waiter.get(PENDING_SESSION_TIMEOUT_SECONDS, TimeUnit.SECONDS); + } catch (TimeoutException e) { + try { + rpc.sendErrorResponse(requestId, -32603, "Session " + sessionId + " not registered within " + + PENDING_SESSION_TIMEOUT_SECONDS + "s"); + } catch (IOException ioe) { + LOG.log(Level.SEVERE, "Failed to send timeout error", ioe); + } + return null; + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + try { + rpc.sendErrorResponse(requestId, -32603, "Session " + sessionId + " not registered: " + + (cause != null ? cause.getMessage() : e.getMessage())); + } catch (IOException ioe) { + LOG.log(Level.SEVERE, "Failed to send registration-failed error", ioe); + } + return null; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + try { + rpc.sendErrorResponse(requestId, -32603, "Interrupted waiting for session " + sessionId); + } catch (IOException ioe) { + LOG.log(Level.SEVERE, "Failed to send interrupted error", ioe); + } + return null; + } + } + } + } + + // No pending routing active or pending routing returned null; send error + try { + rpc.sendErrorResponse(requestId, -32602, "Unknown session " + sessionId); + } catch (IOException e) { + LOG.log(Level.SEVERE, "Failed to send unknown-session error", e); + } + return null; + } + private void runAsync(Runnable task) { try { if (executor != null) { diff --git a/java/src/main/java/com/github/copilot/sdk/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/sdk/SessionRequestBuilder.java index 0cdc4f942..fd0e4b97a 100644 --- a/java/src/main/java/com/github/copilot/sdk/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/sdk/SessionRequestBuilder.java @@ -173,6 +173,82 @@ static CreateSessionRequest buildCreateRequest(SessionConfig config) { return buildCreateRequest(config, sessionId); } + /** + * Builds a CreateSessionRequest for cloud session creation. Unlike + * {@link #buildCreateRequest(SessionConfig, String)}, this method omits + * {@code sessionId} from the request — the runtime assigns the session id for + * cloud sessions. + * + * @param config + * the session configuration (may be null) + * @return the built request object with {@code sessionId} left null + */ + static CreateSessionRequest buildCloudCreateRequest(SessionConfig config) { + var request = new CreateSessionRequest(); + // Always request permission callbacks to enable deny-by-default behavior + request.setRequestPermission(true); + // Always send envValueMode=direct for MCP servers + request.setEnvValueMode("direct"); + // sessionId intentionally omitted: the runtime assigns the id for cloud + // sessions + if (config == null) { + return request; + } + + request.setModel(config.getModel()); + request.setClientName(config.getClientName()); + request.setReasoningEffort(config.getReasoningEffort()); + request.setTools(config.getTools()); + request.setSystemMessage(config.getSystemMessage()); + request.setAvailableTools(config.getAvailableTools()); + request.setExcludedTools(config.getExcludedTools()); + // provider intentionally omitted: cloud sessions use the runtime's provider + config.getEnableSessionTelemetry().ifPresent(request::setEnableSessionTelemetry); + if (config.getOnUserInputRequest() != null) { + request.setRequestUserInput(true); + } + if (config.getHooks() != null && config.getHooks().hasHooks()) { + request.setHooks(true); + } + request.setWorkingDirectory(config.getWorkingDirectory()); + if (config.isStreaming()) { + request.setStreaming(true); + } + config.getIncludeSubAgentStreamingEvents().ifPresent(request::setIncludeSubAgentStreamingEvents); + request.setMcpServers(config.getMcpServers()); + request.setCustomAgents(config.getCustomAgents()); + request.setDefaultAgent(config.getDefaultAgent()); + request.setAgent(config.getAgent()); + request.setInfiniteSessions(config.getInfiniteSessions()); + request.setSkillDirectories(config.getSkillDirectories()); + request.setInstructionDirectories(config.getInstructionDirectories()); + request.setDisabledSkills(config.getDisabledSkills()); + request.setConfigDir(config.getConfigDir()); + config.getEnableConfigDiscovery().ifPresent(request::setEnableConfigDiscovery); + request.setModelCapabilities(config.getModelCapabilities()); + + if (config.getCommands() != null && !config.getCommands().isEmpty()) { + var wireCommands = config.getCommands().stream() + .map(c -> new CommandWireDefinition(c.getName(), c.getDescription())) + .collect(java.util.stream.Collectors.toList()); + request.setCommands(wireCommands); + } + if (config.getOnElicitationRequest() != null) { + request.setRequestElicitation(true); + } + if (config.getOnExitPlanMode() != null) { + request.setRequestExitPlanMode(true); + } + if (config.getOnAutoModeSwitch() != null) { + request.setRequestAutoModeSwitch(true); + } + request.setGitHubToken(config.getGitHubToken()); + request.setRemoteSession(config.getRemoteSession()); + request.setCloud(config.getCloud()); + + return request; + } + /** * Builds a ResumeSessionRequest from the given session ID and configuration. * diff --git a/java/src/main/java/com/github/copilot/sdk/json/CreateSessionResponse.java b/java/src/main/java/com/github/copilot/sdk/json/CreateSessionResponse.java index b47af050b..bf4d47dc9 100644 --- a/java/src/main/java/com/github/copilot/sdk/json/CreateSessionResponse.java +++ b/java/src/main/java/com/github/copilot/sdk/json/CreateSessionResponse.java @@ -11,12 +11,15 @@ * @param workspacePath * the workspace path, or {@code null} if infinite sessions are * disabled + * @param remoteUrl + * the remote URL for a cloud session, or {@code null} for local + * sessions * @param capabilities * the capabilities reported by the host, or {@code null} * @since 1.0.0 */ @JsonInclude(JsonInclude.Include.NON_NULL) public record CreateSessionResponse(@JsonProperty("sessionId") String sessionId, - @JsonProperty("workspacePath") String workspacePath, + @JsonProperty("workspacePath") String workspacePath, @JsonProperty("remoteUrl") String remoteUrl, @JsonProperty("capabilities") SessionCapabilities capabilities) { } diff --git a/java/src/test/java/com/github/copilot/sdk/CloudSessionTest.java b/java/src/test/java/com/github/copilot/sdk/CloudSessionTest.java new file mode 100644 index 000000000..2177655a5 --- /dev/null +++ b/java/src/test/java/com/github/copilot/sdk/CloudSessionTest.java @@ -0,0 +1,515 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.sdk; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.InputStream; +import java.lang.reflect.Field; +import java.net.ServerSocket; +import java.net.Socket; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.copilot.sdk.json.CloudSessionOptions; +import com.github.copilot.sdk.json.CloudSessionRepository; +import com.github.copilot.sdk.json.CreateSessionRequest; +import com.github.copilot.sdk.json.PermissionHandler; +import com.github.copilot.sdk.json.ProviderConfig; +import com.github.copilot.sdk.json.SessionConfig; +import com.github.copilot.sdk.json.SessionLifecycleEvent; +import com.github.copilot.sdk.json.ToolDefinition; +import com.github.copilot.sdk.json.ToolResultObject; + +/** + * Tests for {@link CopilotClient#createCloudSession} and + * {@link CopilotClient#createSession} cloud-config rejection. + * + *

+ * Covers: + *

    + *
  1. {@code createSession} rejects cloud config
  2. + *
  3. Wire payload omits {@code sessionId} and includes {@code cloud}
  4. + *
  5. {@code createCloudSession} rejects caller-provided {@code sessionId}
  6. + *
  7. {@code createCloudSession} rejects caller-provided {@code provider}
  8. + *
  9. {@code createCloudSession} requires {@code cloud} to be set
  10. + *
  11. Early {@code session.event} notifications are buffered and replayed
  12. + *
  13. Inbound RPC requests are parked until the session is registered
  14. + *
+ */ +class CloudSessionTest { + + private static final ObjectMapper MAPPER = JsonRpcClient.getObjectMapper(); + private static final int SOCKET_TIMEOUT_MS = 5000; + + // Socket-pair fields used by routing-related tests + private Socket clientSideSocket; + private Socket serverSideSocket; + private JsonRpcClient rpc; + private Map sessions; + private CopyOnWriteArrayList lifecycleEvents; + private PendingRoutingState pendingState; + private RpcHandlerDispatcher dispatcher; + private InputStream responseStream; + private Map> handlers; + + @BeforeEach + void setup() throws Exception { + try (ServerSocket ss = new ServerSocket(0)) { + clientSideSocket = new Socket("localhost", ss.getLocalPort()); + serverSideSocket = ss.accept(); + } + serverSideSocket.setSoTimeout(SOCKET_TIMEOUT_MS); + + rpc = JsonRpcClient.fromSocket(clientSideSocket); + responseStream = serverSideSocket.getInputStream(); + + sessions = new ConcurrentHashMap<>(); + lifecycleEvents = new CopyOnWriteArrayList<>(); + pendingState = new PendingRoutingState(); + + dispatcher = new RpcHandlerDispatcher(sessions, lifecycleEvents::add, null, pendingState); + dispatcher.registerHandlers(rpc); + + // Extract registered handlers via reflection (same pattern as + // RpcHandlerDispatcherTest) + Field f = JsonRpcClient.class.getDeclaredField("notificationHandlers"); + f.setAccessible(true); + @SuppressWarnings("unchecked") + Map> h = (Map>) f.get(rpc); + handlers = h; + } + + @AfterEach + void teardown() throws Exception { + if (rpc != null) { + rpc.close(); + } + if (serverSideSocket != null) { + serverSideSocket.close(); + } + if (clientSideSocket != null) { + clientSideSocket.close(); + } + } + + private void invokeHandler(String method, String requestId, JsonNode params) { + handlers.get(method).accept(requestId, params); + } + + private JsonNode readResponse() throws Exception { + StringBuilder header = new StringBuilder(); + while (!header.toString().endsWith("\r\n\r\n")) { + int b = responseStream.read(); + if (b == -1) { + throw new java.io.IOException("Unexpected end of stream"); + } + header.append((char) b); + } + String headerStr = header.toString().trim(); + int idx = headerStr.indexOf(':'); + int contentLength = Integer.parseInt(headerStr.substring(idx + 1).trim()); + byte[] body = responseStream.readNBytes(contentLength); + return MAPPER.readTree(body); + } + + // ========================================================================= + // Test 1: createSession rejects cloud config + // ========================================================================= + + @Test + void createSession_rejectsCloudConfig() { + var client = new CopilotClient(); + var config = new SessionConfig().setCloud(new CloudSessionOptions()) + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL); + + var future = client.createSession(config); + + var ex = assertThrows(ExecutionException.class, future::get, "createSession should fail with cloud config set"); + assertInstanceOf(IllegalArgumentException.class, ex.getCause(), "Cause should be IllegalArgumentException"); + assertTrue(ex.getCause().getMessage().contains("cloud"), "Error message should mention 'cloud'"); + + try { + client.forceStop().get(5, TimeUnit.SECONDS); + } catch (Exception ignored) { + } + } + + // ========================================================================= + // Test 2: wire payload omits sessionId and includes cloud + // ========================================================================= + + @Test + void buildCloudCreateRequest_omitsSessionIdAndIncludesCloud() throws Exception { + var cloud = new CloudSessionOptions() + .setRepository(new CloudSessionRepository().setOwner("github").setName("copilot-sdk")); + var config = new SessionConfig().setCloud(cloud).setModel("gpt-5") + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL); + + CreateSessionRequest request = SessionRequestBuilder.buildCloudCreateRequest(config); + + // Java-level assertions + assertNull(request.getSessionId(), "sessionId must be null on the cloud create request"); + assertNotNull(request.getCloud(), "cloud must be set on the cloud create request"); + assertEquals("gpt-5", request.getModel()); + + // Serialize to JSON and verify wire shape + String json = MAPPER.writeValueAsString(request); + JsonNode tree = MAPPER.readTree(json); + + assertFalse(tree.has("sessionId"), "sessionId must be absent from serialized JSON (NON_NULL omits it)"); + assertTrue(tree.has("cloud"), "cloud must be present in serialized JSON"); + assertTrue(tree.get("cloud").has("repository"), "cloud.repository must be present"); + assertEquals("github", tree.get("cloud").get("repository").get("owner").asText()); + } + + @Test + void buildCloudCreateRequest_sessionIdOmittedEvenWhenModelIsNull() throws Exception { + // Minimal config: only cloud set, no model + var config = new SessionConfig().setCloud(new CloudSessionOptions()); + + CreateSessionRequest request = SessionRequestBuilder.buildCloudCreateRequest(config); + + assertNull(request.getSessionId()); + String json = MAPPER.writeValueAsString(request); + JsonNode tree = MAPPER.readTree(json); + assertFalse(tree.has("sessionId"), "sessionId must never appear in cloud create wire payload"); + assertTrue(tree.has("cloud")); + } + + // ========================================================================= + // Test 3: createCloudSession rejects caller-provided sessionId + // ========================================================================= + + @Test + void createCloudSession_rejectsCallerSessionId() { + var client = new CopilotClient(); + var config = new SessionConfig().setCloud(new CloudSessionOptions()).setSessionId("my-caller-session") + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL); + + var future = client.createCloudSession(config); + + var ex = assertThrows(ExecutionException.class, future::get, + "createCloudSession should fail when sessionId is set"); + assertInstanceOf(IllegalArgumentException.class, ex.getCause()); + assertTrue(ex.getCause().getMessage().contains("sessionId"), "Error message should mention 'sessionId'"); + + try { + client.forceStop().get(5, TimeUnit.SECONDS); + } catch (Exception ignored) { + } + } + + // ========================================================================= + // Test 4: createCloudSession rejects caller-provided provider + // ========================================================================= + + @Test + void createCloudSession_rejectsCallerProvider() { + var client = new CopilotClient(); + var config = new SessionConfig().setCloud(new CloudSessionOptions()) + .setProvider(new ProviderConfig().setType("openai")) + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL); + + var future = client.createCloudSession(config); + + var ex = assertThrows(ExecutionException.class, future::get, + "createCloudSession should fail when provider is set"); + assertInstanceOf(IllegalArgumentException.class, ex.getCause()); + assertTrue(ex.getCause().getMessage().contains("provider"), "Error message should mention 'provider'"); + + try { + client.forceStop().get(5, TimeUnit.SECONDS); + } catch (Exception ignored) { + } + } + + // ========================================================================= + // Test 5: createCloudSession requires cloud to be set + // ========================================================================= + + @Test + void createCloudSession_requiresCloud() { + var client = new CopilotClient(); + var config = new SessionConfig().setModel("gpt-5").setOnPermissionRequest(PermissionHandler.APPROVE_ALL); + // cloud is NOT set + + var future = client.createCloudSession(config); + + var ex = assertThrows(ExecutionException.class, future::get, + "createCloudSession should fail when cloud is not set"); + assertInstanceOf(IllegalArgumentException.class, ex.getCause()); + assertTrue(ex.getCause().getMessage().contains("cloud"), "Error message should mention 'cloud'"); + + try { + client.forceStop().get(5, TimeUnit.SECONDS); + } catch (Exception ignored) { + } + } + + // ========================================================================= + // Test 6: early session.event notifications are buffered and replayed + // ========================================================================= + + @Test + void bufferEarlySessionEventNotifications() throws Exception { + // Enter pending routing mode (simulates createCloudSession in-flight) + pendingState.incrementGuard(); + + String pendingSessionId = "cloud-session-abc"; + + // Dispatch a session.event while no session is registered yet. + ObjectNode params = MAPPER.createObjectNode(); + params.put("sessionId", pendingSessionId); + ObjectNode event = params.putObject("event"); + event.put("type", "sessionStart"); + ObjectNode data = event.putObject("data"); + data.put("sessionId", pendingSessionId); + + invokeHandler("session.event", null, params); + + // Give the (synchronous) handler a moment — no session registered yet, so the + // event should be buffered, not dispatched. + Thread.sleep(50); + + // Create the session object and register it via registerAndFlush, which + // atomically inserts the session into the map and drains the buffer. + var session = new CopilotSession(pendingSessionId, rpc); + var dispatched = new CopyOnWriteArrayList<>(); + session.on(dispatched::add); + + var flush = pendingState.registerAndFlush(pendingSessionId, session, sessions); + + // Replay buffered events into the session (simulates what createCloudSession + // does) + for (var buffered : flush.events()) { + session.dispatchEvent(buffered); + } + + // Complete parked waiters (none in this test) + for (var waiter : flush.waiters()) { + waiter.complete(session); + } + + // Release the guard + pendingState.decrementGuard(); + + // The buffered session.event should now have been replayed to the session + Thread.sleep(50); + assertEquals(1, dispatched.size(), "Buffered notification should have been replayed to the session"); + } + + @Test + void bufferRespectsSizeLimit() throws Exception { + pendingState.incrementGuard(); + String sid = "cloud-overflow-test"; + + // Send more than BUFFER_LIMIT events + int overLimit = PendingRoutingState.BUFFER_LIMIT + 10; + for (int i = 0; i < overLimit; i++) { + ObjectNode params = MAPPER.createObjectNode(); + params.put("sessionId", sid); + ObjectNode event = params.putObject("event"); + event.put("type", "assistantMessage"); + event.putObject("data").put("content", "msg-" + i); + invokeHandler("session.event", null, params); + } + + var session = new CopilotSession(sid, rpc); + var flush = pendingState.registerAndFlush(sid, session, sessions); + + // Should have been capped at BUFFER_LIMIT; oldest entries were dropped + assertEquals(PendingRoutingState.BUFFER_LIMIT, flush.events().size(), + "Buffer should be capped at BUFFER_LIMIT"); + + pendingState.decrementGuard(); + } + + // ========================================================================= + // Test 7: inbound RPC requests are parked until the session is registered + // ========================================================================= + + @Test + void parksInboundRequestsUntilRegistration() throws Exception { + String pendingSessionId = "cloud-session-xyz"; + + // Register a tool on the (not-yet-created) session by pre-creating it without + // registering in the sessions map yet. We'll use the pending state directly. + pendingState.incrementGuard(); + + // In a background thread, send a tool.call request for the pending session. + // The handler should park until the session is registered. + var toolCallFuture = CompletableFuture.runAsync(() -> { + ObjectNode params = MAPPER.createObjectNode(); + params.put("sessionId", pendingSessionId); + params.put("toolCallId", "tc-1"); + params.put("toolName", "say_hello"); + params.set("arguments", MAPPER.createObjectNode()); + invokeHandler("tool.call", "42", params); + }); + + // Brief pause to allow the handler thread to start and park + Thread.sleep(100); + + // Create and register the session with the requested tool + var session = new CopilotSession(pendingSessionId, rpc); + session.registerTools(java.util.List.of( + ToolDefinition.create("say_hello", "Greets the user", Map.of("type", "object", "properties", Map.of()), + inv -> CompletableFuture.completedFuture(ToolResultObject.success("hello!"))))); + + var flush = pendingState.registerAndFlush(pendingSessionId, session, sessions); + + // No buffered notifications in this test + assertTrue(flush.events().isEmpty(), "No buffered events expected"); + + // Complete any parked request waiters + for (var waiter : flush.waiters()) { + waiter.complete(session); + } + + pendingState.decrementGuard(); + + // Wait for the handler to finish (it was parked on the waiter) + toolCallFuture.get(5, TimeUnit.SECONDS); + + // The tool.call handler should have executed and sent a response back on the + // wire + JsonNode response = readResponse(); + assertNotNull(response, "Should have received a tool response"); + assertEquals(42, response.get("id").asInt(), "Response id should match request id"); + assertNotNull(response.get("result"), "Tool call should produce a result"); + } + + @Test + void parkedRequestFailsExceptionallyWhenGuardDroppedWithoutRegistration() throws Exception { + String pendingSessionId = "cloud-session-dropped"; + + pendingState.incrementGuard(); + + // Park a request in the background + var toolCallFuture = CompletableFuture.runAsync(() -> { + ObjectNode params = MAPPER.createObjectNode(); + params.put("sessionId", pendingSessionId); + params.put("toolCallId", "tc-2"); + params.put("toolName", "any_tool"); + params.set("arguments", MAPPER.createObjectNode()); + invokeHandler("tool.call", "99", params); + }); + + Thread.sleep(100); + + // Drop the guard WITHOUT registering the session. decrementGuard now + // completes parked waiters internally with the canonical message. + pendingState.decrementGuard(); + + // The handler should receive the exceptional completion and send an + // error response + toolCallFuture.get(5, TimeUnit.SECONDS); + + JsonNode response = readResponse(); + assertNotNull(response, "Should have received an error response"); + assertEquals(99, response.get("id").asInt()); + assertNotNull(response.get("error"), "Response should be an error (not a result)"); + String errorMessage = response.get("error").get("message").asText(); + assertTrue(errorMessage.contains("routing ended before session was registered"), + "Error message should contain the canonical guard-drop phrase; got: " + errorMessage); + } + + // ========================================================================= + // Test 8: overflow path — oldest parked waiter gets the overflow message + // ========================================================================= + + @Test + void parkedRequestWaiterBuffer_overflow_evictsOldestWithOverflowMessage() throws Exception { + pendingState.incrementGuard(); + String sid = "cloud-overflow-requests"; + + // Park BUFFER_LIMIT + 1 waiters via tryParkRequest. The 129th call must + // evict the very first waiter and complete it with the overflow message. + var waiters = new java.util.ArrayList>(); + for (int i = 0; i < PendingRoutingState.BUFFER_LIMIT + 1; i++) { + waiters.add(pendingState.tryParkRequest(sid, sessions)); + } + + // The first waiter (oldest) must have been evicted with the overflow message. + CompletableFuture oldest = waiters.get(0); + assertTrue(oldest.isCompletedExceptionally(), "Oldest waiter should be completed exceptionally on overflow"); + ExecutionException ex = assertThrows(ExecutionException.class, oldest::get); + assertEquals("pending session buffer overflow", ex.getCause().getMessage()); + + // The remaining BUFFER_LIMIT waiters should still be pending. + for (int i = 1; i <= PendingRoutingState.BUFFER_LIMIT; i++) { + assertFalse(waiters.get(i).isDone(), "Waiter " + i + " should still be pending after overflow eviction"); + } + + // Registering the session resolves the remaining 128 waiters normally. + var session = new CopilotSession(sid, rpc); + var flush = pendingState.registerAndFlush(sid, session, sessions); + assertEquals(PendingRoutingState.BUFFER_LIMIT, flush.waiters().size(), + "registerAndFlush should return all non-evicted waiters"); + for (var waiter : flush.waiters()) { + waiter.complete(session); + } + for (int i = 1; i <= PendingRoutingState.BUFFER_LIMIT; i++) { + assertFalse(waiters.get(i).isCompletedExceptionally(), + "Waiter " + i + " should complete normally, not exceptionally"); + assertEquals(session, waiters.get(i).get(1, TimeUnit.SECONDS)); + } + + pendingState.decrementGuard(); + } + + // ========================================================================= + // Test 9: guard-drop message is distinct from overflow message + // ========================================================================= + + @Test + void parkedRequestWaiter_guardDropMessage_isDistinctFromOverflowMessage() throws Exception { + String pendingSessionId = "cloud-session-distinct-msg"; + + pendingState.incrementGuard(); + + // Park a request in the background via the full handler path so the + // response travels over the socket — this mirrors the real runtime flow. + var toolCallFuture = CompletableFuture.runAsync(() -> { + ObjectNode params = MAPPER.createObjectNode(); + params.put("sessionId", pendingSessionId); + params.put("toolCallId", "tc-distinct"); + params.put("toolName", "noop"); + params.set("arguments", MAPPER.createObjectNode()); + invokeHandler("tool.call", "77", params); + }); + + Thread.sleep(100); + + // Drop the guard without registration. decrementGuard completes waiters + // internally with the canonical guard-drop message. + pendingState.decrementGuard(); + + toolCallFuture.get(5, TimeUnit.SECONDS); + + JsonNode response = readResponse(); + assertEquals(77, response.get("id").asInt()); + assertNotNull(response.get("error"), "Should be an error response"); + String msg = response.get("error").get("message").asText(); + + // Must contain the guard-drop phrase — NOT the overflow phrase. + assertTrue(msg.contains("routing ended before session was registered"), + "Guard-drop error must use the routing-ended phrase; got: " + msg); + assertFalse(msg.contains("buffer overflow"), "Guard-drop error must NOT use the overflow phrase; got: " + msg); + } +} diff --git a/java/src/test/java/com/github/copilot/sdk/RpcHandlerDispatcherTest.java b/java/src/test/java/com/github/copilot/sdk/RpcHandlerDispatcherTest.java index 7453a7b26..5c5bcbfd6 100644 --- a/java/src/test/java/com/github/copilot/sdk/RpcHandlerDispatcherTest.java +++ b/java/src/test/java/com/github/copilot/sdk/RpcHandlerDispatcherTest.java @@ -66,7 +66,7 @@ void setup() throws Exception { sessions = new ConcurrentHashMap<>(); lifecycleEvents = new CopyOnWriteArrayList<>(); - dispatcher = new RpcHandlerDispatcher(sessions, lifecycleEvents::add, null); + dispatcher = new RpcHandlerDispatcher(sessions, lifecycleEvents::add, null, null); dispatcher.registerHandlers(rpc); // Extract the registered handlers via reflection so we can invoke them directly diff --git a/nodejs/README.md b/nodejs/README.md index 1cb6e7836..daa723392 100644 --- a/nodejs/README.md +++ b/nodejs/README.md @@ -127,6 +127,24 @@ Create a new conversation session. Resume an existing session. Returns the session with `workspacePath` populated if infinite sessions were enabled. +##### `createCloudSession(config: SessionConfig): Promise` + +Create a Mission Control–backed cloud session. The runtime owns the session ID for cloud sessions: do **not** set `sessionId` or `provider` on the config (the SDK rejects both). The SDK omits `sessionId` from the `session.create` request and registers the resulting session under the id the runtime returns. + +Any `session.event` notifications or inbound JSON-RPC requests that arrive between sending `session.create` and receiving its response are buffered (bounded, drop-oldest) and replayed once the returned id is registered, so early events aren't lost. + +`config.cloud` is required. Passing `cloud` to `createSession` instead throws — use this method. + +```typescript +const session = await client.createCloudSession({ + onPermissionRequest: approveAll, + cloud: { + repository: { owner: "github", name: "copilot-sdk", branch: "main" }, + }, +}); +console.log("cloud session id:", session.sessionId); +``` + ##### `ping(message?: string): Promise<{ message: string; timestamp: string }>` Ping the server to check connectivity. diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 21563d598..66f365084 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -249,6 +249,30 @@ export class CopilotClient { private actualHost: string = "localhost"; private state: "disconnected" | "connecting" | "connected" | "error" = "disconnected"; private sessions: Map = new Map(); + /** + * Refcount of in-flight cloud `session.create` calls. While >0, the + * notification/request handlers buffer messages addressed to session + * ids that are not yet registered, so events the runtime emits between + * `session.create` and its response are not lost. + */ + private pendingRoutingCount: number = 0; + /** Buffered `session.event` payloads keyed by runtime-assigned sessionId. */ + private pendingSessionEvents: Map = new Map(); + /** + * Waiters parked by request handlers when the addressed session id is + * not yet registered but pending routing is active. Resolved when the + * session is registered; rejected when pending mode ends without + * registration so the JSON-RPC request surfaces a clear error. + */ + private pendingSessionWaiters: Map< + string, + Array<{ resolve: (s: CopilotSession) => void; reject: (e: Error) => void }> + > = new Map(); + /** + * Upper bound on buffered notifications per pending session id. Cloud + * handshakes are short; drop-oldest above this is acceptable. + */ + private static readonly PENDING_SESSION_BUFFER_LIMIT = 128; private stderrBuffer: string = ""; // Captures CLI stderr for error messages /** Resolved connection mode chosen in the constructor. */ private connectionConfig: InternalRuntimeConnection; @@ -788,6 +812,11 @@ export class CopilotClient { * ``` */ async createSession(config: SessionConfig): Promise { + if (config.cloud) { + throw new Error( + "CopilotClient.createSession does not support cloud sessions; use createCloudSession instead." + ); + } if (!this.connection) { await this.start(); } @@ -899,6 +928,249 @@ export class CopilotClient { return session; } + /** + * Creates a Mission Control–backed cloud session. + * + * The runtime owns the session ID for cloud sessions: do **not** set + * `sessionId` or `provider` on the config (the SDK rejects both). The + * SDK omits `sessionId` from the `session.create` request and registers + * the resulting session under the id the runtime returns. + * + * Any `session.event` notifications or inbound JSON-RPC requests that + * arrive between sending `session.create` and receiving its response are + * buffered (bounded, drop-oldest) and replayed once the returned id is + * registered, so early events aren't lost. + * + * **Known limitation:** inbound `sessionFs.*` requests (the generated + * client-session API handlers) are not pending-buffered today. In practice + * the runtime does not initiate `sessionFs.*` calls before the + * `session.create` response, so this is theoretical; if needed, the + * generated `registerClientSessionApiHandlers` shim can be updated to + * support async session resolution. + * + * @example + * ```typescript + * const session = await client.createCloudSession({ + * onPermissionRequest: approveAll, + * cloud: { repository: { owner: "github", name: "copilot-sdk", branch: "main" } }, + * }); + * console.log(session.sessionId); + * ``` + */ + async createCloudSession(config: SessionConfig): Promise { + if (!config.cloud) { + throw new Error("CopilotClient.createCloudSession requires config.cloud to be set."); + } + if (config.sessionId !== undefined) { + throw new Error( + "CopilotClient.createCloudSession does not support a caller-provided sessionId; the runtime assigns one." + ); + } + if (config.provider !== undefined) { + throw new Error( + "CopilotClient.createCloudSession does not support config.provider; cloud sessions use the runtime's provider." + ); + } + if (!this.connection) { + await this.start(); + } + + const { wirePayload: wireSystemMessage, transformCallbacks } = extractTransformCallbacks( + config.systemMessage + ); + + const guard = this.beginPendingSessionRouting(); + + let response: { + sessionId?: unknown; + workspacePath?: string; + remoteUrl?: string; + capabilities?: { ui?: { elicitation?: boolean } }; + }; + try { + response = (await this.connection!.sendRequest("session.create", { + ...(await getTraceContext(this.onGetTraceContext)), + model: config.model, + // sessionId intentionally omitted: the runtime assigns the id for cloud sessions. + clientName: config.clientName, + reasoningEffort: config.reasoningEffort, + tools: config.tools?.map((tool) => ({ + name: tool.name, + description: tool.description, + parameters: toJsonSchema(tool.parameters), + overridesBuiltInTool: tool.overridesBuiltInTool, + skipPermission: tool.skipPermission, + })), + commands: config.commands?.map((cmd) => ({ + name: cmd.name, + description: cmd.description, + })), + systemMessage: wireSystemMessage, + availableTools: config.availableTools, + excludedTools: config.excludedTools, + enableSessionTelemetry: config.enableSessionTelemetry, + modelCapabilities: config.modelCapabilities, + requestPermission: !!config.onPermissionRequest, + requestUserInput: !!config.onUserInputRequest, + requestElicitation: !!config.onElicitationRequest, + requestExitPlanMode: !!config.onExitPlanModeRequest, + requestAutoModeSwitch: !!config.onAutoModeSwitchRequest, + hooks: !!(config.hooks && Object.values(config.hooks).some(Boolean)), + workingDirectory: config.workingDirectory, + streaming: config.streaming, + includeSubAgentStreamingEvents: config.includeSubAgentStreamingEvents ?? true, + mcpServers: toWireMcpServers(config.mcpServers), + envValueMode: "direct", + customAgents: toWireCustomAgents(config.customAgents), + defaultAgent: config.defaultAgent, + agent: config.agent, + configDir: config.configDir, + enableConfigDiscovery: config.enableConfigDiscovery, + skillDirectories: config.skillDirectories, + instructionDirectories: config.instructionDirectories, + disabledSkills: config.disabledSkills, + infiniteSessions: config.infiniteSessions, + gitHubToken: config.gitHubToken, + remoteSession: config.remoteSession, + cloud: config.cloud, + })) as { + sessionId?: unknown; + workspacePath?: string; + remoteUrl?: string; + capabilities?: { ui?: { elicitation?: boolean } }; + }; + } catch (e) { + guard.dispose(); + throw e; + } + + if (!response || typeof response !== "object" || typeof response.sessionId !== "string") { + // No id to issue session.destroy against; release the guard and surface the error. + // Any runtime session created on the other side may leak. + + console.warn( + "Cloud session.create response missing sessionId; runtime session may leak." + ); + guard.dispose(); + throw new Error( + "Cloud session.create response did not include a sessionId; cannot register session." + ); + } + + const sessionId = response.sessionId; + const session = new CopilotSession( + sessionId, + this.connection!, + undefined, + this.onGetTraceContext + ); + session.registerTools(config.tools); + session.registerCommands(config.commands); + session.registerPermissionHandler(config.onPermissionRequest); + if (config.onUserInputRequest) { + session.registerUserInputHandler(config.onUserInputRequest); + } + if (config.onElicitationRequest) { + session.registerElicitationHandler(config.onElicitationRequest); + } + if (config.onExitPlanModeRequest) { + session.registerExitPlanModeHandler(config.onExitPlanModeRequest); + } + if (config.onAutoModeSwitchRequest) { + session.registerAutoModeSwitchHandler(config.onAutoModeSwitchRequest); + } + if (config.hooks) { + session.registerHooks(config.hooks); + } + if (transformCallbacks) { + session.registerTransformCallbacks(transformCallbacks); + } + if (config.onEvent) { + session.on(config.onEvent); + } + try { + this.sessions.set(sessionId, session); + this.setupSessionFs(session, config); + session["_workspacePath"] = response.workspacePath; + session["_remoteUrl"] = response.remoteUrl; + session.setCapabilities(response.capabilities); + + // Drain anything that arrived during the in-flight session.create + // into the freshly-registered session before releasing the guard. + this.flushPendingForSession(sessionId, session); + } catch (e) { + // Roll back partial registration so a failed post-response setup + // (e.g. setupSessionFs throwing because sessionFs is misconfigured) + // doesn't leave a half-wired session in this.sessions. + this.sessions.delete(sessionId); + guard.dispose(); + throw e; + } + guard.dispose(); + + return session; + } + + /** + * Enter pending-routing mode. While the returned guard is undisposed, + * notifications and inbound requests addressed to session ids that are + * not yet registered are buffered (up to + * {@link CopilotClient.PENDING_SESSION_BUFFER_LIMIT} per id, drop-oldest) + * and replayed on registration. When the last guard is disposed, any + * still-buffered messages are dropped and parked request waiters are + * rejected so the calling JSON-RPC requests don't hang forever. + */ + private beginPendingSessionRouting(): { dispose: () => void } { + this.pendingRoutingCount++; + let disposed = false; + return { + dispose: () => { + if (disposed) return; + disposed = true; + this.pendingRoutingCount--; + if (this.pendingRoutingCount === 0) { + this.pendingSessionEvents.clear(); + for (const waiters of this.pendingSessionWaiters.values()) { + for (const w of waiters) { + // Distinct phrasing from the overflow-eviction path so the + // runtime / future debugging can tell the two cases apart. + // Matches the Rust SDK message in PR #1394 (commit e0ff254f). + w.reject( + new Error( + "pending session routing ended before session was registered" + ) + ); + } + } + this.pendingSessionWaiters.clear(); + } + }, + }; + } + + /** + * Drain buffered events and pending request waiters for `sessionId` into + * the freshly-registered session. Called from {@link createCloudSession} + * after the session is in `this.sessions` and before the pending guard + * is released. + */ + private flushPendingForSession(sessionId: string, session: CopilotSession): void { + const events = this.pendingSessionEvents.get(sessionId); + if (events) { + this.pendingSessionEvents.delete(sessionId); + for (const event of events) { + session._dispatchEvent(event); + } + } + const waiters = this.pendingSessionWaiters.get(sessionId); + if (waiters) { + this.pendingSessionWaiters.delete(sessionId); + for (const w of waiters) { + w.resolve(session); + } + } + } + /** * Resumes an existing conversation session by its ID. * @@ -1908,9 +2180,27 @@ export class CopilotClient { return; } - const session = this.sessions.get((notification as { sessionId: string }).sessionId); + const sessionId = (notification as { sessionId: string }).sessionId; + const event = (notification as { event: SessionEvent }).event; + const session = this.sessions.get(sessionId); if (session) { - session._dispatchEvent((notification as { event: SessionEvent }).event); + session._dispatchEvent(event); + return; + } + if (this.pendingRoutingCount > 0) { + let buf = this.pendingSessionEvents.get(sessionId); + if (!buf) { + buf = []; + this.pendingSessionEvents.set(sessionId, buf); + } + if (buf.length >= CopilotClient.PENDING_SESSION_BUFFER_LIMIT) { + buf.shift(); + + console.warn( + `pending session notification buffer full for ${sessionId}; dropping oldest` + ); + } + buf.push(event); } } @@ -1969,6 +2259,42 @@ export class CopilotClient { } } + /** + * Look up the session for an inbound request. If the session is not yet + * registered but a cloud `session.create` is in flight (pending routing + * mode is active), park the request until the session is registered or + * pending mode ends. Otherwise throw immediately. + */ + private resolveSession(sessionId: string): Promise { + const session = this.sessions.get(sessionId); + if (session) { + return Promise.resolve(session); + } + if (this.pendingRoutingCount > 0) { + return new Promise((resolve, reject) => { + let waiters = this.pendingSessionWaiters.get(sessionId); + if (!waiters) { + waiters = []; + this.pendingSessionWaiters.set(sessionId, waiters); + } + // Cap parked waiters per session id. When exceeded, reject the + // oldest with a distinct message; vscode-jsonrpc surfaces the + // rejection as a JSON-RPC error response to the runtime (rather + // than silently dropping, which would hang the runtime on the + // request id). Matches the Rust SDK fix in PR #1394 (commit + // 491b4427). + if (waiters.length >= CopilotClient.PENDING_SESSION_BUFFER_LIMIT) { + const oldest = waiters.shift(); + if (oldest) { + oldest.reject(new Error("pending session buffer overflow")); + } + } + waiters.push({ resolve, reject }); + }); + } + return Promise.reject(new Error(`Session not found: ${sessionId}`)); + } + private async handleUserInputRequest(params: { sessionId: string; question: string; @@ -1983,10 +2309,7 @@ export class CopilotClient { throw new Error("Invalid user input request payload"); } - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } + const session = await this.resolveSession(params.sessionId); const result = await session._handleUserInputRequest({ question: params.question, @@ -2009,10 +2332,7 @@ export class CopilotClient { throw new Error("Invalid exit plan mode request payload"); } - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } + const session = await this.resolveSession(params.sessionId); return await session._handleExitPlanModeRequest({ summary: params.summary, @@ -2029,10 +2349,7 @@ export class CopilotClient { throw new Error("Invalid auto mode switch request payload"); } - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } + const session = await this.resolveSession(params.sessionId); const response = await session._handleAutoModeSwitchRequest({ errorCode: params.errorCode, @@ -2054,10 +2371,7 @@ export class CopilotClient { throw new Error("Invalid hooks invoke payload"); } - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } + const session = await this.resolveSession(params.sessionId); const output = await session._handleHooksInvoke(params.hookType, params.input); return { output }; @@ -2076,10 +2390,7 @@ export class CopilotClient { throw new Error("Invalid systemMessage.transform payload"); } - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } + const session = await this.resolveSession(params.sessionId); return await session._handleSystemMessageTransform(params.sections); } diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 6f2a002b1..6484c8593 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -152,6 +152,20 @@ export class CopilotSession { return this._workspacePath; } + /** + * Remote URL for the Mission Control–backed cloud session. + * + * Populated from the `remoteUrl` field in the `session.create` response + * for cloud sessions created via {@link CopilotClient.createCloudSession}. + * Undefined for regular local sessions. + */ + get remoteUrl(): string | undefined { + return this._remoteUrl; + } + + /** @internal Populated by CopilotClient after session.create returns. */ + _remoteUrl?: string; + /** * Host capabilities reported when the session was created or resumed. * Use this to check feature support before calling capability-gated APIs. diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index b9a34c214..e2cfe598a 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -31,7 +31,20 @@ describe("CopilotClient", () => { ); }); - it("forwards cloud options in session.create request", async () => { + it("createSession rejects cloud config", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + await expect( + client.createSession({ + onPermissionRequest: approveAll, + cloud: { repository: { owner: "github", name: "copilot-sdk" } }, + }) + ).rejects.toThrow(/createCloudSession/); + }); + + it("createCloudSession sends session.create with cloud and without sessionId", async () => { const client = new CopilotClient(); await client.start(); onTestFinished(() => client.forceStop()); @@ -39,20 +52,254 @@ describe("CopilotClient", () => { const spy = vi .spyOn((client as any).connection!, "sendRequest") .mockResolvedValue({ sessionId: "cloud-session" }); - await client.createSession({ + const session = await client.createCloudSession({ onPermissionRequest: approveAll, cloud: { repository: { owner: "github", name: "copilot-sdk", branch: "main" }, }, }); - expect(spy).toHaveBeenCalledWith( - "session.create", - expect.objectContaining({ - cloud: { - repository: { owner: "github", name: "copilot-sdk", branch: "main" }, - }, - }) + expect(session.sessionId).toBe("cloud-session"); + const call = spy.mock.calls.find(([m]) => m === "session.create"); + expect(call).toBeDefined(); + const payload = call![1] as Record; + expect(payload.cloud).toEqual({ + repository: { owner: "github", name: "copilot-sdk", branch: "main" }, + }); + // sessionId must be omitted: the runtime assigns it on the cloud path. + expect("sessionId" in payload).toBe(false); + }); + + it("createCloudSession rejects caller-provided sessionId", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + await expect( + client.createCloudSession({ + onPermissionRequest: approveAll, + sessionId: "caller-id", + cloud: { repository: { owner: "github", name: "copilot-sdk" } }, + } as never) + ).rejects.toThrow(/sessionId/); + }); + + it("createCloudSession rejects caller-provided provider", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + await expect( + client.createCloudSession({ + onPermissionRequest: approveAll, + provider: { baseUrl: "https://example.com", apiKey: "k" } as never, + cloud: { repository: { owner: "github", name: "copilot-sdk" } }, + } as never) + ).rejects.toThrow(/provider/); + }); + + it("createCloudSession requires cloud option", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + await expect( + client.createCloudSession({ onPermissionRequest: approveAll } as never) + ).rejects.toThrow(/cloud/); + }); + + it("createCloudSession buffers early session.event notifications until registration", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + // Capture the assigned event handler so we can fire one before the + // sendRequest promise resolves. + let dispatchEvent: + | ((event: { type: string; data: Record }) => void) + | undefined; + let resolveCreate: ((value: unknown) => void) | undefined; + vi.spyOn((client as any).connection!, "sendRequest").mockImplementation( + (method: unknown) => { + if (method !== "session.create") { + return Promise.resolve({}); + } + return new Promise((resolve) => { + resolveCreate = resolve; + }); + } + ); + + const events: Array<{ type: string }> = []; + const sessionPromise = client.createCloudSession({ + onPermissionRequest: approveAll, + cloud: { repository: { owner: "github", name: "copilot-sdk" } }, + onEvent: (e) => { + events.push(e as { type: string }); + dispatchEvent = undefined; + }, + }); + + // Yield so createCloudSession reaches the in-flight sendRequest. + await new Promise((r) => setImmediate(r)); + + // Simulate the runtime pushing an early session.event addressed to + // the not-yet-registered cloud session id. + ( + client as unknown as { + handleSessionEventNotification: (n: unknown) => void; + } + ).handleSessionEventNotification({ + sessionId: "cloud-session", + event: { type: "user.message", data: { text: "hi" } }, + }); + + // Events should be buffered, not yet delivered. + expect(events).toEqual([]); + + // Now resolve the create response. + resolveCreate!({ sessionId: "cloud-session" }); + await sessionPromise; + + expect(events).toEqual([{ type: "user.message", data: { text: "hi" } }]); + // suppress unused-variable lint + void dispatchEvent; + }); + + it("createCloudSession parks inbound requests until registration", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + let resolveCreate: ((value: unknown) => void) | undefined; + vi.spyOn((client as any).connection!, "sendRequest").mockImplementation( + (method: unknown) => { + if (method !== "session.create") { + return Promise.resolve({}); + } + return new Promise((resolve) => { + resolveCreate = resolve; + }); + } + ); + + const sessionPromise = client.createCloudSession({ + onPermissionRequest: approveAll, + cloud: { repository: { owner: "github", name: "copilot-sdk" } }, + onUserInputRequest: async () => ({ answer: "yes", wasFreeform: false }), + }); + + // Yield so createCloudSession reaches the in-flight sendRequest. + await new Promise((r) => setImmediate(r)); + + // Simulate an inbound userInput.request for the not-yet-registered + // session id arriving during the in-flight session.create. + const parked = ( + client as unknown as { + handleUserInputRequest: (p: { + sessionId: string; + question: string; + }) => Promise<{ answer: string; wasFreeform: boolean }>; + } + ).handleUserInputRequest({ + sessionId: "cloud-session", + question: "ok?", + }); + + // Resolve the create response so the session gets registered. + resolveCreate!({ sessionId: "cloud-session" }); + await sessionPromise; + + await expect(parked).resolves.toEqual({ answer: "yes", wasFreeform: false }); + }); + + it("rejects parked requests with overflow error when buffer cap exceeded", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + let resolveCreate: ((value: unknown) => void) | undefined; + vi.spyOn((client as any).connection!, "sendRequest").mockImplementation( + (method: unknown) => { + if (method !== "session.create") return Promise.resolve({}); + return new Promise((resolve) => { + resolveCreate = resolve; + }); + } + ); + + const sessionPromise = client.createCloudSession({ + onPermissionRequest: approveAll, + cloud: { repository: { owner: "github", name: "copilot-sdk" } }, + onUserInputRequest: async () => ({ answer: "ok", wasFreeform: false }), + }); + await new Promise((r) => setImmediate(r)); + + const handler = ( + client as unknown as { + handleUserInputRequest: (p: { + sessionId: string; + question: string; + }) => Promise<{ answer: string; wasFreeform: boolean }>; + } + ).handleUserInputRequest.bind(client); + + // Park 128 + 1 requests; the oldest must reject with overflow message. + // vscode-jsonrpc translates the rejection into a JSON-RPC error response + // back to the runtime so the request id isn't left hanging. + const parked: Promise[] = []; + for (let i = 0; i < 129; i++) { + parked.push(handler({ sessionId: "cloud-session", question: `q${i}` })); + } + + await expect(parked[0]).rejects.toThrow(/pending session buffer overflow/); + + resolveCreate!({ sessionId: "cloud-session" }); + await sessionPromise; + + // The remaining 128 parked requests resolve after registration. + for (let i = 1; i < 129; i++) { + await expect(parked[i]).resolves.toEqual({ answer: "ok", wasFreeform: false }); + } + }); + + it("rejects parked requests when pending routing ends without registration", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + let rejectCreate: ((err: Error) => void) | undefined; + vi.spyOn((client as any).connection!, "sendRequest").mockImplementation( + (method: unknown) => { + if (method !== "session.create") return Promise.resolve({}); + return new Promise((_resolve, reject) => { + rejectCreate = reject; + }); + } + ); + + const sessionPromise = client.createCloudSession({ + onPermissionRequest: approveAll, + cloud: { repository: { owner: "github", name: "copilot-sdk" } }, + }); + await new Promise((r) => setImmediate(r)); + + const parked = ( + client as unknown as { + handleUserInputRequest: (p: { + sessionId: string; + question: string; + }) => Promise<{ answer: string; wasFreeform: boolean }>; + } + ).handleUserInputRequest({ sessionId: "cloud-session", question: "ok?" }); + + // session.create fails before registration; guard drop must reject all + // parked waiters with a distinct message (separate from overflow). + rejectCreate!(new Error("create failed")); + await expect(sessionPromise).rejects.toThrow(); + + await expect(parked).rejects.toThrow( + /pending session routing ended before session was registered/ ); }); diff --git a/python/README.md b/python/README.md index 3a504f966..44e409bfe 100644 --- a/python/README.md +++ b/python/README.md @@ -414,6 +414,40 @@ When `streaming=True`: Note: `assistant.message` and `assistant.reasoning` (final events) are always sent regardless of streaming setting. +## Cloud Sessions + +Use `create_cloud_session()` to create a Mission Control–backed cloud session. The +runtime owns the session ID for cloud sessions, so omit `session_id` and `provider` +(the SDK raises `ValueError` if either is set). + +```python +from copilot import CopilotClient, RuntimeConnection +from copilot.client import CloudSessionOptions, CloudSessionRepository + +client = CopilotClient(connection=RuntimeConnection.for_stdio(path="/path/to/cli")) +await client.start() + +session = await client.create_cloud_session( + cloud=CloudSessionOptions( + repository=CloudSessionRepository( + owner="my-org", + name="my-repo", + branch="main", + ) + ), + on_event=lambda event: print(event), +) +print(session.remote_url) # URL of the remote cloud session +``` + +`create_cloud_session()` accepts the same keyword arguments as `create_session()` +(tools, streaming, model, hooks, etc.) with two restrictions: +- `session_id` must not be set (the runtime assigns the ID). +- `provider` must not be set (cloud sessions always use the Mission Control provider). + +Early `session.event` notifications and inbound RPC requests that arrive before the +session is fully registered are buffered and replayed once the session is ready. + ## Infinite Sessions By default, sessions use **infinite sessions** which automatically manage context window limits through background compaction and persist state to a workspace directory. diff --git a/python/copilot/client.py b/python/copilot/client.py index a52b8711f..354ff4d65 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -996,6 +996,55 @@ def _extract_transform_callbacks( return wire_payload, callbacks +_PENDING_SESSION_BUFFER_LIMIT = 128 +"""Upper bound on buffered notifications/requests per pending session id. + +Holds traffic that arrives between ``session.create`` being sent and the +SDK learning the runtime-assigned session id from the response (cloud path). +Drop-oldest behaviour is acceptable: cloud handshakes are short and 128 +entries is well above realistic init/replay bursts. +""" + + +class _PendingSessionRoutingGuard: + """RAII guard that keeps pending-routing mode active for a cloud session.create. + + While alive, notifications and inbound requests addressed to session ids + that are not yet registered are buffered instead of dropped, so events + the runtime emits between ``session.create`` and its response are not + lost. Dispose exactly once — either after successful registration (to + replay the buffered messages) or on any error path (to reject parked + request futures so callers don't hang). + """ + + __slots__ = ("_client", "_disposed") + + def __init__(self, client: CopilotClient) -> None: + self._client = client + self._disposed = False + + def dispose(self) -> None: + if self._disposed: + return + self._disposed = True + waiters_to_reject: list[asyncio.Future] = [] + with self._client._sessions_lock: + self._client._pending_routing_count -= 1 + if self._client._pending_routing_count == 0: + self._client._pending_session_events.clear() + for session_waiters in self._client._pending_session_waiters.values(): + waiters_to_reject.extend(session_waiters) + self._client._pending_session_waiters.clear() + for future in waiters_to_reject: + if not future.done(): + # Distinct phrasing from the overflow-eviction path so debugging + # can tell the two cases apart. Matches Rust SDK (commit e0ff254f) + # and TS SDK (commit c167bc3e). + future.set_exception( + ValueError("pending session routing ended before session was registered") + ) + + class CopilotClient: """ Main client for interacting with the Copilot CLI. @@ -1181,6 +1230,10 @@ def __init__( self._state: _ConnectionState = "disconnected" self._sessions: dict[str, CopilotSession] = {} self._sessions_lock = threading.Lock() + # Pending-routing state for create_cloud_session: guarded by _sessions_lock. + self._pending_routing_count: int = 0 + self._pending_session_events: dict[str, list[SessionEvent]] = {} + self._pending_session_waiters: dict[str, list[asyncio.Future[CopilotSession]]] = {} self._models_cache: list[ModelInfo] | None = None self._models_cache_lock = asyncio.Lock() self._lifecycle_handlers: list[SessionLifecycleHandler] = [] @@ -1629,6 +1682,11 @@ async def create_session( ... streaming=True, ... ) """ + if cloud is not None: + raise ValueError( + "CopilotClient.create_session does not support cloud sessions; " + "use create_cloud_session instead." + ) if on_permission_request is not None and not callable(on_permission_request): raise ValueError("on_permission_request must be callable when provided.") if not self._client: @@ -1880,6 +1938,402 @@ async def create_session( ) return session + async def create_cloud_session( + self, + *, + cloud: CloudSessionOptions | None = None, + session_id: str | None = None, + provider: ProviderConfig | None = None, + on_permission_request: _PermissionHandlerFn | None = None, + model: str | None = None, + client_name: str | None = None, + reasoning_effort: ReasoningEffort | None = None, + tools: list[Tool] | None = None, + system_message: SystemMessageConfig | None = None, + available_tools: list[str] | None = None, + excluded_tools: list[str] | None = None, + on_user_input_request: UserInputHandler | None = None, + hooks: SessionHooks | None = None, + working_directory: str | None = None, + enable_session_telemetry: bool | None = None, + model_capabilities: ModelCapabilitiesOverride | None = None, + streaming: bool | None = None, + include_sub_agent_streaming_events: bool | None = None, + mcp_servers: dict[str, MCPServerConfig] | None = None, + custom_agents: list[CustomAgentConfig] | None = None, + default_agent: DefaultAgentConfig | dict[str, Any] | None = None, + agent: str | None = None, + config_dir: str | None = None, + enable_config_discovery: bool | None = None, + skill_directories: list[str] | None = None, + instruction_directories: list[str] | None = None, + disabled_skills: list[str] | None = None, + infinite_sessions: InfiniteSessionConfig | None = None, + on_event: Callable[[SessionEvent], None] | None = None, + commands: list[CommandDefinition] | None = None, + on_elicitation_request: ElicitationHandler | None = None, + on_exit_plan_mode_request: ExitPlanModeHandler | None = None, + on_auto_mode_switch_request: AutoModeSwitchHandler | None = None, + create_session_fs_handler: CreateSessionFsHandler | None = None, + github_token: str | None = None, + remote_session: RemoteSessionMode | None = None, + ) -> CopilotSession: + """ + Create a Mission Control–backed cloud session. + + The runtime owns the session ID for cloud sessions: do **not** set + ``session_id`` or ``provider`` on the call (the SDK rejects both with + :class:`ValueError`). The SDK omits ``sessionId`` from the + ``session.create`` wire payload and registers the resulting session + under the id that the runtime returns. + + Any ``session.event`` notifications or inbound JSON-RPC requests + (``userInput.request``, ``exitPlanMode.request``, etc.) that arrive + between sending ``session.create`` and receiving its response are + buffered (bounded, drop-oldest, up to + ``_PENDING_SESSION_BUFFER_LIMIT`` per id) and replayed once the + returned session id is registered, so early events are not lost. + + **Known limitation:** inbound ``sessionFs.*`` requests (the generated + client-session API handlers) are not pending-buffered. In practice the + runtime does not initiate ``sessionFs.*`` calls before the + ``session.create`` response, so this is theoretical. + + Args: + cloud: Required. Cloud session options (repository, branch, etc.). + session_id: Must be ``None``; the runtime assigns the id. + provider: Must be ``None``; cloud sessions use the runtime's provider. + on_permission_request: Handler for permission requests. + model: Model to use. + client_name: Client name for identification. + reasoning_effort: Reasoning effort level. + tools: Custom tools to register. + system_message: System message configuration. + available_tools: Allowlist of tools. + excluded_tools: Tools to disable. + on_user_input_request: Handler for user input requests. + hooks: Lifecycle hooks. + working_directory: Working directory. + enable_session_telemetry: Enable/disable session telemetry. + model_capabilities: Model capabilities override. + streaming: Enable streaming responses. + include_sub_agent_streaming_events: Include sub-agent streaming events. + mcp_servers: MCP server configurations. + custom_agents: Custom agent configurations. + default_agent: Default agent configuration. + agent: Agent to use. + config_dir: Configuration directory override. + enable_config_discovery: Auto-discover MCP/skill config from cwd. + skill_directories: Directories to search for skills. + instruction_directories: Additional instruction file directories. + disabled_skills: Skills to disable. + infinite_sessions: Infinite session configuration. + on_event: Callback for session events. + commands: Commands to register. + on_elicitation_request: Handler for elicitation requests. + on_exit_plan_mode_request: Handler for exit-plan-mode requests. + on_auto_mode_switch_request: Handler for auto-mode-switch requests. + create_session_fs_handler: Session filesystem handler factory. + github_token: Per-session GitHub token. + remote_session: Remote session mode. + + Returns: + A :class:`CopilotSession` for the cloud session, with its + ``session_id`` set to the runtime-assigned id. + + Raises: + ValueError: If ``cloud`` is ``None``, ``session_id`` is set, or + ``provider`` is set. + + Example: + >>> session = await client.create_cloud_session( + ... cloud=CloudSessionOptions( + ... repository=CloudSessionRepository( + ... owner="github", name="copilot-sdk", branch="main" + ... ) + ... ), + ... ) + >>> print(session.session_id) # runtime-assigned id + """ + if cloud is None: + raise ValueError( + "create_cloud_session requires cloud to be set; " + "use CloudSessionOptions to configure the repository." + ) + if session_id is not None: + raise ValueError( + "create_cloud_session does not accept session_id; the runtime assigns one." + ) + if provider is not None: + raise ValueError( + "create_cloud_session does not accept provider; " + "cloud sessions use the runtime's provider." + ) + + if not self._client: + await self.start() + + tool_defs = [] + if tools: + for tool in tools: + definition: dict[str, Any] = { + "name": tool.name, + "description": tool.description, + } + if tool.parameters: + definition["parameters"] = tool.parameters + if tool.overrides_built_in_tool: + definition["overridesBuiltInTool"] = True + if tool.skip_permission: + definition["skipPermission"] = True + tool_defs.append(definition) + + payload: dict[str, Any] = {} + if model: + payload["model"] = model + if client_name: + payload["clientName"] = client_name + if reasoning_effort: + payload["reasoningEffort"] = reasoning_effort + if tool_defs: + payload["tools"] = tool_defs + + wire_system_message, transform_callbacks = _extract_transform_callbacks(system_message) + if wire_system_message: + payload["systemMessage"] = wire_system_message + + if available_tools is not None: + payload["availableTools"] = available_tools + if excluded_tools is not None: + payload["excludedTools"] = excluded_tools + + payload["requestPermission"] = bool(on_permission_request) + if on_user_input_request: + payload["requestUserInput"] = True + payload["requestElicitation"] = bool(on_elicitation_request) + payload["requestExitPlanMode"] = bool(on_exit_plan_mode_request) + payload["requestAutoModeSwitch"] = bool(on_auto_mode_switch_request) + + if commands: + payload["commands"] = [ + {"name": cmd.name, "description": cmd.description} for cmd in commands + ] + if hooks and any(hooks.values()): + payload["hooks"] = True + if github_token is not None: + payload["gitHubToken"] = github_token + if remote_session is not None: + payload["remoteSession"] = remote_session.value + + # sessionId intentionally omitted: the runtime assigns the id for cloud sessions. + payload["cloud"] = _cloud_session_options_to_dict(cloud) + + if working_directory: + payload["workingDirectory"] = working_directory + if streaming is not None: + payload["streaming"] = streaming + payload["includeSubAgentStreamingEvents"] = ( + include_sub_agent_streaming_events + if include_sub_agent_streaming_events is not None + else True + ) + if enable_session_telemetry is not None: + payload["enableSessionTelemetry"] = enable_session_telemetry + if model_capabilities: + payload["modelCapabilities"] = _capabilities_to_dict(model_capabilities) + if mcp_servers: + payload["mcpServers"] = _mcp_servers_to_wire(mcp_servers) + payload["envValueMode"] = "direct" + if custom_agents: + payload["customAgents"] = [ + self._convert_custom_agent_to_wire_format(a) for a in custom_agents + ] + if default_agent: + payload["defaultAgent"] = self._convert_default_agent_to_wire_format(default_agent) + if agent: + payload["agent"] = agent + if config_dir: + payload["configDir"] = config_dir + if enable_config_discovery is not None: + payload["enableConfigDiscovery"] = enable_config_discovery + if skill_directories: + payload["skillDirectories"] = skill_directories + if instruction_directories is not None: + payload["instructionDirectories"] = instruction_directories + if disabled_skills: + payload["disabledSkills"] = disabled_skills + if infinite_sessions: + wire_config: dict[str, Any] = {} + if "enabled" in infinite_sessions: + wire_config["enabled"] = infinite_sessions["enabled"] + if "background_compaction_threshold" in infinite_sessions: + wire_config["backgroundCompactionThreshold"] = infinite_sessions[ + "background_compaction_threshold" + ] + if "buffer_exhaustion_threshold" in infinite_sessions: + wire_config["bufferExhaustionThreshold"] = infinite_sessions[ + "buffer_exhaustion_threshold" + ] + payload["infiniteSessions"] = wire_config + + if not self._client: + raise RuntimeError("Client not connected") + + trace_ctx = get_trace_context() + payload.update(trace_ctx) + + total_start = time.perf_counter() + guard = self._begin_pending_session_routing() + + try: + rpc_start = time.perf_counter() + response = await self._client.request("session.create", payload) + log_timing( + logger, + logging.DEBUG, + "CopilotClient.create_cloud_session session creation request completed", + rpc_start, + ) + except BaseException: + guard.dispose() + raise + + returned_session_id = response.get("sessionId") + if not isinstance(returned_session_id, str) or not returned_session_id: + logger.warning( + "Cloud session.create response missing sessionId; runtime session may leak" + ) + guard.dispose() + raise ValueError( + "Cloud session.create response did not include a sessionId; " + "cannot register session." + ) + + session = CopilotSession(returned_session_id, self._client, workspace_path=None) + session._register_tools(tools) + session._register_commands(commands) + session._register_permission_handler(on_permission_request) + if on_user_input_request: + session._register_user_input_handler(on_user_input_request) + if on_elicitation_request: + session._register_elicitation_handler(on_elicitation_request) + if on_exit_plan_mode_request: + session._register_exit_plan_mode_handler(on_exit_plan_mode_request) + if on_auto_mode_switch_request: + session._register_auto_mode_switch_handler(on_auto_mode_switch_request) + if hooks: + session._register_hooks(hooks) + if transform_callbacks: + session._register_transform_callbacks(transform_callbacks) + if on_event: + session.on(on_event) + + try: + if self._session_fs_config: + if create_session_fs_handler is None: + raise ValueError( + "create_session_fs_handler is required in session config when " + "session_fs is enabled in client options." + ) + fs_provider: SessionFsProvider = create_session_fs_handler(session) + caps = self._session_fs_config.get("capabilities") + if caps and caps.get("sqlite"): + from .session_fs_provider import SessionFsSqliteProvider + + if not isinstance(fs_provider, SessionFsSqliteProvider): + raise ValueError( + "SessionFs capabilities declare SQLite support but the provider " + "does not implement SessionFsSqliteProvider" + ) + session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) + + with self._sessions_lock: + self._sessions[returned_session_id] = session + + session._workspace_path = response.get("workspacePath") + session._remote_url = response.get("remoteUrl") + capabilities = response.get("capabilities") + session._set_capabilities(capabilities) + self._flush_pending_for_session(returned_session_id, session) + except BaseException: + with self._sessions_lock: + self._sessions.pop(returned_session_id, None) + guard.dispose() + raise + + guard.dispose() + log_timing( + logger, + logging.DEBUG, + "CopilotClient.create_cloud_session complete", + total_start, + session_id=returned_session_id, + ) + return session + + def _begin_pending_session_routing(self) -> _PendingSessionRoutingGuard: + """Enter pending-routing mode; return a guard that exits it on dispose(). + + While at least one guard is alive, ``session.event`` notifications and + inbound JSON-RPC requests addressed to session ids that are not yet + registered are buffered (bounded, drop-oldest) and replayed on + registration. When the last guard is disposed, any still-pending + messages are dropped and parked request futures are rejected so callers + don't hang. + """ + with self._sessions_lock: + self._pending_routing_count += 1 + return _PendingSessionRoutingGuard(self) + + def _flush_pending_for_session(self, session_id: str, session: CopilotSession) -> None: + """Drain buffered events and resolve parked request futures for ``session_id``. + + Called from :meth:`create_cloud_session` after the session has been + registered in ``_sessions`` and before the pending-routing guard is + released. + """ + events_to_dispatch: list[SessionEvent] = [] + waiters_to_resolve: list[asyncio.Future[CopilotSession]] = [] + with self._sessions_lock: + events_to_dispatch = self._pending_session_events.pop(session_id, []) + waiters_to_resolve = self._pending_session_waiters.pop(session_id, []) + for event in events_to_dispatch: + session._dispatch_event(event) + for future in waiters_to_resolve: + if not future.done(): + future.set_result(session) + + async def _resolve_session(self, session_id: str) -> CopilotSession: + """Look up the session for an inbound request. + + If the session is not yet registered but a cloud ``session.create`` is + in flight (pending-routing mode is active), park the caller on a + :class:`asyncio.Future` until the session is registered or pending mode + ends. Otherwise raise :class:`ValueError` immediately. + """ + future: asyncio.Future[CopilotSession] | None = None + evicted: asyncio.Future[CopilotSession] | None = None + with self._sessions_lock: + session = self._sessions.get(session_id) + if session is None and self._pending_routing_count > 0: + loop = asyncio.get_running_loop() + future = loop.create_future() + waiters = self._pending_session_waiters.setdefault(session_id, []) + # Cap parked waiters at the same limit as notifications. When exceeded, + # reject the oldest so the runtime gets a JSON-RPC error response + # (code -32603) rather than hanging on the request id until timeout. + # Matches Rust SDK fix (commit 491b4427) and TS SDK (commit c167bc3e). + if len(waiters) >= _PENDING_SESSION_BUFFER_LIMIT: + evicted = waiters.pop(0) + waiters.append(future) + if evicted is not None and not evicted.done(): + evicted.set_exception(ValueError("pending session buffer overflow")) + if session is not None: + return session + if future is not None: + return await future + raise ValueError(f"unknown session {session_id}") + async def resume_session( self, session_id: str, @@ -2969,6 +3423,16 @@ def handle_notification(method: str, params: dict): event = session_event_from_dict(event_dict) with self._sessions_lock: session = self._sessions.get(session_id) + if session is None and self._pending_routing_count > 0: + buf = self._pending_session_events.setdefault(session_id, []) + if len(buf) >= _PENDING_SESSION_BUFFER_LIMIT: + buf.pop(0) + logger.warning( + "pending session event buffer full for %s; dropping oldest", + session_id, + ) + buf.append(event) + return if session: session._dispatch_event(event) elif method == "session.lifecycle": @@ -3087,7 +3551,18 @@ def handle_notification(method: str, params: dict): event_dict = params["event"] # Convert dict to SessionEvent object event = session_event_from_dict(event_dict) - session = self._sessions.get(session_id) + with self._sessions_lock: + session = self._sessions.get(session_id) + if session is None and self._pending_routing_count > 0: + buf = self._pending_session_events.setdefault(session_id, []) + if len(buf) >= _PENDING_SESSION_BUFFER_LIMIT: + buf.pop(0) + logger.warning( + "pending session event buffer full for %s; dropping oldest", + session_id, + ) + buf.append(event) + return if session: session._dispatch_event(event) elif method == "session.lifecycle": @@ -3153,11 +3628,7 @@ async def _handle_user_input_request(self, params: dict) -> dict: if not session_id or not question: raise ValueError("invalid user input request payload") - with self._sessions_lock: - session = self._sessions.get(session_id) - if not session: - raise ValueError(f"unknown session {session_id}") - + session = await self._resolve_session(session_id) result = await session._handle_user_input_request(params) return {"answer": result["answer"], "wasFreeform": result["wasFreeform"]} @@ -3173,11 +3644,7 @@ async def _handle_exit_plan_mode_request(self, params: dict) -> dict: if not isinstance(actions, list) or not isinstance(recommended_action, str): raise ValueError("invalid exit plan mode request payload") - with self._sessions_lock: - session = self._sessions.get(session_id) - if not session: - raise ValueError(f"unknown session {session_id}") - + session = await self._resolve_session(session_id) return dict(await session._handle_exit_plan_mode_request(params)) async def _handle_auto_mode_switch_request(self, params: dict) -> dict: @@ -3186,11 +3653,7 @@ async def _handle_auto_mode_switch_request(self, params: dict) -> dict: if not session_id: raise ValueError("invalid auto mode switch request payload") - with self._sessions_lock: - session = self._sessions.get(session_id) - if not session: - raise ValueError(f"unknown session {session_id}") - + session = await self._resolve_session(session_id) response = await session._handle_auto_mode_switch_request(params) return {"response": response} @@ -3214,11 +3677,7 @@ async def _handle_hooks_invoke(self, params: dict) -> dict: if not session_id or not hook_type: raise ValueError("invalid hooks invoke payload") - with self._sessions_lock: - session = self._sessions.get(session_id) - if not session: - raise ValueError(f"unknown session {session_id}") - + session = await self._resolve_session(session_id) output = await session._handle_hooks_invoke(hook_type, input_data) return {"output": output} @@ -3230,9 +3689,5 @@ async def _handle_system_message_transform(self, params: dict) -> dict: if not session_id or not sections: raise ValueError("invalid systemMessage.transform payload") - with self._sessions_lock: - session = self._sessions.get(session_id) - if not session: - raise ValueError(f"unknown session {session_id}") - + session = await self._resolve_session(session_id) return await session._handle_system_message_transform(sections) diff --git a/python/copilot/session.py b/python/copilot/session.py index c775ef58e..7bee0530a 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -1001,6 +1001,7 @@ def __init__( self.session_id = session_id self._client = client self._workspace_path = os.fsdecode(workspace_path) if workspace_path is not None else None + self._remote_url: str | None = None self._event_handlers: set[Callable[[SessionEvent], None]] = set() self._event_handlers_lock = threading.Lock() self._tool_handlers: dict[str, ToolHandler] = {} @@ -1069,6 +1070,16 @@ def workspace_path(self) -> pathlib.Path | None: # attribute to do the conversion, or just do the conversion lazily via a getter. return pathlib.Path(self._workspace_path) if self._workspace_path else None + @property + def remote_url(self) -> str | None: + """Remote URL for the Mission Control–backed cloud session. + + Set from the ``remoteUrl`` field in the ``session.create`` response for + cloud sessions created via :meth:`CopilotClient.create_cloud_session`. + ``None`` for regular local sessions. + """ + return self._remote_url + async def send( self, prompt: str, diff --git a/python/test_client.py b/python/test_client.py index 14320b3a2..9869d9967 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -63,7 +63,7 @@ async def test_resume_session_allows_none_permission_handler(self): class TestCreateSessionConfig: @pytest.mark.asyncio - async def test_create_session_forwards_cloud_options(self): + async def test_create_cloud_session_forwards_cloud_options(self): client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) await client.start() try: @@ -72,12 +72,11 @@ async def test_create_session_forwards_cloud_options(self): async def mock_request(method, params): captured[method] = params if method == "session.create": - return {"sessionId": params["sessionId"], "workspacePath": None} + return {"sessionId": "cloud-session-id", "workspacePath": None} return {} client._client.request = mock_request - await client.create_session( - on_permission_request=PermissionHandler.approve_all, + await client.create_cloud_session( cloud=CloudSessionOptions( repository=CloudSessionRepository( owner="github", @@ -87,6 +86,9 @@ async def mock_request(method, params): ), ) + assert "sessionId" not in captured["session.create"], ( + "cloud sessions must not send a sessionId in the wire payload" + ) assert captured["session.create"]["cloud"] == { "repository": { "owner": "github", diff --git a/python/test_cloud_session.py b/python/test_cloud_session.py new file mode 100644 index 000000000..fae22d65f --- /dev/null +++ b/python/test_cloud_session.py @@ -0,0 +1,410 @@ +""" +Tests for CopilotClient.create_cloud_session. + +Ports the spirit of the Rust integration tests in rust/tests/session_test.rs. +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime +from uuid import uuid4 + +import pytest + +from copilot import CopilotClient, RuntimeConnection +from copilot.client import ( + _PENDING_SESSION_BUFFER_LIMIT, + CloudSessionOptions, + CloudSessionRepository, +) +from copilot.session import ProviderConfig, UserInputRequest, UserInputResponse +from e2e.testharness import CLI_PATH + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _cloud_config() -> dict: + return dict( + cloud=CloudSessionOptions( + repository=CloudSessionRepository(owner="github", name="copilot-sdk", branch="main") + ) + ) + + +def _make_event_dict(event_type: str = "session.buffered_test", data: dict | None = None) -> dict: + """Build a minimal valid session-event dict for injection in tests.""" + return { + "id": str(uuid4()), + "timestamp": datetime.now().isoformat(), + "parentId": None, + "type": event_type, + "data": data or {}, + } + + +# --------------------------------------------------------------------------- +# Test 1: create_session rejects cloud config +# --------------------------------------------------------------------------- + + +class TestCreateSessionRejectsCloud: + @pytest.mark.asyncio + async def test_create_session_rejects_cloud_config(self): + """create_session must raise ValueError mentioning create_cloud_session.""" + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + with pytest.raises(ValueError, match="create_cloud_session"): + await client.create_session(**_cloud_config()) + finally: + await client.force_stop() + + +# --------------------------------------------------------------------------- +# Test 2: wire shape — sessionId omitted, cloud set, returned id used +# --------------------------------------------------------------------------- + + +class TestCreateCloudSessionWireShape: + @pytest.mark.asyncio + async def test_sends_cloud_without_session_id(self): + """session.create must carry cloud but omit sessionId; the response id is used.""" + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: dict = {} + + async def mock_request(method, params): + captured[method] = params + if method == "session.create": + return { + "sessionId": "remote-cloud-session", + "remoteUrl": "https://copilot.example.test/agents/remote-cloud-session", + "capabilities": {"ui": {"elicitation": True}}, + } + return {} + + client._client.request = mock_request + session = await client.create_cloud_session(**_cloud_config()) + + wire = captured["session.create"] + assert "sessionId" not in wire, "sessionId must be omitted from cloud create" + assert wire["cloud"]["repository"]["owner"] == "github" + assert wire["cloud"]["repository"]["name"] == "copilot-sdk" + assert wire["cloud"]["repository"]["branch"] == "main" + assert "provider" not in wire + + assert session.session_id == "remote-cloud-session" + assert session.remote_url == "https://copilot.example.test/agents/remote-cloud-session" + assert session.capabilities.get("ui", {}).get("elicitation") is True + finally: + await client.force_stop() + + +# --------------------------------------------------------------------------- +# Test 3: rejects caller-provided session_id +# --------------------------------------------------------------------------- + + +class TestCreateCloudSessionRejectsSessionId: + @pytest.mark.asyncio + async def test_rejects_caller_session_id(self): + """Passing session_id must raise ValueError naming session_id.""" + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + with pytest.raises(ValueError, match="session_id"): + await client.create_cloud_session(**_cloud_config(), session_id="caller-id") + + +# --------------------------------------------------------------------------- +# Test 4: rejects caller-provided provider +# --------------------------------------------------------------------------- + + +class TestCreateCloudSessionRejectsProvider: + @pytest.mark.asyncio + async def test_rejects_caller_provider(self): + """Passing provider must raise ValueError naming provider.""" + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + with pytest.raises(ValueError, match="provider"): + await client.create_cloud_session( + **_cloud_config(), + provider=ProviderConfig(type="openai", base_url="https://api.example.test/v1"), + ) + + +# --------------------------------------------------------------------------- +# Test 5: requires cloud +# --------------------------------------------------------------------------- + + +class TestCreateCloudSessionRequiresCloud: + @pytest.mark.asyncio + async def test_requires_cloud(self): + """Omitting cloud (or passing None) must raise ValueError mentioning cloud.""" + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + with pytest.raises(ValueError, match="cloud"): + await client.create_cloud_session() + + +# --------------------------------------------------------------------------- +# Test 6: buffers early session.event notifications +# --------------------------------------------------------------------------- + + +class TestCreateCloudSessionBuffersEarlyNotifications: + @pytest.mark.asyncio + async def test_early_notifications_dispatched_after_registration(self): + """session.event notifications arriving before registration are buffered and replayed.""" + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + create_response_gate: asyncio.Future[dict] = asyncio.get_event_loop().create_future() + + async def mock_request(method, params): + if method == "session.create": + return await create_response_gate + return {} + + client._client.request = mock_request + + session_id = "remote-cloud-session" + received_events: list = [] + + create_task = asyncio.ensure_future( + client.create_cloud_session( + **_cloud_config(), + on_event=lambda e: received_events.append(e), + ) + ) + + # Yield control so create_cloud_session enters pending-routing mode. + await asyncio.sleep(0) + await asyncio.sleep(0) + + # Inject a session.event notification while the create is in flight. + notification_handler = client._client.notification_handler + assert notification_handler is not None, "notification handler not registered" + notification_handler( + "session.event", + { + "sessionId": session_id, + "event": _make_event_dict(), + }, + ) + + # Verify it is buffered (not yet dispatched — session not registered yet). + await asyncio.sleep(0) + assert not received_events, "event dispatched before session was registered" + + # Allow session.create to respond; this registers the session. + create_response_gate.set_result({"sessionId": session_id}) + await asyncio.wait_for(create_task, timeout=5.0) + + # Give the event loop a tick to flush the buffered event. + await asyncio.sleep(0) + + assert len(received_events) == 1, ( + f"expected 1 buffered event to be replayed, got {len(received_events)}" + ) + # Our synthetic event uses an unknown type; just confirm it was dispatched. + assert received_events[0].raw_type == "session.buffered_test" + finally: + await client.force_stop() + + +# --------------------------------------------------------------------------- +# Test 7: parks inbound requests until registration +# --------------------------------------------------------------------------- + + +class TestCreateCloudSessionParksInboundRequests: + @pytest.mark.asyncio + async def test_parked_user_input_resolves_after_registration(self): + """userInput.request that arrives before registration is parked, then resolved.""" + answered: list[str] = [] + + async def color_picker(request: UserInputRequest, context: dict) -> UserInputResponse: + answered.append(request["question"]) + return UserInputResponse(answer="blue", wasFreeform=True) + + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + create_response_gate: asyncio.Future[dict] = asyncio.get_event_loop().create_future() + + async def mock_request(method, params): + if method == "session.create": + return await create_response_gate + return {} + + client._client.request = mock_request + + session_id = "remote-cloud-session" + create_task = asyncio.ensure_future( + client.create_cloud_session(**_cloud_config(), on_user_input_request=color_picker) + ) + + # Yield so pending-routing mode is entered. + await asyncio.sleep(0) + await asyncio.sleep(0) + + # Dispatch a userInput.request while the create is in flight. + user_input_handler = client._client.request_handlers.get("userInput.request") + assert user_input_handler is not None, "userInput.request handler not registered" + + input_task = asyncio.ensure_future( + user_input_handler( + { + "sessionId": session_id, + "question": "Pick a color", + "choices": ["red", "blue"], + "allowFreeform": True, + } + ) + ) + + # Yield to let the handler park on the pending future. + await asyncio.sleep(0) + assert not input_task.done(), "handler should be parked waiting for session" + + # Now let the create response arrive; this registers the session. + create_response_gate.set_result({"sessionId": session_id}) + await asyncio.wait_for(create_task, timeout=5.0) + + # The parked userInput handler should now complete. + result = await asyncio.wait_for(input_task, timeout=5.0) + assert result["answer"] == "blue" + assert result["wasFreeform"] is True + assert answered == ["Pick a color"] + finally: + await client.force_stop() + + +# --------------------------------------------------------------------------- +# Test 8: pending request buffer overflow emits an error (not silent drop) +# --------------------------------------------------------------------------- + + +class TestPendingRequestBufferOverflow: + @pytest.mark.asyncio + async def test_oldest_waiter_rejected_on_overflow(self): + """When the parked-request buffer is full, the oldest waiter is rejected. + + The rejection causes the JSON-RPC dispatch layer to send a JSON-RPC error + response (code -32603) rather than silently hanging the runtime on that + request id. The remaining _PENDING_SESSION_BUFFER_LIMIT waiters resolve + normally once the session is registered. + """ + session_id = "overflow-session" + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + # Enter pending-routing mode manually so _resolve_session parks futures. + guard = client._begin_pending_session_routing() + + total = _PENDING_SESSION_BUFFER_LIMIT + 1 # 129 concurrent waiters + tasks = [ + asyncio.ensure_future(client._resolve_session(session_id)) for _ in range(total) + ] + + # Yield so all _resolve_session calls park on futures. + await asyncio.sleep(0) + await asyncio.sleep(0) + + # The oldest (tasks[0]) should now be rejected with the overflow message. + assert tasks[0].done(), "oldest waiter should have been rejected synchronously" + with pytest.raises(ValueError, match="pending session buffer overflow"): + tasks[0].result() + + # The remaining 128 are still parked. + assert all(not t.done() for t in tasks[1:]), "remaining waiters should still be parked" + + # Register the session so the remaining waiters resolve. + from copilot.session import CopilotSession + + session = CopilotSession(session_id, client._client, workspace_path=None) + with client._sessions_lock: + client._sessions[session_id] = session + client._flush_pending_for_session(session_id, session) + guard.dispose() + + # Let the event loop settle. + await asyncio.sleep(0) + + resolved_sessions = await asyncio.gather(*tasks[1:], return_exceptions=True) + assert all(s is session for s in resolved_sessions), ( + "all remaining parked waiters should resolve to the registered session" + ) + finally: + await client.force_stop() + + +# --------------------------------------------------------------------------- +# Test 9: guard drop without registration rejects parked requests +# --------------------------------------------------------------------------- + + +class TestPendingRequestGuardDropWithoutRegistration: + @pytest.mark.asyncio + async def test_parked_request_rejected_when_create_fails(self): + """When session.create fails, parked request waiters get a distinct error. + + The error message "pending session routing ended before session was registered" + must differ from the overflow message so the two failure modes are + distinguishable in logs and the runtime gets a proper JSON-RPC error + response rather than hanging. + """ + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + create_response_gate: asyncio.Future[dict] = asyncio.get_event_loop().create_future() + + async def mock_request(method, params): + if method == "session.create": + return await create_response_gate + return {} + + client._client.request = mock_request + + session_id = "failing-cloud-session" + create_task = asyncio.ensure_future(client.create_cloud_session(**_cloud_config())) + + # Yield so create_cloud_session enters pending-routing mode. + await asyncio.sleep(0) + await asyncio.sleep(0) + + # Park an inbound request while the create is in flight. + user_input_handler = client._client.request_handlers.get("userInput.request") + assert user_input_handler is not None, "userInput.request handler not registered" + + input_task = asyncio.ensure_future( + user_input_handler( + { + "sessionId": session_id, + "question": "Pick a color", + "choices": ["red", "blue"], + "allowFreeform": True, + } + ) + ) + + await asyncio.sleep(0) + assert not input_task.done(), "handler should be parked waiting for session" + + # Make session.create fail; this causes create_cloud_session to call + # guard.dispose() without registering any session id. + create_response_gate.set_exception(RuntimeError("simulated session.create failure")) + with pytest.raises(RuntimeError, match="simulated session.create failure"): + await asyncio.wait_for(create_task, timeout=5.0) + + # The parked waiter should now be rejected with the routing-ended message. + await asyncio.sleep(0) + assert input_task.done(), "parked waiter should be rejected after guard drop" + expected_msg = "pending session routing ended before session was registered" + with pytest.raises(ValueError, match=expected_msg): + await input_task + finally: + await client.force_stop()