From e7734f8f1f50bf28f5cfd80487a2ff5aa82e6cf3 Mon Sep 17 00:00:00 2001 From: Timothy Clem Date: Fri, 22 May 2026 19:25:01 -0700 Subject: [PATCH 1/2] Go SDK: add Client.CreateCloudSession - CreateSession now rejects config.Cloud with an error pointing to CreateCloudSession - CreateCloudSession validates config (requires Cloud, rejects caller-supplied SessionID and Provider), omits sessionId from the session.create wire payload so the runtime assigns one - Pending-routing support: a refcounted beginPendingSessionRouting guard buffers session.event notifications (bounded drop-oldest, 128 per id) and parks inbound request handlers (userInput.request, exitPlanMode.request, autoModeSwitch.request, hooks.invoke, systemMessage.transform) until the runtime-assigned session id is registered; waiters are rejected with a clear error if pending mode ends without registration - TOCTOU race fixed: after acquiring pending.mu, both handleSessionEvent and waitForSession re-check c.sessions so a notification/request that races with flushPendingForSession is dispatched directly rather than buffered and abandoned - Waiter buffer is also bounded at 128; oldest waiter is rejected on overflow - README and inline godoc updated Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- go/README.md | 24 +++ go/client.go | 397 ++++++++++++++++++++++++++++++++++++--- go/cloud_session_test.go | 329 ++++++++++++++++++++++++++++++++ 3 files changed, 724 insertions(+), 26 deletions(-) create mode 100644 go/cloud_session_test.go diff --git a/go/README.md b/go/README.md index da77033f8..62f3b9198 100644 --- a/go/README.md +++ b/go/README.md @@ -105,6 +105,7 @@ That's it! When your application calls `copilot.NewClient` without a `Connection - `Stop() error` - Stop the CLI server - `ForceStop()` - Forcefully stop without graceful cleanup - `CreateSession(config *SessionConfig) (*Session, error)` - Create a new session +- `CreateCloudSession(ctx context.Context, config *SessionConfig) (*Session, error)` - Create a Mission Control–backed cloud session - `ResumeSession(sessionID string, config *ResumeSessionConfig) (*Session, error)` - Resume an existing session - `ResumeSessionWithOptions(sessionID string, config *ResumeSessionConfig) (*Session, error)` - Resume with additional configuration - `ListSessions(filter *SessionListFilter) ([]SessionMetadata, error)` - List sessions (with optional filter) @@ -170,6 +171,8 @@ Event types: `SessionLifecycleCreated`, `SessionLifecycleDeleted`, `SessionLifec - `Commands` ([]CommandDefinition): Slash-commands registered for this session. See [Commands](#commands) section. - `OnElicitationRequest` (ElicitationHandler): Handler for elicitation requests from the server. See [Elicitation Requests](#elicitation-requests-serverclient) section. +- `Cloud` (\*CloudSessionOptions): Cloud session configuration. When set, `CreateSession` rejects the config; use `CreateCloudSession` instead. Do not set `SessionID` or `Provider` when using cloud sessions. + **ResumeSessionConfig:** - `OnPermissionRequest` (PermissionHandlerFunc): Optional handler called before each tool execution to approve or deny it. See [Permission Handling](#permission-handling) section. @@ -487,6 +490,27 @@ When enabled, sessions emit compaction events: - `session.compaction_start` - Background compaction started - `session.compaction_complete` - Compaction finished (includes token counts) +## Cloud Sessions + +`CreateCloudSession` creates a Mission Control–backed cloud session. The runtime assigns the session ID; do not set `SessionID` or `Provider` on the config (the SDK rejects both). `CreateSession` also rejects any config that has `Cloud` set. + +Any `session.event` notifications or inbound JSON-RPC requests that arrive between sending `session.create` and receiving its response are buffered (bounded, drop-oldest, limit 128 per id) and replayed once the runtime-assigned session ID is registered. + +```go +session, err := client.CreateCloudSession(context.Background(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Cloud: &copilot.CloudSessionOptions{ + Repository: &copilot.CloudSessionRepository{ + Owner: "github", Name: "copilot-sdk", + }, + }, +}) +if err != nil { + log.Fatal(err) +} +fmt.Println("cloud session id:", session.SessionID) +``` + ## Custom Providers The SDK supports custom OpenAI-compatible API providers (BYOK - Bring Your Own Key), including local providers like Ollama. When using a custom provider, you must specify the `Model` explicitly. diff --git a/go/client.go b/go/client.go index 9491eb199..0f5db2419 100644 --- a/go/client.go +++ b/go/client.go @@ -87,6 +87,27 @@ func validateSessionFsConfig(config *SessionFsConfig) error { // log.Fatal(err) // } // defer client.Stop() +// +// pendingResult carries the outcome of a parked inbound-request session lookup. +type pendingResult struct { + session *Session + err error +} + +// pendingRouting buffers session.event notifications and parks inbound request +// handlers that arrive before a cloud session.create response is received. +// A refcount tracks how many cloud creates are in flight; when it reaches zero +// the buffers are cleared and parked handlers are rejected. +type pendingRouting struct { + mu sync.Mutex + count int + events map[string][]sessionEventRequest + waiters map[string][]chan pendingResult +} + +// pendingSessionBufferLimit caps buffered notifications per in-flight session id. +const pendingSessionBufferLimit = 128 + type Client struct { options ClientOptions process *exec.Cmd @@ -121,6 +142,10 @@ type Client struct { effectiveConnectionToken string onListModels func(ctx context.Context) ([]ModelInfo, error) + // pending buffers traffic that arrives between session.create being sent and + // the response for cloud sessions. + pending pendingRouting + // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). RPC *rpc.ServerRpc @@ -162,6 +187,10 @@ func NewClient(options *ClientOptions) *Client { actualHost: "localhost", isExternalServer: false, useStdio: true, + pending: pendingRouting{ + events: make(map[string][]sessionEventRequest), + waiters: make(map[string][]chan pendingResult), + }, } if options != nil { @@ -593,6 +622,10 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses config = &SessionConfig{} } + if config.Cloud != nil { + return nil, fmt.Errorf("CreateSession does not support cloud sessions; use CreateCloudSession instead") + } + if err := c.ensureConnected(ctx); err != nil { return nil, err } @@ -754,6 +787,307 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses return session, nil } +// CreateCloudSession creates a Mission Control–backed cloud session. +// +// The runtime owns the session ID for cloud sessions: do not set SessionID or +// Provider on the config (the SDK rejects both). Build the config with Cloud +// set to a [CloudSessionOptions] value; [Client.CreateSession] rejects any +// config that has Cloud set. +// +// Any session.event notifications or inbound JSON-RPC requests that arrive +// between sending session.create and receiving its response are buffered +// (bounded, drop-oldest, limit 128 per id) and replayed once the +// runtime-assigned session ID is registered. +// +// Known limitation: inbound sessionFs.* requests from the generated +// client-session API handlers are not pending-buffered. In practice the +// runtime does not initiate sessionFs.* calls before the session.create +// response, so this is theoretical. +// +// Example: +// +// session, err := client.CreateCloudSession(context.Background(), &copilot.SessionConfig{ +// OnPermissionRequest: copilot.PermissionHandler.ApproveAll, +// Cloud: &copilot.CloudSessionOptions{ +// Repository: &copilot.CloudSessionRepository{ +// Owner: "github", Name: "copilot-sdk", +// }, +// }, +// }) +func (c *Client) CreateCloudSession(ctx context.Context, config *SessionConfig) (*Session, error) { + if config == nil { + config = &SessionConfig{} + } + + if config.Cloud == nil { + return nil, fmt.Errorf("CreateCloudSession requires config.Cloud to be set") + } + if config.SessionID != "" { + return nil, fmt.Errorf("CreateCloudSession does not accept a caller-provided SessionID; the runtime assigns one") + } + if config.Provider != nil { + return nil, fmt.Errorf("CreateCloudSession does not accept config.Provider; cloud sessions use the runtime's provider") + } + + if err := c.ensureConnected(ctx); err != nil { + return nil, err + } + + req := createSessionRequest{} + req.Model = config.Model + req.ClientName = config.ClientName + req.ReasoningEffort = config.ReasoningEffort + req.ConfigDir = config.ConfigDir + if config.EnableConfigDiscovery { + req.EnableConfigDiscovery = Bool(true) + } + req.Tools = config.Tools + wireSystemMessage, transformCallbacks := extractTransformCallbacks(config.SystemMessage) + req.SystemMessage = wireSystemMessage + req.AvailableTools = config.AvailableTools + req.ExcludedTools = config.ExcludedTools + req.EnableSessionTelemetry = config.EnableSessionTelemetry + req.ModelCapabilities = config.ModelCapabilities + req.WorkingDirectory = config.WorkingDirectory + req.MCPServers = config.MCPServers + req.EnvValueMode = "direct" + req.CustomAgents = config.CustomAgents + req.DefaultAgent = config.DefaultAgent + req.Agent = config.Agent + req.SkillDirectories = config.SkillDirectories + req.InstructionDirectories = config.InstructionDirectories + req.DisabledSkills = config.DisabledSkills + req.InfiniteSessions = config.InfiniteSessions + req.GitHubToken = config.GitHubToken + req.RemoteSession = config.RemoteSession + req.Cloud = config.Cloud + // SessionID intentionally omitted: the runtime assigns the id for cloud sessions. + + if len(config.Commands) > 0 { + cmds := make([]wireCommand, 0, len(config.Commands)) + for _, cmd := range config.Commands { + cmds = append(cmds, wireCommand{Name: cmd.Name, Description: cmd.Description}) + } + req.Commands = cmds + } + if config.OnElicitationRequest != nil { + req.RequestElicitation = Bool(true) + } + if config.OnExitPlanModeRequest != nil { + req.RequestExitPlanMode = Bool(true) + } + if config.OnAutoModeSwitchRequest != nil { + req.RequestAutoModeSwitch = Bool(true) + } + if config.Streaming != nil { + req.Streaming = config.Streaming + } + if config.IncludeSubAgentStreamingEvents != nil { + req.IncludeSubAgentStreamingEvents = config.IncludeSubAgentStreamingEvents + } else { + req.IncludeSubAgentStreamingEvents = Bool(true) + } + if config.OnUserInputRequest != nil { + req.RequestUserInput = Bool(true) + } + if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || + config.Hooks.OnPreMcpToolCall != nil || + config.Hooks.OnPostToolUse != nil || + config.Hooks.OnUserPromptSubmitted != nil || + config.Hooks.OnSessionStart != nil || + config.Hooks.OnSessionEnd != nil || + config.Hooks.OnErrorOccurred != nil) { + req.Hooks = Bool(true) + } + if config.OnPermissionRequest != nil { + req.RequestPermission = Bool(true) + } + + traceparent, tracestate := getTraceContext(ctx) + req.Traceparent = traceparent + req.Tracestate = tracestate + + dispose := c.beginPendingSessionRouting() + + result, err := c.client.Request("session.create", req) + if err != nil { + dispose() + return nil, fmt.Errorf("failed to create cloud session: %w", err) + } + + var response createSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + dispose() + return nil, fmt.Errorf("failed to unmarshal cloud session response: %w", err) + } + + if response.SessionID == "" { + fmt.Println("warning: cloud session.create response missing sessionId; runtime session may leak") + dispose() + return nil, fmt.Errorf("cloud session.create response did not include a sessionId; cannot register session") + } + + sessionID := response.SessionID + session := newSession(sessionID, c.client, response.WorkspacePath) + + session.registerTools(config.Tools) + session.registerPermissionHandler(config.OnPermissionRequest) + if config.OnUserInputRequest != nil { + session.registerUserInputHandler(config.OnUserInputRequest) + } + if config.Hooks != nil { + session.registerHooks(config.Hooks) + } + if transformCallbacks != nil { + session.registerTransformCallbacks(transformCallbacks) + } + if config.OnEvent != nil { + session.On(config.OnEvent) + } + if len(config.Commands) > 0 { + session.registerCommands(config.Commands) + } + if config.OnElicitationRequest != nil { + session.registerElicitationHandler(config.OnElicitationRequest) + } + if config.OnExitPlanModeRequest != nil { + session.registerExitPlanModeHandler(config.OnExitPlanModeRequest) + } + if config.OnAutoModeSwitchRequest != nil { + session.registerAutoModeSwitchHandler(config.OnAutoModeSwitchRequest) + } + session.setCapabilities(response.Capabilities) + + c.sessionsMux.Lock() + c.sessions[sessionID] = session + c.sessionsMux.Unlock() + + if c.options.SessionFs != nil { + if config.CreateSessionFsProvider == nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + dispose() + return nil, fmt.Errorf("CreateSessionFsProvider is required in session config when SessionFs is enabled in client options") + } + provider := config.CreateSessionFsProvider(session) + if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { + if _, ok := provider.(SessionFsSqliteProvider); !ok { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + dispose() + return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") + } + } + session.clientSessionApis.SessionFs = newSessionFsAdapter(provider) + } + + // Drain buffered events and unblock parked request handlers before + // releasing the guard, so they see a fully-wired session. + c.flushPendingForSession(sessionID, session) + dispose() + + return session, nil +} + +// beginPendingSessionRouting increments the pending-routing refcount and +// returns a disposer. While any disposer is undisposed, session.event +// notifications and inbound JSON-RPC requests addressed to unknown session ids +// are buffered/parked rather than dropped. When the last disposer fires, any +// remaining buffers are cleared and parked handlers receive an error so they +// don't block forever. +func (c *Client) beginPendingSessionRouting() func() { + c.pending.mu.Lock() + c.pending.count++ + c.pending.mu.Unlock() + + var once sync.Once + return func() { + once.Do(func() { + c.pending.mu.Lock() + c.pending.count-- + if c.pending.count > 0 { + c.pending.mu.Unlock() + return + } + // Last guard: swap out the maps so we can signal waiters without + // holding the lock (buffered channels make sends non-blocking, but + // releasing the lock is cleaner). + c.pending.events = make(map[string][]sessionEventRequest) + waiters := c.pending.waiters + c.pending.waiters = make(map[string][]chan pendingResult) + c.pending.mu.Unlock() + + for _, chs := range waiters { + for _, ch := range chs { + ch <- pendingResult{err: fmt.Errorf("request dropped: cloud session.create completed without registering this session id")} + } + } + }) + } +} + +// flushPendingForSession drains buffered events and resolves parked request +// handlers for sessionID into the freshly-registered session. Called from +// CreateCloudSession after the session is in c.sessions and before the pending +// guard is released. +func (c *Client) flushPendingForSession(sessionID string, session *Session) { + c.pending.mu.Lock() + events := c.pending.events[sessionID] + delete(c.pending.events, sessionID) + waiters := c.pending.waiters[sessionID] + delete(c.pending.waiters, sessionID) + c.pending.mu.Unlock() + + for _, req := range events { + session.dispatchEvent(req.Event) + } + for _, ch := range waiters { + ch <- pendingResult{session: session} + } +} + +// waitForSession looks up the session by id. If the session is not yet +// registered but pending routing is active, the call parks until the session +// is registered (or pending routing ends without registration). +func (c *Client) waitForSession(sessionID string) (*Session, error) { + c.sessionsMux.Lock() + session, ok := c.sessions[sessionID] + c.sessionsMux.Unlock() + if ok { + return session, nil + } + + c.pending.mu.Lock() + if c.pending.count == 0 { + c.pending.mu.Unlock() + return nil, fmt.Errorf("unknown session %s", sessionID) + } + // Re-check under pending.mu: the session may have been registered and + // flushed between the first lookup and acquiring this lock. + c.sessionsMux.Lock() + session, ok = c.sessions[sessionID] + c.sessionsMux.Unlock() + if ok { + c.pending.mu.Unlock() + return session, nil + } + ch := make(chan pendingResult, 1) + waiters := c.pending.waiters[sessionID] + if len(waiters) >= pendingSessionBufferLimit { + // Reject the oldest waiter to keep the queue bounded. + oldest := waiters[0] + waiters = waiters[1:] + oldest <- pendingResult{err: fmt.Errorf("request dropped: pending session waiter buffer full for %s", sessionID)} + } + c.pending.waiters[sessionID] = append(waiters, ch) + c.pending.mu.Unlock() + + result := <-ch + return result.session, result.err +} + // ResumeSession resumes an existing conversation session by its ID. // // This is a convenience method that calls [Client.ResumeSessionWithOptions]. @@ -1761,14 +2095,35 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) { if req.SessionID == "" { return } - // Dispatch to session c.sessionsMux.Lock() session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if ok { session.dispatchEvent(req.Event) + return + } + + // Buffer if a cloud session.create is in flight for this id. + c.pending.mu.Lock() + if c.pending.count > 0 { + // Re-check under pending.mu: the session may have been registered and + // flushed between the first lookup and acquiring this lock. + c.sessionsMux.Lock() + session, ok = c.sessions[req.SessionID] + c.sessionsMux.Unlock() + if ok { + c.pending.mu.Unlock() + session.dispatchEvent(req.Event) + return + } + buf := c.pending.events[req.SessionID] + if len(buf) >= pendingSessionBufferLimit { + buf = buf[1:] // drop oldest + } + c.pending.events[req.SessionID] = append(buf, req) } + c.pending.mu.Unlock() } // handleUserInputRequest handles a user input request from the CLI server. @@ -1777,11 +2132,9 @@ func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputRespons return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } response, err := session.handleUserInputRequest(UserInputRequest{ @@ -1806,11 +2159,9 @@ func (c *Client) handleExitPlanModeRequest(req exitPlanModeRequest) (*ExitPlanMo recommendedAction = "autopilot" } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } response, err := session.handleExitPlanModeRequest(ExitPlanModeRequest{ @@ -1832,11 +2183,9 @@ func (c *Client) handleAutoModeSwitchRequest(req autoModeSwitchRequest) (*autoMo return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid auto mode switch request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } response, err := session.handleAutoModeSwitchRequest(AutoModeSwitchRequest{ @@ -1856,11 +2205,9 @@ func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jso return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } output, err := session.handleHooksInvoke(req.Type, req.Input) @@ -1881,11 +2228,9 @@ func (c *Client) handleSystemMessageTransform(req systemMessageTransformRequest) return systemMessageTransformResponse{}, &jsonrpc2.Error{Code: -32602, Message: "invalid system message transform payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return systemMessageTransformResponse{}, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return systemMessageTransformResponse{}, &jsonrpc2.Error{Code: -32602, Message: err.Error()} } resp, err := session.handleSystemMessageTransform(req.Sections) diff --git a/go/cloud_session_test.go b/go/cloud_session_test.go new file mode 100644 index 000000000..ce827dbb0 --- /dev/null +++ b/go/cloud_session_test.go @@ -0,0 +1,329 @@ +package copilot + +import ( + "encoding/json" + "strings" + "sync" + "testing" + "time" + + "github.com/github/copilot-sdk/go/internal/jsonrpc2" +) + +// newCloudTestClient returns a Client with pending routing initialized and a +// pre-populated sessions map, suitable for unit-testing cloud session logic +// without a real network connection. +func newCloudTestClient() *Client { + return &Client{ + sessions: make(map[string]*Session), + pending: pendingRouting{ + events: make(map[string][]sessionEventRequest), + waiters: make(map[string][]chan pendingResult), + }, + } +} + +// TestCreateSession_RejectsCloudConfig verifies that CreateSession returns a +// clear error when config.Cloud is set. +func TestCreateSession_RejectsCloudConfig(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + }) + if err == nil { + t.Fatal("expected error when cloud config is set") + } + if !strings.Contains(err.Error(), "CreateCloudSession") { + t.Errorf("error should mention CreateCloudSession, got: %v", err) + } +} + +// TestCreateCloudSession_RejectsCallerSessionID verifies the SDK rejects a +// caller-supplied SessionID. +func TestCreateCloudSession_RejectsCallerSessionID(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + SessionID: "caller-supplied-id", + }) + if err == nil { + t.Fatal("expected error when SessionID is set") + } + if !strings.Contains(err.Error(), "SessionID") { + t.Errorf("error should mention SessionID, got: %v", err) + } +} + +// TestCreateCloudSession_RejectsCallerProvider verifies the SDK rejects a +// caller-supplied Provider. +func TestCreateCloudSession_RejectsCallerProvider(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + Provider: &ProviderConfig{ModelID: "gpt-4"}, + }) + if err == nil { + t.Fatal("expected error when Provider is set") + } + if !strings.Contains(err.Error(), "Provider") { + t.Errorf("error should mention Provider, got: %v", err) + } +} + +// TestCreateCloudSession_RequiresCloud verifies the SDK rejects configs without +// Cloud set. +func TestCreateCloudSession_RequiresCloud(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{}) + if err == nil { + t.Fatal("expected error when Cloud is nil") + } + if !strings.Contains(err.Error(), "Cloud") { + t.Errorf("error should mention Cloud, got: %v", err) + } +} + +// TestCreateCloudSession_WirePayload verifies that the session.create wire +// payload includes the cloud field and omits sessionId when built by the cloud +// path. +func TestCreateCloudSession_WirePayload(t *testing.T) { + req := createSessionRequest{ + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk"}, + }, + // SessionID intentionally left empty + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal error: %v", err) + } + + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if _, ok := m["sessionId"]; ok { + t.Error("sessionId must be omitted from the cloud session.create wire payload") + } + + cloud, ok := m["cloud"] + if !ok { + t.Fatal("cloud field must be present in the wire payload") + } + cloudMap, ok := cloud.(map[string]any) + if !ok { + t.Fatalf("cloud field should be a map, got %T", cloud) + } + repo, ok := cloudMap["repository"].(map[string]any) + if !ok { + t.Fatal("cloud.repository should be a map") + } + if repo["owner"] != "github" || repo["name"] != "copilot-sdk" { + t.Errorf("unexpected cloud.repository: %v", repo) + } +} + +// TestPendingRouting_BuffersEarlyNotifications verifies that session.event +// notifications arriving before the session is registered are buffered and +// replayed when flushPendingForSession is called. +func TestPendingRouting_BuffersEarlyNotifications(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + defer dispose() + + const pendingID = "runtime-assigned-id" + + // Simulate two session.event notifications arriving before the session is + // registered. + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{Data: &SessionIdleData{}}, + }) + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{Data: &SessionIdleData{}}, + }) + + // Verify they are buffered. + client.pending.mu.Lock() + bufLen := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + if bufLen != 2 { + t.Fatalf("expected 2 buffered events, got %d", bufLen) + } + + // Now register the session and flush. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + + var received []SessionEvent + var mu sync.Mutex + var wg sync.WaitGroup + wg.Add(2) + session.On(func(event SessionEvent) { + mu.Lock() + received = append(received, event) + mu.Unlock() + wg.Done() + }) + + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + + client.flushPendingForSession(pendingID, session) + + // Wait for the event handler goroutine to process. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for buffered events to be dispatched") + } + + mu.Lock() + got := len(received) + mu.Unlock() + if got != 2 { + t.Errorf("expected 2 events replayed, got %d", got) + } + + // Buffer should be cleared after flush. + client.pending.mu.Lock() + remaining := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + if remaining != 0 { + t.Errorf("buffer should be empty after flush, got %d", remaining) + } +} + +// TestPendingRouting_ParksInboundRequests verifies that inbound request handlers +// (e.g. userInput.request) park until the session is registered when pending +// routing is active. +func TestPendingRouting_ParksInboundRequests(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "runtime-assigned-id-2" + + // Launch a goroutine that simulates an inbound userInput.request arriving + // before the session is registered. + type result struct { + resp *userInputResponse + err *jsonrpcError + } + resultCh := make(chan result, 1) + go func() { + resp, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- result{resp, rpcErr} + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Register the session. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + session.registerUserInputHandler(func(req UserInputRequest, _ UserInputInvocation) (UserInputResponse, error) { + return UserInputResponse{Answer: "yes"}, nil + }) + + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + + client.flushPendingForSession(pendingID, session) + dispose() + + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("expected success, got rpc error: %v", r.err) + } + if r.resp == nil || r.resp.Answer != "yes" { + t.Errorf("expected answer 'yes', got %+v", r.resp) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for parked request to be resolved") + } +} + +// TestPendingRouting_DropOldestWhenBufferFull verifies drop-oldest behaviour +// when the notification buffer is full. +func TestPendingRouting_DropOldestWhenBufferFull(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + defer dispose() + + const pendingID = "overflow-session" + + // Fill buffer beyond the limit. + for i := range pendingSessionBufferLimit + 5 { + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{ + // Embed the index so we can verify drop-oldest. + Data: &SessionIdleData{}, + }, + }) + _ = i + } + + client.pending.mu.Lock() + bufLen := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + + if bufLen != pendingSessionBufferLimit { + t.Errorf("expected buffer capped at %d, got %d", pendingSessionBufferLimit, bufLen) + } +} + +// TestPendingRouting_RejectsWaitersOnDispose verifies that waiters are +// rejected with an error when pending mode ends without registration. +func TestPendingRouting_RejectsWaitersOnDispose(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "never-registered" + + resultCh := make(chan *jsonrpcError, 1) + go func() { + _, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- rpcErr + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Dispose without registering the session. + dispose() + + select { + case rpcErr := <-resultCh: + if rpcErr == nil { + t.Fatal("expected an rpc error after dispose without registration") + } + if !strings.Contains(rpcErr.Message, "dropped") { + t.Errorf("expected 'dropped' in error message, got: %s", rpcErr.Message) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for rejected waiter") + } +} + +// jsonrpcError is a local alias for jsonrpc2.Error used in test assertions. +type jsonrpcError = jsonrpc2.Error From 4c61c9558b517eafeab6f5d9077347176883f374 Mon Sep 17 00:00:00 2001 From: Timothy Clem Date: Fri, 22 May 2026 19:52:03 -0700 Subject: [PATCH 2/2] Go SDK: emit JSON-RPC error on pending-buffer overflow + guard drop Carries the Rust SDK PR #1394 follow-up review fixes into the Go port: 1. Cap the per-session parked-waiter list at 128. When exceeded, reject the oldest waiter with errPendingSessionBufferOverflow ('pending session buffer overflow'). The handler returns a *jsonrpc2.Error with code -32603 via the new pendingRoutingRPCError helper, so the runtime receives a proper error response instead of hanging on the request id until its own timeout. Mirrors Rust commit 491b4427 and TS commit c167bc3e. 2. When the last pending-routing guard drops without RegisterSession (e.g. session.create failed mid-RPC), signal all parked waiters with errPendingSessionRoutingEnded ('pending session routing ended before session was registered'). Distinct phrasing from the overflow path so debugging can tell the two cases apart. Mirrors Rust commit e0ff254f and TS commit c167bc3e. Adds pendingRoutingRPCError helper that routes sentinel errors to -32603 while unknown-session errors keep -32602. Adds two tests: - TestPendingRouting_OverflowEmitsError: 129 parked waiters, oldest gets -32603 overflow error, remaining 128 resolve normally after registration. - TestPendingRouting_GuardDropDistinctMessage: parks a request, drops the guard without registration, verifies exact routing-ended message and -32603 code. Updates TestPendingRouting_RejectsWaitersOnDispose to assert the new exact message and code instead of the old 'dropped' substring check. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- go/client.go | 37 ++++++++--- go/cloud_session_test.go | 129 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 156 insertions(+), 10 deletions(-) diff --git a/go/client.go b/go/client.go index 0f5db2419..6ba816b59 100644 --- a/go/client.go +++ b/go/client.go @@ -88,6 +88,13 @@ func validateSessionFsConfig(config *SessionFsConfig) error { // } // defer client.Stop() // +// Sentinel errors for the two pending-routing rejection paths. Using distinct +// values lets callers (and debugging) tell overflow eviction from guard-drop. +var ( + errPendingSessionBufferOverflow = errors.New("pending session buffer overflow") + errPendingSessionRoutingEnded = errors.New("pending session routing ended before session was registered") +) + // pendingResult carries the outcome of a parked inbound-request session lookup. type pendingResult struct { session *Session @@ -1021,7 +1028,7 @@ func (c *Client) beginPendingSessionRouting() func() { for _, chs := range waiters { for _, ch := range chs { - ch <- pendingResult{err: fmt.Errorf("request dropped: cloud session.create completed without registering this session id")} + ch <- pendingResult{err: errPendingSessionRoutingEnded} } } }) @@ -1076,10 +1083,12 @@ func (c *Client) waitForSession(sessionID string) (*Session, error) { ch := make(chan pendingResult, 1) waiters := c.pending.waiters[sessionID] if len(waiters) >= pendingSessionBufferLimit { - // Reject the oldest waiter to keep the queue bounded. + // Reject the oldest waiter to keep the queue bounded. Send a JSON-RPC + // error response via the handler return so the runtime doesn't hang on + // the request id waiting for a reply that would never come. oldest := waiters[0] waiters = waiters[1:] - oldest <- pendingResult{err: fmt.Errorf("request dropped: pending session waiter buffer full for %s", sessionID)} + oldest <- pendingResult{err: errPendingSessionBufferOverflow} } c.pending.waiters[sessionID] = append(waiters, ch) c.pending.mu.Unlock() @@ -2126,6 +2135,18 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) { c.pending.mu.Unlock() } +// pendingRoutingRPCError maps an error from waitForSession to the appropriate +// JSON-RPC error. Overflow and guard-drop rejections use -32603 (internal +// error) so the runtime gets a proper error response instead of hanging on the +// request id. All other waitForSession errors (e.g. unknown session) keep the +// existing -32602 (invalid params) code. +func pendingRoutingRPCError(err error) *jsonrpc2.Error { + if errors.Is(err, errPendingSessionBufferOverflow) || errors.Is(err, errPendingSessionRoutingEnded) { + return &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return &jsonrpc2.Error{Code: -32602, Message: err.Error()} +} + // handleUserInputRequest handles a user input request from the CLI server. func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) { if req.SessionID == "" || req.Question == "" { @@ -2134,7 +2155,7 @@ func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputRespons session, err := c.waitForSession(req.SessionID) if err != nil { - return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} + return nil, pendingRoutingRPCError(err) } response, err := session.handleUserInputRequest(UserInputRequest{ @@ -2161,7 +2182,7 @@ func (c *Client) handleExitPlanModeRequest(req exitPlanModeRequest) (*ExitPlanMo session, err := c.waitForSession(req.SessionID) if err != nil { - return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} + return nil, pendingRoutingRPCError(err) } response, err := session.handleExitPlanModeRequest(ExitPlanModeRequest{ @@ -2185,7 +2206,7 @@ func (c *Client) handleAutoModeSwitchRequest(req autoModeSwitchRequest) (*autoMo session, err := c.waitForSession(req.SessionID) if err != nil { - return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} + return nil, pendingRoutingRPCError(err) } response, err := session.handleAutoModeSwitchRequest(AutoModeSwitchRequest{ @@ -2207,7 +2228,7 @@ func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jso session, err := c.waitForSession(req.SessionID) if err != nil { - return nil, &jsonrpc2.Error{Code: -32602, Message: err.Error()} + return nil, pendingRoutingRPCError(err) } output, err := session.handleHooksInvoke(req.Type, req.Input) @@ -2230,7 +2251,7 @@ func (c *Client) handleSystemMessageTransform(req systemMessageTransformRequest) session, err := c.waitForSession(req.SessionID) if err != nil { - return systemMessageTransformResponse{}, &jsonrpc2.Error{Code: -32602, Message: err.Error()} + return systemMessageTransformResponse{}, pendingRoutingRPCError(err) } resp, err := session.handleSystemMessageTransform(req.Sections) diff --git a/go/cloud_session_test.go b/go/cloud_session_test.go index ce827dbb0..63c66ab78 100644 --- a/go/cloud_session_test.go +++ b/go/cloud_session_test.go @@ -317,8 +317,133 @@ func TestPendingRouting_RejectsWaitersOnDispose(t *testing.T) { if rpcErr == nil { t.Fatal("expected an rpc error after dispose without registration") } - if !strings.Contains(rpcErr.Message, "dropped") { - t.Errorf("expected 'dropped' in error message, got: %s", rpcErr.Message) + if !strings.Contains(rpcErr.Message, "routing ended before session was registered") { + t.Errorf("expected routing-ended message, got: %s", rpcErr.Message) + } + if rpcErr.Code != -32603 { + t.Errorf("expected code -32603, got: %d", rpcErr.Code) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for rejected waiter") + } +} + +// TestPendingRouting_OverflowEmitsError verifies that when the parked-waiter +// buffer reaches its cap, the oldest waiter receives the overflow error response +// and the remaining 128 waiters resolve normally after registration. +func TestPendingRouting_OverflowEmitsError(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "overflow-request-session" + const total = pendingSessionBufferLimit + 1 // 129 + + type result struct { + resp *userInputResponse + err *jsonrpcError + } + + // Register a user-input handler so the session resolves successfully. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + session.registerUserInputHandler(func(req UserInputRequest, _ UserInputInvocation) (UserInputResponse, error) { + return UserInputResponse{Answer: "yes"}, nil + }) + + results := make([]chan result, total) + for i := range total { + results[i] = make(chan result, 1) + go func(ch chan result) { + resp, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + ch <- result{resp, rpcErr} + }(results[i]) + } + + // Give goroutines time to park. + time.Sleep(50 * time.Millisecond) + + // Register the session and flush — this resolves the 128 remaining waiters. + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + client.flushPendingForSession(pendingID, session) + dispose() + + // Collect all results with a timeout. + var gotOverflow int + var gotSuccess int + deadline := time.After(2 * time.Second) + for _, ch := range results { + select { + case r := <-ch: + if r.err != nil { + if !strings.Contains(r.err.Message, "pending session buffer overflow") { + t.Errorf("unexpected error message: %s", r.err.Message) + } + if r.err.Code != -32603 { + t.Errorf("expected code -32603 for overflow, got: %d", r.err.Code) + } + gotOverflow++ + } else { + gotSuccess++ + } + case <-deadline: + t.Fatalf("timed out: overflow=%d success=%d", gotOverflow, gotSuccess) + } + } + + if gotOverflow != 1 { + t.Errorf("expected exactly 1 overflow rejection, got %d", gotOverflow) + } + if gotSuccess != pendingSessionBufferLimit { + t.Errorf("expected %d successful resolutions, got %d", pendingSessionBufferLimit, gotSuccess) + } +} + +// TestPendingRouting_GuardDropDistinctMessage verifies that when the last +// pending-routing guard drops without registration, parked waiters receive the +// distinct routing-ended error (not the overflow message) so the two paths are +// distinguishable in logs and debugging. +func TestPendingRouting_GuardDropDistinctMessage(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "guard-drop-session" + + resultCh := make(chan *jsonrpcError, 1) + go func() { + _, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- rpcErr + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Drop the guard without registering — simulates session.create failing. + dispose() + + select { + case rpcErr := <-resultCh: + if rpcErr == nil { + t.Fatal("expected an rpc error after guard drop without registration") + } + const want = "pending session routing ended before session was registered" + if rpcErr.Message != want { + t.Errorf("expected exact message %q, got %q", want, rpcErr.Message) + } + if rpcErr.Code != -32603 { + t.Errorf("expected code -32603, got: %d", rpcErr.Code) + } + // Must NOT contain the overflow message. + if strings.Contains(rpcErr.Message, "buffer overflow") { + t.Errorf("guard-drop path must not use overflow message, got: %s", rpcErr.Message) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for rejected waiter")