diff --git a/dotnet/README.md b/dotnet/README.md index a9527f447..b7bea362a 100644 --- a/dotnet/README.md +++ b/dotnet/README.md @@ -133,6 +133,47 @@ Resume an existing session. Returns the session with `WorkspacePath` populated i - `OnPermissionRequest` - Optional handler called before each tool execution to approve or deny it. See [Permission Handling](#permission-handling) section. +##### `CreateCloudSessionAsync(CloudSessionConfig config): Task` + +Create a session that is hosted by and routed to a remote Copilot cloud server rather than the local CLI. The returned `CopilotSession` behaves identically to a locally-created session: send messages, receive events, and use tools through the same API. + +A `Cloud` configuration is required; passing `Cloud = null` throws `ArgumentException`. Conversely, passing a `Cloud` config to the regular `CreateSessionAsync` also throws, preventing accidental misconfiguration. + +The session ID is assigned by the cloud server. After `CreateCloudSessionAsync` returns, `session.SessionId` contains the server-assigned ID and `session.RemoteUrl` contains the URL of the cloud server that owns the session. Any RPC requests (user-input, permission, hooks, etc.) that arrive from the server before the session is fully registered are buffered and replayed automatically — your handlers will never miss an early request. + +**CloudSessionConfig:** + +- `Cloud` *(required)* — Cloud routing options: + - `Repository` *(required)* — The repository to route the session to: + - `Nwo` *(required)* — Full `owner/repo` name (e.g. `"github/my-repo"`) + - `Ref` *(optional)* — Branch or tag ref (e.g. `"refs/heads/main"`) + - `SdkOptions` *(optional)* — Pass-through options forwarded to the cloud provider +- `Model`, `ReasoningEffort`, `Tools`, `SystemMessage`, `AvailableTools`, `ExcludedTools`, `Provider`, `Streaming`, `InfiniteSessions`, `OnPermissionRequest`, `OnUserInputRequest`, `Hooks` — Same as `SessionConfig` + +**Returned session properties:** + +- `SessionId` — Server-assigned session ID +- `RemoteUrl` — URL of the cloud server that owns this session + +```csharp +await using var session = await client.CreateCloudSessionAsync(new CloudSessionConfig +{ + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Nwo = "github/my-repo" } + }, + OnUserInputRequest = async (request, _) => + { + Console.Write($"{request.Question} "); + return new UserInputResponse { Answer = Console.ReadLine()! }; + } +}); + +Console.WriteLine($"Cloud session: {session.SessionId} at {session.RemoteUrl}"); +var reply = await session.SendAsync(new MessageOptions { Content = "Hello from the cloud!" }); +Console.WriteLine(reply); +``` + ##### `PingAsync(string? message = null): Task` Ping the server to check connectivity. @@ -193,6 +234,7 @@ Represents a single conversation session. - `SessionId` - The unique identifier for this session - `WorkspacePath` - Path to the session workspace directory when infinite sessions are enabled. Contains `checkpoints/`, `plan.md`, and `files/` subdirectories. Null if infinite sessions are disabled. +- `RemoteUrl` - URL of the cloud server that owns this session. Non-null only for sessions created via `CreateCloudSessionAsync`. #### Methods diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index a5cc62354..6d90a668e 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -83,6 +83,16 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private readonly object _lifecycleHandlersLock = new(); private ServerRpc? _serverRpc; + // Pending-routing state for cloud session.create in-flight windows. + // While _pendingRoutingCount > 0, notifications and inbound requests + // addressed to not-yet-registered session ids are buffered and replayed + // once the runtime-assigned id is registered. + private int _pendingRoutingCount; + private readonly Dictionary> _pendingSessionEvents = new(); + private readonly Dictionary>> _pendingSessionWaiters = new(); + private readonly object _pendingLock = new(); + private const int PendingSessionBufferLimit = 128; + private sealed record LifecycleSubscription(Type EventType, Action Handler); /// @@ -520,6 +530,12 @@ public async Task CreateSessionAsync(SessionConfig config, Cance { ArgumentNullException.ThrowIfNull(config); + if (config.Cloud != null) + { + throw new InvalidOperationException( + "CopilotClient.CreateSessionAsync does not support cloud sessions; use CreateCloudSessionAsync instead."); + } + var connection = await EnsureConnectedAsync(cancellationToken); var totalTimestamp = Stopwatch.GetTimestamp(); @@ -651,6 +667,359 @@ public async Task CreateSessionAsync(SessionConfig config, Cance return session; } + /// + /// Creates a Mission Control–backed cloud session. + /// + /// + /// + /// The runtime owns the session ID for cloud sessions: do not set + /// or on the + /// config (the SDK rejects both). The SDK omits sessionId from the + /// session.create wire payload and registers the resulting session under the id + /// the runtime returns. + /// + /// + /// Any session.event notifications or inbound JSON-RPC requests that arrive between + /// sending session.create and receiving its response are buffered (bounded, + /// drop-oldest, limit 128 per id) and replayed once the returned id is registered, + /// so early events are not lost. + /// + /// + /// Known limitation: inbound sessionFs.* requests are not pending-buffered. + /// In practice the runtime does not initiate sessionFs.* before the + /// session.create response, so this is theoretical. + /// + /// + /// Configuration for the cloud session. is required. + /// A that can be used to cancel the operation. + /// A task that resolves to the created . + /// Thrown when is null. + /// Thrown when is null, or when + /// or is set on the config. + /// + /// + /// var session = await client.CreateCloudSessionAsync(new SessionConfig + /// { + /// OnPermissionRequest = PermissionHandler.ApproveAll, + /// Cloud = new CloudSessionOptions + /// { + /// Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk", Branch = "main" } + /// } + /// }); + /// Console.WriteLine($"Cloud session id: {session.SessionId}"); + /// + /// + public async Task CreateCloudSessionAsync(SessionConfig config, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(config); + + if (config.Cloud == null) + { + throw new InvalidOperationException( + "CopilotClient.CreateCloudSessionAsync requires config.Cloud to be set."); + } + if (!string.IsNullOrEmpty(config.SessionId)) + { + throw new InvalidOperationException( + "CopilotClient.CreateCloudSessionAsync does not support a caller-provided SessionId; the runtime assigns one."); + } + if (config.Provider != null) + { + throw new InvalidOperationException( + "CopilotClient.CreateCloudSessionAsync does not support config.Provider; cloud sessions use the runtime's provider."); + } + + var connection = await EnsureConnectedAsync(cancellationToken); + var totalTimestamp = Stopwatch.GetTimestamp(); + + var hasHooks = config.Hooks != null && ( + config.Hooks.OnPreToolUse != null || + config.Hooks.OnPreMcpToolCall != null || + config.Hooks.OnPostToolUse != null || + config.Hooks.OnUserPromptSubmitted != null || + config.Hooks.OnSessionStart != null || + config.Hooks.OnSessionEnd != null || + config.Hooks.OnErrorOccurred != null); + + var (wireSystemMessage, transformCallbacks) = ExtractTransformCallbacks(config.SystemMessage); + + // Begin pending-routing mode so notifications/requests that arrive + // before the runtime returns the session id are buffered. + var guard = BeginPendingSessionRouting(); + + CreateSessionResponse response; + try + { + var (traceparent, tracestate) = TelemetryHelpers.GetTraceContext(); + + // sessionId is intentionally omitted (null) on the cloud path: + // the runtime assigns the Mission Control session id. + var request = new CreateSessionRequest( + config.Model, + null /* sessionId omitted */, + config.ClientName, + config.ReasoningEffort, + config.Tools?.Select(ToolDefinition.FromAIFunction).ToList(), + wireSystemMessage, + config.AvailableTools, + config.ExcludedTools, + null /* Provider not allowed on cloud path */, + config.EnableSessionTelemetry, + config.OnPermissionRequest != null ? true : null, + config.OnUserInputRequest != null ? true : null, + config.OnExitPlanModeRequest != null ? true : null, + config.OnAutoModeSwitchRequest != null ? true : null, + hasHooks ? true : null, + config.WorkingDirectory, + config.Streaming is true ? true : null, + config.IncludeSubAgentStreamingEvents, + config.McpServers, + "direct", + config.CustomAgents, + config.DefaultAgent, + config.Agent, + config.ConfigDir, + config.EnableConfigDiscovery, + config.SkillDirectories, + config.DisabledSkills, + config.InfiniteSessions, + Commands: config.Commands?.Select(c => new CommandWireDefinition(c.Name, c.Description)).ToList(), + RequestElicitation: config.OnElicitationRequest != null, + Traceparent: traceparent, + Tracestate: tracestate, + ModelCapabilities: config.ModelCapabilities, + GitHubToken: config.GitHubToken, + RemoteSession: config.RemoteSession, + Cloud: config.Cloud, + InstructionDirectories: config.InstructionDirectories); + + response = await InvokeRpcAsync( + connection.Rpc, "session.create", [request], cancellationToken); + } + catch (Exception ex) + { + guard.Dispose(); + if (ex is not OperationCanceledException) + { + LoggingHelpers.LogTiming(_logger, LogLevel.Warning, ex, + "CopilotClient.CreateCloudSessionAsync failed during session.create RPC. Elapsed={Elapsed}", + totalTimestamp); + } + throw; + } + + if (string.IsNullOrEmpty(response.SessionId)) + { + // No id to issue session.destroy against; release the guard and surface the error. + // Any runtime session created on the other side may leak. + _logger.LogWarning("Cloud session.create response missing sessionId; runtime session may leak."); + guard.Dispose(); + throw new InvalidOperationException( + "Cloud session.create response did not include a sessionId; cannot register session."); + } + + var sessionId = response.SessionId; + var setupTimestamp = Stopwatch.GetTimestamp(); + var session = new CopilotSession(sessionId, connection.Rpc, _logger, this); + session.RegisterTools(config.Tools ?? []); + session.RegisterPermissionHandler(config.OnPermissionRequest); + session.RegisterCommands(config.Commands); + session.RegisterElicitationHandler(config.OnElicitationRequest); + session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest); + session.RegisterAutoModeSwitchHandler(config.OnAutoModeSwitchRequest); + if (config.OnUserInputRequest != null) + { + session.RegisterUserInputHandler(config.OnUserInputRequest); + } + if (config.Hooks != null) + { + session.RegisterHooks(config.Hooks); + } + if (transformCallbacks != null) + { + session.RegisterTransformCallbacks(transformCallbacks); + } + if (config.OnEvent != null) + { + session.On(config.OnEvent); + } + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, + "CopilotClient.CreateCloudSessionAsync local setup complete. Elapsed={Elapsed}, SessionId={SessionId}", + setupTimestamp, + sessionId); + + try + { + RegisterSession(session); + ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider); + session.StartProcessingEvents(); + session.WorkspacePath = response.WorkspacePath; + session.SetCapabilities(response.Capabilities); + session.RemoteUrl = response.RemoteUrl; + + // Flush buffered notifications and unblock parked request waiters + // now that the session is registered. Must happen before the guard + // is released so nothing races into a still-pending buffer. + FlushPendingForSession(sessionId, session); + } + catch (Exception ex) + { + session.RemoveFromClient(); + guard.Dispose(); + // Destroy the cloud session on the server — we already have the sessionId from the + // session.create response, so we can send session.destroy even though setup failed. + try { await session.DisposeAsync().ConfigureAwait(false); } catch { /* best effort */ } + LoggingHelpers.LogTiming(_logger, LogLevel.Warning, ex, + "CopilotClient.CreateCloudSessionAsync failed during post-response setup. Elapsed={Elapsed}, SessionId={SessionId}", + totalTimestamp, + sessionId); + throw; + } + + guard.Dispose(); + + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, + "CopilotClient.CreateCloudSessionAsync complete. Elapsed={Elapsed}, SessionId={SessionId}", + totalTimestamp, + sessionId); + return session; + } + + /// + /// Enter pending-routing mode. While the returned guard is undisposed, + /// notifications and inbound requests addressed to session ids that are + /// not yet registered are buffered (up to + /// per id, drop-oldest) and replayed on registration. When the last guard is + /// disposed without registration, buffered events are dropped and parked request + /// waiters are faulted with an . + /// + private PendingSessionRoutingGuard BeginPendingSessionRouting() + { + lock (_pendingLock) + { + _pendingRoutingCount++; + } + return new PendingSessionRoutingGuard(this); + } + + private void EndPendingSessionRouting() + { + List>? waiters = null; + lock (_pendingLock) + { + _pendingRoutingCount--; + if (_pendingRoutingCount == 0) + { + _pendingSessionEvents.Clear(); + waiters = new List>(); + foreach (var list in _pendingSessionWaiters.Values) + waiters.AddRange(list); + _pendingSessionWaiters.Clear(); + } + } + + // Fault pending waiters outside the lock so TCS continuations don't run under it. + // Distinct phrasing from the overflow-eviction path so the runtime / debugging can tell + // the two cases apart. Matches the Rust SDK message in PR #1394 (commit e0ff254f). + if (waiters != null) + { + foreach (var tcs in waiters) + { + tcs.TrySetException(new InvalidOperationException( + "pending session routing ended before session was registered")); + } + } + } + + private void FlushPendingForSession(string sessionId, CopilotSession session) + { + List events; + List> waiters; + + lock (_pendingLock) + { + _pendingSessionEvents.TryGetValue(sessionId, out var rawEvents); + _pendingSessionEvents.Remove(sessionId); + events = rawEvents ?? []; + + _pendingSessionWaiters.TryGetValue(sessionId, out var rawWaiters); + _pendingSessionWaiters.Remove(sessionId); + waiters = rawWaiters ?? []; + } + + foreach (var evt in events) + { + session.DispatchEvent(evt); + } + + // Resolve waiters outside the lock so TCS continuations don't run under it. + foreach (var tcs in waiters) + { + tcs.TrySetResult(session); + } + } + + private Task ResolveSessionAsync(string sessionId) + { + var session = GetSession(sessionId); + if (session != null) + { + return Task.FromResult(session); + } + + lock (_pendingLock) + { + // Re-check inside lock to avoid the race where the session was registered + // between the unlock-free GetSession call and acquiring _pendingLock. + session = GetSession(sessionId); + if (session != null) + { + return Task.FromResult(session); + } + + if (_pendingRoutingCount > 0) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + if (!_pendingSessionWaiters.TryGetValue(sessionId, out var list)) + { + list = []; + _pendingSessionWaiters[sessionId] = list; + } + + // Cap parked waiters per session id. When exceeded, reject the oldest with a + // distinct message so the runtime isn't left waiting on a hung request id. + // RunContinuationsAsynchronously ensures the continuation won't run under this lock. + // Matches the Rust SDK fix in PR #1394 (commit 491b4427). + if (list.Count >= PendingSessionBufferLimit) + { + var oldest = list[0]; + list.RemoveAt(0); + oldest.TrySetException(new InvalidOperationException("pending session buffer overflow")); + } + + list.Add(tcs); + return tcs.Task; + } + } + + return Task.FromException(new ArgumentException($"Unknown session {sessionId}")); + } + + private sealed class PendingSessionRoutingGuard : IDisposable + { + private readonly CopilotClient _client; + private bool _disposed; + + internal PendingSessionRoutingGuard(CopilotClient client) => _client = client; + + public void Dispose() + { + if (_disposed) return; + _disposed = true; + _client.EndPendingSessionRouting(); + } + } + /// /// Resumes an existing Copilot session with the specified configuration. /// @@ -1710,14 +2079,45 @@ private class RpcHandler(CopilotClient client) { public void OnSessionEvent(string sessionId, JsonElement? @event) { + if (@event == null) return; + var session = client.GetSession(sessionId); - if (session != null && @event != null) + if (session != null) { var evt = SessionEvent.FromJson(@event.Value.GetRawText()); if (evt != null) { session.DispatchEvent(evt); } + return; + } + + // Session not yet registered — buffer if pending routing is active. + lock (client._pendingLock) + { + if (client._pendingRoutingCount == 0) + { + return; + } + + var parsed = SessionEvent.FromJson(@event.Value.GetRawText()); + if (parsed == null) return; + + if (!client._pendingSessionEvents.TryGetValue(sessionId, out var buf)) + { + buf = []; + client._pendingSessionEvents[sessionId] = buf; + } + + if (buf.Count >= PendingSessionBufferLimit) + { + buf.RemoveAt(0); + client._logger.LogWarning( + "Pending session event buffer full for session {SessionId}; dropping oldest event.", + sessionId); + } + + buf.Add(parsed); } } @@ -1747,7 +2147,7 @@ public void OnSessionLifecycle(string type, string sessionId, JsonElement? metad public async ValueTask OnUserInputRequest(string sessionId, string question, IList? choices = null, bool? allowFreeform = null) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); var request = new UserInputRequest { Question = question, @@ -1766,7 +2166,7 @@ public async ValueTask OnExitPlanModeRequest( IList? actions = null, string? recommendedAction = null) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); var request = new ExitPlanModeRequest { Summary = summary, @@ -1783,7 +2183,7 @@ public async ValueTask OnAutoModeSwitchRequest( string? errorCode = null, double? retryAfterSeconds = null) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); var response = await session.HandleAutoModeSwitchRequestAsync(new AutoModeSwitchRequest { ErrorCode = errorCode, @@ -1794,14 +2194,14 @@ public async ValueTask OnAutoModeSwitchRequest( public async ValueTask OnHooksInvoke(string sessionId, string hookType, JsonElement input) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); var output = await session.HandleHooksInvokeAsync(hookType, input); return new HooksInvokeResponse(output); } public async ValueTask OnSystemMessageTransform(string sessionId, JsonElement sections) { - var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + var session = await client.ResolveSessionAsync(sessionId); return await session.HandleSystemMessageTransformAsync(sections); } } @@ -1888,7 +2288,8 @@ public static ToolDefinition FromAIFunction(AIFunctionDeclaration function) internal record CreateSessionResponse( string SessionId, string? WorkspacePath, - SessionCapabilities? Capabilities = null); + SessionCapabilities? Capabilities = null, + string? RemoteUrl = null); internal record ResumeSessionRequest( string SessionId, diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 6ad8e14d9..646cc256f 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -107,6 +107,15 @@ private sealed record EventSubscription(Type EventType, Action Han /// public string? WorkspacePath { get; internal set; } + /// + /// Gets the remote URL of this cloud session, if available. + /// + /// + /// The Mission Control remote URL returned by the runtime when creating a cloud session, + /// or null for local sessions. + /// + public string? RemoteUrl { get; internal set; } + /// /// Gets the capabilities reported by the host for this session. /// diff --git a/dotnet/test/Unit/CloudSessionTests.cs b/dotnet/test/Unit/CloudSessionTests.cs new file mode 100644 index 000000000..97dfe9f02 --- /dev/null +++ b/dotnet/test/Unit/CloudSessionTests.cs @@ -0,0 +1,766 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +#if NET8_0_OR_GREATER +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Text.Json; +using Xunit; + +namespace GitHub.Copilot.Test.Unit; + +/// +/// Unit tests for and the rejection guard +/// on for cloud configs. +/// +public sealed class CloudSessionTests +{ + // ------------------------------------------------------------------------- + // 1. CreateSessionAsync rejects cloud config + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateSessionAsync_Rejects_CloudConfig() + { + await using var server = await FakeCloudServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + var ex = await Assert.ThrowsAsync(() => + client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions { Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } } + })); + + Assert.Contains("CreateCloudSessionAsync", ex.Message); + } + + // ------------------------------------------------------------------------- + // 2. CreateCloudSessionAsync sends session.create with cloud and without sessionId + // (wire-shape correctness: assert sessionId absent from serialized JSON) + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Sends_Create_With_Cloud_And_Without_SessionId() + { + await using var server = await FakeCloudServer.StartAsync(cloudSessionId: "remote-cloud-session"); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var session = await client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk", Branch = "main" } + } + }); + + // Verify the session id is the runtime-assigned one. + Assert.Equal("remote-cloud-session", session.SessionId); + + // Verify the wire payload captured by the server had no sessionId and had cloud. + var payload = server.LastCreatePayload; + Assert.NotNull(payload); + Assert.False(payload!.Value.TryGetProperty("sessionId", out _), + "session.create payload must not contain 'sessionId' on the cloud path."); + Assert.True(payload.Value.TryGetProperty("cloud", out var cloud)); + Assert.Equal("github", cloud.GetProperty("repository").GetProperty("owner").GetString()); + Assert.Equal("copilot-sdk", cloud.GetProperty("repository").GetProperty("name").GetString()); + Assert.Equal("main", cloud.GetProperty("repository").GetProperty("branch").GetString()); + } + + // Supplementary serialization-layer assertion: JsonSerializer.Serialize on a + // CreateSessionRequest with SessionId=null must not emit the key. + [Fact] + public void CreateSessionRequest_WithNullSessionId_DoesNotEmitSessionIdKey() + { + // Retrieve the private serializer options the SDK uses (same approach as SerializationTests). + var prop = typeof(CopilotClient) + .GetProperty("SerializerOptionsForMessageFormatter", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + var options = (JsonSerializerOptions?)prop?.GetValue(null); + Assert.NotNull(options); + + var requestType = typeof(CopilotClient) + .GetNestedType("CreateSessionRequest", System.Reflection.BindingFlags.NonPublic); + Assert.NotNull(requestType); + + // Build a request with SessionId = null (the cloud path). + var instance = System.Runtime.CompilerServices.RuntimeHelpers.GetUninitializedObject(requestType!); + var cloudField = requestType!.GetField("k__BackingField", + System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic); + cloudField?.SetValue(instance, new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "o", Name = "r" } + }); + + var json = JsonSerializer.Serialize(instance, requestType, options!); + + Assert.DoesNotContain("\"sessionId\"", json); + Assert.Contains("\"cloud\"", json); + } + + // ------------------------------------------------------------------------- + // 3. Rejects caller-provided SessionId + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Rejects_CallerSessionId() + { + await using var server = await FakeCloudServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + var ex = await Assert.ThrowsAsync(() => + client.CreateCloudSessionAsync(new SessionConfig + { + SessionId = "caller-id", + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions { Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } } + })); + + Assert.Contains("SessionId", ex.Message); + } + + // ------------------------------------------------------------------------- + // 4. Rejects caller-provided Provider + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Rejects_CallerProvider() + { + await using var server = await FakeCloudServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + var ex = await Assert.ThrowsAsync(() => + client.CreateCloudSessionAsync(new SessionConfig + { + Provider = new ProviderConfig { BaseUrl = "https://api.example.com/v1" }, + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions { Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } } + })); + + Assert.Contains("Provider", ex.Message); + } + + // ------------------------------------------------------------------------- + // 5. Requires Cloud option + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Requires_Cloud() + { + await using var server = await FakeCloudServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + var ex = await Assert.ThrowsAsync(() => + client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll + // Cloud deliberately absent + })); + + Assert.Contains("Cloud", ex.Message); + } + + // ------------------------------------------------------------------------- + // 6. Buffers early session.event notifications until session id is registered + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Buffers_Early_Notifications_Until_Registration() + { + const string cloudId = "remote-cloud-session"; + + // Server is configured to send a session.event notification before + // responding to session.create. + await using var server = await FakeCloudServer.StartAsync( + cloudSessionId: cloudId, + earlyNotification: new Dictionary + { + ["method"] = "session.event", + ["params"] = new Dictionary + { + ["sessionId"] = cloudId, + ["event"] = new Dictionary + { + ["type"] = "capabilities.changed", + ["data"] = new Dictionary + { + ["ui"] = new Dictionary { ["elicitation"] = true } + } + } + } + }); + + var receivedEvents = new List(); + + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + await using var session = await client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } + }, + OnEvent = evt => + { + if (evt is CapabilitiesChangedEvent capEvt) + { + receivedEvents.Add(capEvt); + } + } + }); + + // Allow the event channel to drain. + var deadline = DateTime.UtcNow.AddSeconds(5); + while (receivedEvents.Count == 0 && DateTime.UtcNow < deadline) + { + await Task.Delay(20); + } + + Assert.Single(receivedEvents); + Assert.True(receivedEvents[0].Data?.Ui?.Elicitation == true); + } + + // ------------------------------------------------------------------------- + // 7. Parks inbound requests until session id is registered + // ------------------------------------------------------------------------- + + [Fact] + public async Task CreateCloudSessionAsync_Parks_Inbound_Requests_Until_Registration() + { + const string cloudId = "remote-cloud-session"; + + // Server sends a userInput.request before responding to session.create. + await using var server = await FakeCloudServer.StartAsync( + cloudSessionId: cloudId, + earlyInboundRequest: new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = 301, + ["method"] = "userInput.request", + ["params"] = new Dictionary + { + ["sessionId"] = cloudId, + ["question"] = "Pick a color", + ["choices"] = new object?[] { "red", "blue" }, + ["allowFreeform"] = true + } + }); + + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + await using var session = await client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } + }, + OnUserInputRequest = (_, _) => Task.FromResult(new UserInputResponse { Answer = "blue", WasFreeform = true }) + }); + + // Wait for the server to receive the userInput response. + var response = await server.WaitForUserInputResponse(TimeSpan.FromSeconds(5)); + Assert.NotNull(response); + Assert.Equal("blue", response!["answer"]?.ToString()); + } + + // ------------------------------------------------------------------------- + // 8. Pending-waiter overflow: oldest is rejected, remaining 128 succeed + // ------------------------------------------------------------------------- + + [Fact] + public async Task PendingRequestWaiterOverflow_RejectsOldestWithOverflowMessage() + { + const string cloudId = "overflow-session"; + const int requestCount = 129; // one beyond the 128-waiter cap + + await using var server = await FakeCloudServer.StartAsync( + cloudSessionId: cloudId, + earlyInboundRequestCount: requestCount); + + await using var client = new CopilotClient(new CopilotClientOptions + { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var _ = await client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } + }, + OnUserInputRequest = (_, _) => Task.FromResult(new UserInputResponse { Answer = "yes", WasFreeform = false }) + }); + + var responses = await server.WaitForAllInboundResponses(requestCount, TimeSpan.FromSeconds(10)); + + // Exactly one overflow eviction, 128 successful completions. + Assert.Equal(1, responses.Count(r => r.IsError)); + Assert.Equal(128, responses.Count(r => !r.IsError)); + + var err = responses.Single(r => r.IsError); + Assert.Contains("pending session buffer overflow", err.ErrorMessage ?? ""); + } + + // ------------------------------------------------------------------------- + // 9. Guard-drop path: parked requests are rejected with distinct message + // ------------------------------------------------------------------------- + + [Fact] + public async Task PendingSessionGuardDrop_RejectsParkedRequestWithDistinctMessage() + { + const string cloudId = "guard-drop-session"; + const int inboundRequestId = 500; + + await using var server = await FakeCloudServer.StartAsync( + cloudSessionId: cloudId, + failSessionCreate: true, + earlyInboundRequest: new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = inboundRequestId, + ["method"] = "userInput.request", + ["params"] = new Dictionary + { + ["sessionId"] = cloudId, + ["question"] = "Color?", + ["choices"] = new object?[] { "red", "blue" }, + ["allowFreeform"] = false + } + }); + + await using var client = new CopilotClient(new CopilotClientOptions + { Connection = RuntimeConnection.ForUri(server.Url) }); + + await Assert.ThrowsAnyAsync(() => + client.CreateCloudSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository { Owner = "github", Name = "copilot-sdk" } + } + })); + + // The parked request must have been rejected with the guard-drop message (not the overflow message). + var responses = await server.WaitForAllInboundResponses(1, TimeSpan.FromSeconds(5)); + + Assert.Single(responses); + Assert.True(responses[0].IsError); + Assert.Contains( + "pending session routing ended before session was registered", + responses[0].ErrorMessage ?? ""); + } + + // ========================================================================= + // Fake server infrastructure + // ========================================================================= + + private sealed class FakeCloudServer : IAsyncDisposable + { + private readonly TcpListener _listener; + private readonly CancellationTokenSource _cts = new(); + private readonly SemaphoreSlim _writeLock = new(1, 1); + private readonly Task _serverTask; + private readonly string _cloudSessionId; + private readonly Dictionary? _earlyNotification; + private readonly Dictionary? _earlyInboundRequest; + private readonly int _earlyInboundRequestCount; + private readonly bool _failSessionCreate; + private readonly TaskCompletionSource?> _userInputResponseTcs = + new(TaskCreationOptions.RunContinuationsAsynchronously); + + // Response tracking for overflow / guard-drop tests. + private readonly object _inboundResponsesLock = new(); + private readonly List _collectedInboundResponses = []; + private int _waitForInboundResponseCount; + private TaskCompletionSource>? _allInboundResponsesTcs; + + public JsonElement? LastCreatePayload { get; private set; } + + public record InboundResponse(int RequestId, bool IsError, string? ErrorMessage); + + private FakeCloudServer( + TcpListener listener, + string cloudSessionId, + Dictionary? earlyNotification, + Dictionary? earlyInboundRequest, + int earlyInboundRequestCount, + bool failSessionCreate) + { + _listener = listener; + _cloudSessionId = cloudSessionId; + _earlyNotification = earlyNotification; + _earlyInboundRequest = earlyInboundRequest; + _earlyInboundRequestCount = earlyInboundRequestCount; + _failSessionCreate = failSessionCreate; + _serverTask = RunAsync(); + } + + public string Url + { + get + { + var endpoint = (IPEndPoint)_listener.LocalEndpoint; + return $"http://127.0.0.1:{endpoint.Port}"; + } + } + + public static Task StartAsync( + string cloudSessionId = "cloud-session-id", + Dictionary? earlyNotification = null, + Dictionary? earlyInboundRequest = null, + int earlyInboundRequestCount = 0, + bool failSessionCreate = false) + { + var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + return Task.FromResult(new FakeCloudServer( + listener, cloudSessionId, earlyNotification, earlyInboundRequest, + earlyInboundRequestCount, failSessionCreate)); + } + + public Task?> WaitForUserInputResponse(TimeSpan timeout) + => _userInputResponseTcs.Task.WaitAsync(timeout); + + /// + /// Waits until the server has collected responses (error or success) + /// from the client for inbound requests. Used by overflow and guard-drop tests. + /// + public Task> WaitForAllInboundResponses(int count, TimeSpan timeout) + { + var tcs = new TaskCompletionSource>( + TaskCreationOptions.RunContinuationsAsynchronously); + lock (_inboundResponsesLock) + { + _waitForInboundResponseCount = count; + _allInboundResponsesTcs = tcs; + if (_collectedInboundResponses.Count >= count) + tcs.TrySetResult(new List(_collectedInboundResponses)); + } + return tcs.Task.WaitAsync(timeout); + } + + private void RecordInboundResponse(InboundResponse response) + { + TaskCompletionSource>? tcs = null; + IReadOnlyList? snapshot = null; + lock (_inboundResponsesLock) + { + _collectedInboundResponses.Add(response); + if (_allInboundResponsesTcs != null && + _collectedInboundResponses.Count >= _waitForInboundResponseCount) + { + tcs = _allInboundResponsesTcs; + snapshot = new List(_collectedInboundResponses); + } + } + tcs?.TrySetResult(snapshot!); + } + + public async ValueTask DisposeAsync() + { + _cts.Cancel(); + _listener.Stop(); + + try { await _serverTask; } + catch (Exception ex) when (ex is OperationCanceledException or ObjectDisposedException or IOException or SocketException) { } + + _cts.Dispose(); + _writeLock.Dispose(); + } + + private async Task RunAsync() + { + using var tcpClient = await _listener.AcceptTcpClientAsync(_cts.Token); + using var stream = tcpClient.GetStream(); + + while (!_cts.Token.IsCancellationRequested) + { + using var request = await ReadMessageAsync(stream, _cts.Token); + if (request is null) return; + await HandleRequestAsync(stream, request.RootElement, _cts.Token); + } + } + + private async Task HandleRequestAsync(Stream stream, JsonElement request, CancellationToken cancellationToken) + { + // Identify the message type: + // - Response from SDK: has "id" and "result"/"error" but no "method" + // - Request from SDK: has "id" and "method" + // - Notification from SDK: has "method" but no "id" + + bool hasId = request.TryGetProperty("id", out var idElement); + bool hasMethod = request.TryGetProperty("method", out var methodEl); + bool hasResult = request.TryGetProperty("result", out var resultEl); + + if (hasId && !hasMethod) + { + // This is a response from the SDK (e.g. userInput reply). + if (hasResult) + { + var dict = new Dictionary(); + foreach (var prop in resultEl.EnumerateObject()) + { + dict[prop.Name] = prop.Value.ValueKind switch + { + JsonValueKind.String => prop.Value.GetString(), + JsonValueKind.True => true, + JsonValueKind.False => false, + _ => prop.Value.GetRawText() + }; + } + _userInputResponseTcs.TrySetResult(dict); + + if (idElement.ValueKind == JsonValueKind.Number && idElement.TryGetInt32(out var successId)) + RecordInboundResponse(new InboundResponse(successId, IsError: false, null)); + } + else if (request.TryGetProperty("error", out var errorEl)) + { + var requestId = idElement.ValueKind == JsonValueKind.Number && idElement.TryGetInt32(out var errId) + ? errId : -1; + var msg = errorEl.TryGetProperty("message", out var msgEl) ? msgEl.GetString() : null; + RecordInboundResponse(new InboundResponse(requestId, IsError: true, msg)); + } + return; + } + + if (!hasId) + { + // Notification — nothing to respond to. + return; + } + + var id = idElement.Clone(); + var method = methodEl.GetString(); + + if (method == "connect") + { + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["result"] = new Dictionary + { + ["ok"] = true, + ["protocolVersion"] = 3, + ["version"] = "test" + } + }, cancellationToken); + return; + } + + if (method == "session.create") + { + // Capture the params for assertions. + if (request.TryGetProperty("params", out var paramsEl)) + { + LastCreatePayload = paramsEl.Clone(); + } + + // Optionally send an early notification before responding. + if (_earlyNotification != null) + { + await WriteMessageAsync(stream, _earlyNotification, cancellationToken); + } + + // Optionally send an early inbound request before responding. + if (_earlyInboundRequest != null) + { + await WriteMessageAsync(stream, _earlyInboundRequest, cancellationToken); + // Give the SDK a moment to park the request before we unblock create. + await Task.Delay(50, cancellationToken); + } + + // For overflow tests: send N inbound requests to exercise the buffer cap. + if (_earlyInboundRequestCount > 0) + { + for (var i = 1; i <= _earlyInboundRequestCount; i++) + { + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = i, + ["method"] = "userInput.request", + ["params"] = new Dictionary + { + ["sessionId"] = _cloudSessionId, + ["question"] = $"Question {i}", + ["choices"] = new object?[] { "yes", "no" }, + ["allowFreeform"] = false + } + }, cancellationToken); + } + + // Give the client time to park/overflow all requests before responding. + await Task.Delay(100, cancellationToken); + } + + if (_failSessionCreate) + { + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["error"] = new Dictionary + { + ["code"] = -32603, + ["message"] = "session.create failed (test-induced failure)" + } + }, cancellationToken); + return; + } + + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["result"] = new Dictionary + { + ["sessionId"] = _cloudSessionId, + ["workspacePath"] = null, + ["capabilities"] = null + } + }, cancellationToken); + return; + } + + if (method == "session.destroy") + { + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["result"] = new Dictionary() + }, cancellationToken); + return; + } + + // Default: return an empty success result. + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["result"] = new Dictionary() + }, cancellationToken); + } + + private async Task WriteMessageAsync(Stream stream, object payload, CancellationToken cancellationToken) + { + using var bodyStream = new MemoryStream(); + using (var writer = new Utf8JsonWriter(bodyStream)) + { + WriteJsonValue(writer, payload); + } + + var body = bodyStream.ToArray(); + var header = Encoding.ASCII.GetBytes($"Content-Length: {body.Length}\r\n\r\n"); + + await _writeLock.WaitAsync(cancellationToken); + try + { + await stream.WriteAsync(header, cancellationToken); + await stream.WriteAsync(body, cancellationToken); + await stream.FlushAsync(cancellationToken); + } + finally + { + _writeLock.Release(); + } + } + + private static void WriteJsonValue(Utf8JsonWriter writer, object? value) + { + switch (value) + { + case null: + writer.WriteNullValue(); + break; + case string s: + writer.WriteStringValue(s); + break; + case bool b: + writer.WriteBooleanValue(b); + break; + case int i: + writer.WriteNumberValue(i); + break; + case long l: + writer.WriteNumberValue(l); + break; + case JsonElement je: + je.WriteTo(writer); + break; + case Dictionary dict: + writer.WriteStartObject(); + foreach (var (k, v) in dict) + { + writer.WritePropertyName(k); + WriteJsonValue(writer, v); + } + writer.WriteEndObject(); + break; + case object?[] arr: + writer.WriteStartArray(); + foreach (var item in arr) + { + WriteJsonValue(writer, item); + } + writer.WriteEndArray(); + break; + default: + throw new InvalidOperationException($"Unexpected JSON value type '{value.GetType().Name}'."); + } + } + + private static async Task ReadMessageAsync(Stream stream, CancellationToken cancellationToken) + { + var headerBytes = new List(); + while (true) + { + var value = await ReadByteAsync(stream, cancellationToken); + if (value < 0) return null; + + headerBytes.Add((byte)value); + var count = headerBytes.Count; + if (count >= 4 && + headerBytes[count - 4] == '\r' && + headerBytes[count - 3] == '\n' && + headerBytes[count - 2] == '\r' && + headerBytes[count - 1] == '\n') + { + break; + } + } + + var header = Encoding.ASCII.GetString([.. headerBytes]); + var contentLength = header + .Split(["\r\n"], StringSplitOptions.RemoveEmptyEntries) + .Select(line => line.Split(':', 2)) + .Where(parts => parts.Length == 2 && parts[0].Equals("Content-Length", StringComparison.OrdinalIgnoreCase)) + .Select(parts => int.Parse(parts[1].Trim(), System.Globalization.CultureInfo.InvariantCulture)) + .Single(); + + var body = new byte[contentLength]; + var offset = 0; + while (offset < body.Length) + { + var read = await stream.ReadAsync(body.AsMemory(offset, body.Length - offset), cancellationToken); + if (read == 0) return null; + offset += read; + } + + return JsonDocument.Parse(body); + } + + private static async Task ReadByteAsync(Stream stream, CancellationToken cancellationToken) + { + var buffer = new byte[1]; + var read = await stream.ReadAsync(buffer, cancellationToken); + return read == 0 ? -1 : buffer[0]; + } + } +} +#endif