diff --git a/go/README.md b/go/README.md index da77033f8..62f3b9198 100644 --- a/go/README.md +++ b/go/README.md @@ -105,6 +105,7 @@ That's it! When your application calls `copilot.NewClient` without a `Connection - `Stop() error` - Stop the CLI server - `ForceStop()` - Forcefully stop without graceful cleanup - `CreateSession(config *SessionConfig) (*Session, error)` - Create a new session +- `CreateCloudSession(ctx context.Context, config *SessionConfig) (*Session, error)` - Create a Mission Control–backed cloud session - `ResumeSession(sessionID string, config *ResumeSessionConfig) (*Session, error)` - Resume an existing session - `ResumeSessionWithOptions(sessionID string, config *ResumeSessionConfig) (*Session, error)` - Resume with additional configuration - `ListSessions(filter *SessionListFilter) ([]SessionMetadata, error)` - List sessions (with optional filter) @@ -170,6 +171,8 @@ Event types: `SessionLifecycleCreated`, `SessionLifecycleDeleted`, `SessionLifec - `Commands` ([]CommandDefinition): Slash-commands registered for this session. See [Commands](#commands) section. - `OnElicitationRequest` (ElicitationHandler): Handler for elicitation requests from the server. See [Elicitation Requests](#elicitation-requests-serverclient) section. +- `Cloud` (\*CloudSessionOptions): Cloud session configuration. When set, `CreateSession` rejects the config; use `CreateCloudSession` instead. Do not set `SessionID` or `Provider` when using cloud sessions. + **ResumeSessionConfig:** - `OnPermissionRequest` (PermissionHandlerFunc): Optional handler called before each tool execution to approve or deny it. See [Permission Handling](#permission-handling) section. @@ -487,6 +490,27 @@ When enabled, sessions emit compaction events: - `session.compaction_start` - Background compaction started - `session.compaction_complete` - Compaction finished (includes token counts) +## Cloud Sessions + +`CreateCloudSession` creates a Mission Control–backed cloud session. The runtime assigns the session ID; do not set `SessionID` or `Provider` on the config (the SDK rejects both). `CreateSession` also rejects any config that has `Cloud` set. + +Any `session.event` notifications or inbound JSON-RPC requests that arrive between sending `session.create` and receiving its response are buffered (bounded, drop-oldest, limit 128 per id) and replayed once the runtime-assigned session ID is registered. + +```go +session, err := client.CreateCloudSession(context.Background(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Cloud: &copilot.CloudSessionOptions{ + Repository: &copilot.CloudSessionRepository{ + Owner: "github", Name: "copilot-sdk", + }, + }, +}) +if err != nil { + log.Fatal(err) +} +fmt.Println("cloud session id:", session.SessionID) +``` + ## Custom Providers The SDK supports custom OpenAI-compatible API providers (BYOK - Bring Your Own Key), including local providers like Ollama. When using a custom provider, you must specify the `Model` explicitly. diff --git a/go/client.go b/go/client.go index 9491eb199..6ba816b59 100644 --- a/go/client.go +++ b/go/client.go @@ -87,6 +87,34 @@ func validateSessionFsConfig(config *SessionFsConfig) error { // log.Fatal(err) // } // defer client.Stop() +// +// Sentinel errors for the two pending-routing rejection paths. Using distinct +// values lets callers (and debugging) tell overflow eviction from guard-drop. +var ( + errPendingSessionBufferOverflow = errors.New("pending session buffer overflow") + errPendingSessionRoutingEnded = errors.New("pending session routing ended before session was registered") +) + +// pendingResult carries the outcome of a parked inbound-request session lookup. +type pendingResult struct { + session *Session + err error +} + +// pendingRouting buffers session.event notifications and parks inbound request +// handlers that arrive before a cloud session.create response is received. +// A refcount tracks how many cloud creates are in flight; when it reaches zero +// the buffers are cleared and parked handlers are rejected. +type pendingRouting struct { + mu sync.Mutex + count int + events map[string][]sessionEventRequest + waiters map[string][]chan pendingResult +} + +// pendingSessionBufferLimit caps buffered notifications per in-flight session id. +const pendingSessionBufferLimit = 128 + type Client struct { options ClientOptions process *exec.Cmd @@ -121,6 +149,10 @@ type Client struct { effectiveConnectionToken string onListModels func(ctx context.Context) ([]ModelInfo, error) + // pending buffers traffic that arrives between session.create being sent and + // the response for cloud sessions. + pending pendingRouting + // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). RPC *rpc.ServerRpc @@ -162,6 +194,10 @@ func NewClient(options *ClientOptions) *Client { actualHost: "localhost", isExternalServer: false, useStdio: true, + pending: pendingRouting{ + events: make(map[string][]sessionEventRequest), + waiters: make(map[string][]chan pendingResult), + }, } if options != nil { @@ -593,6 +629,10 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses config = &SessionConfig{} } + if config.Cloud != nil { + return nil, fmt.Errorf("CreateSession does not support cloud sessions; use CreateCloudSession instead") + } + if err := c.ensureConnected(ctx); err != nil { return nil, err } @@ -754,6 +794,309 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses return session, nil } +// CreateCloudSession creates a Mission Control–backed cloud session. +// +// The runtime owns the session ID for cloud sessions: do not set SessionID or +// Provider on the config (the SDK rejects both). Build the config with Cloud +// set to a [CloudSessionOptions] value; [Client.CreateSession] rejects any +// config that has Cloud set. +// +// Any session.event notifications or inbound JSON-RPC requests that arrive +// between sending session.create and receiving its response are buffered +// (bounded, drop-oldest, limit 128 per id) and replayed once the +// runtime-assigned session ID is registered. +// +// Known limitation: inbound sessionFs.* requests from the generated +// client-session API handlers are not pending-buffered. In practice the +// runtime does not initiate sessionFs.* calls before the session.create +// response, so this is theoretical. +// +// Example: +// +// session, err := client.CreateCloudSession(context.Background(), &copilot.SessionConfig{ +// OnPermissionRequest: copilot.PermissionHandler.ApproveAll, +// Cloud: &copilot.CloudSessionOptions{ +// Repository: &copilot.CloudSessionRepository{ +// Owner: "github", Name: "copilot-sdk", +// }, +// }, +// }) +func (c *Client) CreateCloudSession(ctx context.Context, config *SessionConfig) (*Session, error) { + if config == nil { + config = &SessionConfig{} + } + + if config.Cloud == nil { + return nil, fmt.Errorf("CreateCloudSession requires config.Cloud to be set") + } + if config.SessionID != "" { + return nil, fmt.Errorf("CreateCloudSession does not accept a caller-provided SessionID; the runtime assigns one") + } + if config.Provider != nil { + return nil, fmt.Errorf("CreateCloudSession does not accept config.Provider; cloud sessions use the runtime's provider") + } + + if err := c.ensureConnected(ctx); err != nil { + return nil, err + } + + req := createSessionRequest{} + req.Model = config.Model + req.ClientName = config.ClientName + req.ReasoningEffort = config.ReasoningEffort + req.ConfigDir = config.ConfigDir + if config.EnableConfigDiscovery { + req.EnableConfigDiscovery = Bool(true) + } + req.Tools = config.Tools + wireSystemMessage, transformCallbacks := extractTransformCallbacks(config.SystemMessage) + req.SystemMessage = wireSystemMessage + req.AvailableTools = config.AvailableTools + req.ExcludedTools = config.ExcludedTools + req.EnableSessionTelemetry = config.EnableSessionTelemetry + req.ModelCapabilities = config.ModelCapabilities + req.WorkingDirectory = config.WorkingDirectory + req.MCPServers = config.MCPServers + req.EnvValueMode = "direct" + req.CustomAgents = config.CustomAgents + req.DefaultAgent = config.DefaultAgent + req.Agent = config.Agent + req.SkillDirectories = config.SkillDirectories + req.InstructionDirectories = config.InstructionDirectories + req.DisabledSkills = config.DisabledSkills + req.InfiniteSessions = config.InfiniteSessions + req.GitHubToken = config.GitHubToken + req.RemoteSession = config.RemoteSession + req.Cloud = config.Cloud + // SessionID intentionally omitted: the runtime assigns the id for cloud sessions. + + if len(config.Commands) > 0 { + cmds := make([]wireCommand, 0, len(config.Commands)) + for _, cmd := range config.Commands { + cmds = append(cmds, wireCommand{Name: cmd.Name, Description: cmd.Description}) + } + req.Commands = cmds + } + if config.OnElicitationRequest != nil { + req.RequestElicitation = Bool(true) + } + if config.OnExitPlanModeRequest != nil { + req.RequestExitPlanMode = Bool(true) + } + if config.OnAutoModeSwitchRequest != nil { + req.RequestAutoModeSwitch = Bool(true) + } + if config.Streaming != nil { + req.Streaming = config.Streaming + } + if config.IncludeSubAgentStreamingEvents != nil { + req.IncludeSubAgentStreamingEvents = config.IncludeSubAgentStreamingEvents + } else { + req.IncludeSubAgentStreamingEvents = Bool(true) + } + if config.OnUserInputRequest != nil { + req.RequestUserInput = Bool(true) + } + if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || + config.Hooks.OnPreMcpToolCall != nil || + config.Hooks.OnPostToolUse != nil || + config.Hooks.OnUserPromptSubmitted != nil || + config.Hooks.OnSessionStart != nil || + config.Hooks.OnSessionEnd != nil || + config.Hooks.OnErrorOccurred != nil) { + req.Hooks = Bool(true) + } + if config.OnPermissionRequest != nil { + req.RequestPermission = Bool(true) + } + + traceparent, tracestate := getTraceContext(ctx) + req.Traceparent = traceparent + req.Tracestate = tracestate + + dispose := c.beginPendingSessionRouting() + + result, err := c.client.Request("session.create", req) + if err != nil { + dispose() + return nil, fmt.Errorf("failed to create cloud session: %w", err) + } + + var response createSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + dispose() + return nil, fmt.Errorf("failed to unmarshal cloud session response: %w", err) + } + + if response.SessionID == "" { + fmt.Println("warning: cloud session.create response missing sessionId; runtime session may leak") + dispose() + return nil, fmt.Errorf("cloud session.create response did not include a sessionId; cannot register session") + } + + sessionID := response.SessionID + session := newSession(sessionID, c.client, response.WorkspacePath) + + session.registerTools(config.Tools) + session.registerPermissionHandler(config.OnPermissionRequest) + if config.OnUserInputRequest != nil { + session.registerUserInputHandler(config.OnUserInputRequest) + } + if config.Hooks != nil { + session.registerHooks(config.Hooks) + } + if transformCallbacks != nil { + session.registerTransformCallbacks(transformCallbacks) + } + if config.OnEvent != nil { + session.On(config.OnEvent) + } + if len(config.Commands) > 0 { + session.registerCommands(config.Commands) + } + if config.OnElicitationRequest != nil { + session.registerElicitationHandler(config.OnElicitationRequest) + } + if config.OnExitPlanModeRequest != nil { + session.registerExitPlanModeHandler(config.OnExitPlanModeRequest) + } + if config.OnAutoModeSwitchRequest != nil { + session.registerAutoModeSwitchHandler(config.OnAutoModeSwitchRequest) + } + session.setCapabilities(response.Capabilities) + + c.sessionsMux.Lock() + c.sessions[sessionID] = session + c.sessionsMux.Unlock() + + if c.options.SessionFs != nil { + if config.CreateSessionFsProvider == nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + dispose() + return nil, fmt.Errorf("CreateSessionFsProvider is required in session config when SessionFs is enabled in client options") + } + provider := config.CreateSessionFsProvider(session) + if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { + if _, ok := provider.(SessionFsSqliteProvider); !ok { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + dispose() + return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") + } + } + session.clientSessionApis.SessionFs = newSessionFsAdapter(provider) + } + + // Drain buffered events and unblock parked request handlers before + // releasing the guard, so they see a fully-wired session. + c.flushPendingForSession(sessionID, session) + dispose() + + return session, nil +} + +// beginPendingSessionRouting increments the pending-routing refcount and +// returns a disposer. While any disposer is undisposed, session.event +// notifications and inbound JSON-RPC requests addressed to unknown session ids +// are buffered/parked rather than dropped. When the last disposer fires, any +// remaining buffers are cleared and parked handlers receive an error so they +// don't block forever. +func (c *Client) beginPendingSessionRouting() func() { + c.pending.mu.Lock() + c.pending.count++ + c.pending.mu.Unlock() + + var once sync.Once + return func() { + once.Do(func() { + c.pending.mu.Lock() + c.pending.count-- + if c.pending.count > 0 { + c.pending.mu.Unlock() + return + } + // Last guard: swap out the maps so we can signal waiters without + // holding the lock (buffered channels make sends non-blocking, but + // releasing the lock is cleaner). + c.pending.events = make(map[string][]sessionEventRequest) + waiters := c.pending.waiters + c.pending.waiters = make(map[string][]chan pendingResult) + c.pending.mu.Unlock() + + for _, chs := range waiters { + for _, ch := range chs { + ch <- pendingResult{err: errPendingSessionRoutingEnded} + } + } + }) + } +} + +// flushPendingForSession drains buffered events and resolves parked request +// handlers for sessionID into the freshly-registered session. Called from +// CreateCloudSession after the session is in c.sessions and before the pending +// guard is released. +func (c *Client) flushPendingForSession(sessionID string, session *Session) { + c.pending.mu.Lock() + events := c.pending.events[sessionID] + delete(c.pending.events, sessionID) + waiters := c.pending.waiters[sessionID] + delete(c.pending.waiters, sessionID) + c.pending.mu.Unlock() + + for _, req := range events { + session.dispatchEvent(req.Event) + } + for _, ch := range waiters { + ch <- pendingResult{session: session} + } +} + +// waitForSession looks up the session by id. If the session is not yet +// registered but pending routing is active, the call parks until the session +// is registered (or pending routing ends without registration). +func (c *Client) waitForSession(sessionID string) (*Session, error) { + c.sessionsMux.Lock() + session, ok := c.sessions[sessionID] + c.sessionsMux.Unlock() + if ok { + return session, nil + } + + c.pending.mu.Lock() + if c.pending.count == 0 { + c.pending.mu.Unlock() + return nil, fmt.Errorf("unknown session %s", sessionID) + } + // Re-check under pending.mu: the session may have been registered and + // flushed between the first lookup and acquiring this lock. + c.sessionsMux.Lock() + session, ok = c.sessions[sessionID] + c.sessionsMux.Unlock() + if ok { + c.pending.mu.Unlock() + return session, nil + } + ch := make(chan pendingResult, 1) + waiters := c.pending.waiters[sessionID] + if len(waiters) >= pendingSessionBufferLimit { + // Reject the oldest waiter to keep the queue bounded. Send a JSON-RPC + // error response via the handler return so the runtime doesn't hang on + // the request id waiting for a reply that would never come. + oldest := waiters[0] + waiters = waiters[1:] + oldest <- pendingResult{err: errPendingSessionBufferOverflow} + } + c.pending.waiters[sessionID] = append(waiters, ch) + c.pending.mu.Unlock() + + result := <-ch + return result.session, result.err +} + // ResumeSession resumes an existing conversation session by its ID. // // This is a convenience method that calls [Client.ResumeSessionWithOptions]. @@ -1761,14 +2104,47 @@ func (c *Client) handleSessionEvent(req sessionEventRequest) { if req.SessionID == "" { return } - // Dispatch to session c.sessionsMux.Lock() session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if ok { session.dispatchEvent(req.Event) + return + } + + // Buffer if a cloud session.create is in flight for this id. + c.pending.mu.Lock() + if c.pending.count > 0 { + // Re-check under pending.mu: the session may have been registered and + // flushed between the first lookup and acquiring this lock. + c.sessionsMux.Lock() + session, ok = c.sessions[req.SessionID] + c.sessionsMux.Unlock() + if ok { + c.pending.mu.Unlock() + session.dispatchEvent(req.Event) + return + } + buf := c.pending.events[req.SessionID] + if len(buf) >= pendingSessionBufferLimit { + buf = buf[1:] // drop oldest + } + c.pending.events[req.SessionID] = append(buf, req) + } + c.pending.mu.Unlock() +} + +// pendingRoutingRPCError maps an error from waitForSession to the appropriate +// JSON-RPC error. Overflow and guard-drop rejections use -32603 (internal +// error) so the runtime gets a proper error response instead of hanging on the +// request id. All other waitForSession errors (e.g. unknown session) keep the +// existing -32602 (invalid params) code. +func pendingRoutingRPCError(err error) *jsonrpc2.Error { + if errors.Is(err, errPendingSessionBufferOverflow) || errors.Is(err, errPendingSessionRoutingEnded) { + return &jsonrpc2.Error{Code: -32603, Message: err.Error()} } + return &jsonrpc2.Error{Code: -32602, Message: err.Error()} } // handleUserInputRequest handles a user input request from the CLI server. @@ -1777,11 +2153,9 @@ func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputRespons return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, pendingRoutingRPCError(err) } response, err := session.handleUserInputRequest(UserInputRequest{ @@ -1806,11 +2180,9 @@ func (c *Client) handleExitPlanModeRequest(req exitPlanModeRequest) (*ExitPlanMo recommendedAction = "autopilot" } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, pendingRoutingRPCError(err) } response, err := session.handleExitPlanModeRequest(ExitPlanModeRequest{ @@ -1832,11 +2204,9 @@ func (c *Client) handleAutoModeSwitchRequest(req autoModeSwitchRequest) (*autoMo return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid auto mode switch request payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, pendingRoutingRPCError(err) } response, err := session.handleAutoModeSwitchRequest(AutoModeSwitchRequest{ @@ -1856,11 +2226,9 @@ func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jso return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return nil, pendingRoutingRPCError(err) } output, err := session.handleHooksInvoke(req.Type, req.Input) @@ -1881,11 +2249,9 @@ func (c *Client) handleSystemMessageTransform(req systemMessageTransformRequest) return systemMessageTransformResponse{}, &jsonrpc2.Error{Code: -32602, Message: "invalid system message transform payload"} } - c.sessionsMux.Lock() - session, ok := c.sessions[req.SessionID] - c.sessionsMux.Unlock() - if !ok { - return systemMessageTransformResponse{}, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} + session, err := c.waitForSession(req.SessionID) + if err != nil { + return systemMessageTransformResponse{}, pendingRoutingRPCError(err) } resp, err := session.handleSystemMessageTransform(req.Sections) diff --git a/go/cloud_session_test.go b/go/cloud_session_test.go new file mode 100644 index 000000000..63c66ab78 --- /dev/null +++ b/go/cloud_session_test.go @@ -0,0 +1,454 @@ +package copilot + +import ( + "encoding/json" + "strings" + "sync" + "testing" + "time" + + "github.com/github/copilot-sdk/go/internal/jsonrpc2" +) + +// newCloudTestClient returns a Client with pending routing initialized and a +// pre-populated sessions map, suitable for unit-testing cloud session logic +// without a real network connection. +func newCloudTestClient() *Client { + return &Client{ + sessions: make(map[string]*Session), + pending: pendingRouting{ + events: make(map[string][]sessionEventRequest), + waiters: make(map[string][]chan pendingResult), + }, + } +} + +// TestCreateSession_RejectsCloudConfig verifies that CreateSession returns a +// clear error when config.Cloud is set. +func TestCreateSession_RejectsCloudConfig(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + }) + if err == nil { + t.Fatal("expected error when cloud config is set") + } + if !strings.Contains(err.Error(), "CreateCloudSession") { + t.Errorf("error should mention CreateCloudSession, got: %v", err) + } +} + +// TestCreateCloudSession_RejectsCallerSessionID verifies the SDK rejects a +// caller-supplied SessionID. +func TestCreateCloudSession_RejectsCallerSessionID(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + SessionID: "caller-supplied-id", + }) + if err == nil { + t.Fatal("expected error when SessionID is set") + } + if !strings.Contains(err.Error(), "SessionID") { + t.Errorf("error should mention SessionID, got: %v", err) + } +} + +// TestCreateCloudSession_RejectsCallerProvider verifies the SDK rejects a +// caller-supplied Provider. +func TestCreateCloudSession_RejectsCallerProvider(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{ + Cloud: &CloudSessionOptions{}, + Provider: &ProviderConfig{ModelID: "gpt-4"}, + }) + if err == nil { + t.Fatal("expected error when Provider is set") + } + if !strings.Contains(err.Error(), "Provider") { + t.Errorf("error should mention Provider, got: %v", err) + } +} + +// TestCreateCloudSession_RequiresCloud verifies the SDK rejects configs without +// Cloud set. +func TestCreateCloudSession_RequiresCloud(t *testing.T) { + client := NewClient(&ClientOptions{Connection: StdioConnection{Path: "/__nonexistent__"}}) + _, err := client.CreateCloudSession(t.Context(), &SessionConfig{}) + if err == nil { + t.Fatal("expected error when Cloud is nil") + } + if !strings.Contains(err.Error(), "Cloud") { + t.Errorf("error should mention Cloud, got: %v", err) + } +} + +// TestCreateCloudSession_WirePayload verifies that the session.create wire +// payload includes the cloud field and omits sessionId when built by the cloud +// path. +func TestCreateCloudSession_WirePayload(t *testing.T) { + req := createSessionRequest{ + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk"}, + }, + // SessionID intentionally left empty + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal error: %v", err) + } + + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if _, ok := m["sessionId"]; ok { + t.Error("sessionId must be omitted from the cloud session.create wire payload") + } + + cloud, ok := m["cloud"] + if !ok { + t.Fatal("cloud field must be present in the wire payload") + } + cloudMap, ok := cloud.(map[string]any) + if !ok { + t.Fatalf("cloud field should be a map, got %T", cloud) + } + repo, ok := cloudMap["repository"].(map[string]any) + if !ok { + t.Fatal("cloud.repository should be a map") + } + if repo["owner"] != "github" || repo["name"] != "copilot-sdk" { + t.Errorf("unexpected cloud.repository: %v", repo) + } +} + +// TestPendingRouting_BuffersEarlyNotifications verifies that session.event +// notifications arriving before the session is registered are buffered and +// replayed when flushPendingForSession is called. +func TestPendingRouting_BuffersEarlyNotifications(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + defer dispose() + + const pendingID = "runtime-assigned-id" + + // Simulate two session.event notifications arriving before the session is + // registered. + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{Data: &SessionIdleData{}}, + }) + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{Data: &SessionIdleData{}}, + }) + + // Verify they are buffered. + client.pending.mu.Lock() + bufLen := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + if bufLen != 2 { + t.Fatalf("expected 2 buffered events, got %d", bufLen) + } + + // Now register the session and flush. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + + var received []SessionEvent + var mu sync.Mutex + var wg sync.WaitGroup + wg.Add(2) + session.On(func(event SessionEvent) { + mu.Lock() + received = append(received, event) + mu.Unlock() + wg.Done() + }) + + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + + client.flushPendingForSession(pendingID, session) + + // Wait for the event handler goroutine to process. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for buffered events to be dispatched") + } + + mu.Lock() + got := len(received) + mu.Unlock() + if got != 2 { + t.Errorf("expected 2 events replayed, got %d", got) + } + + // Buffer should be cleared after flush. + client.pending.mu.Lock() + remaining := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + if remaining != 0 { + t.Errorf("buffer should be empty after flush, got %d", remaining) + } +} + +// TestPendingRouting_ParksInboundRequests verifies that inbound request handlers +// (e.g. userInput.request) park until the session is registered when pending +// routing is active. +func TestPendingRouting_ParksInboundRequests(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "runtime-assigned-id-2" + + // Launch a goroutine that simulates an inbound userInput.request arriving + // before the session is registered. + type result struct { + resp *userInputResponse + err *jsonrpcError + } + resultCh := make(chan result, 1) + go func() { + resp, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- result{resp, rpcErr} + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Register the session. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + session.registerUserInputHandler(func(req UserInputRequest, _ UserInputInvocation) (UserInputResponse, error) { + return UserInputResponse{Answer: "yes"}, nil + }) + + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + + client.flushPendingForSession(pendingID, session) + dispose() + + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("expected success, got rpc error: %v", r.err) + } + if r.resp == nil || r.resp.Answer != "yes" { + t.Errorf("expected answer 'yes', got %+v", r.resp) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for parked request to be resolved") + } +} + +// TestPendingRouting_DropOldestWhenBufferFull verifies drop-oldest behaviour +// when the notification buffer is full. +func TestPendingRouting_DropOldestWhenBufferFull(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + defer dispose() + + const pendingID = "overflow-session" + + // Fill buffer beyond the limit. + for i := range pendingSessionBufferLimit + 5 { + client.handleSessionEvent(sessionEventRequest{ + SessionID: pendingID, + Event: SessionEvent{ + // Embed the index so we can verify drop-oldest. + Data: &SessionIdleData{}, + }, + }) + _ = i + } + + client.pending.mu.Lock() + bufLen := len(client.pending.events[pendingID]) + client.pending.mu.Unlock() + + if bufLen != pendingSessionBufferLimit { + t.Errorf("expected buffer capped at %d, got %d", pendingSessionBufferLimit, bufLen) + } +} + +// TestPendingRouting_RejectsWaitersOnDispose verifies that waiters are +// rejected with an error when pending mode ends without registration. +func TestPendingRouting_RejectsWaitersOnDispose(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "never-registered" + + resultCh := make(chan *jsonrpcError, 1) + go func() { + _, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- rpcErr + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Dispose without registering the session. + dispose() + + select { + case rpcErr := <-resultCh: + if rpcErr == nil { + t.Fatal("expected an rpc error after dispose without registration") + } + if !strings.Contains(rpcErr.Message, "routing ended before session was registered") { + t.Errorf("expected routing-ended message, got: %s", rpcErr.Message) + } + if rpcErr.Code != -32603 { + t.Errorf("expected code -32603, got: %d", rpcErr.Code) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for rejected waiter") + } +} + +// TestPendingRouting_OverflowEmitsError verifies that when the parked-waiter +// buffer reaches its cap, the oldest waiter receives the overflow error response +// and the remaining 128 waiters resolve normally after registration. +func TestPendingRouting_OverflowEmitsError(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "overflow-request-session" + const total = pendingSessionBufferLimit + 1 // 129 + + type result struct { + resp *userInputResponse + err *jsonrpcError + } + + // Register a user-input handler so the session resolves successfully. + session, cleanup := newTestSession() + defer cleanup() + session.SessionID = pendingID + session.registerUserInputHandler(func(req UserInputRequest, _ UserInputInvocation) (UserInputResponse, error) { + return UserInputResponse{Answer: "yes"}, nil + }) + + results := make([]chan result, total) + for i := range total { + results[i] = make(chan result, 1) + go func(ch chan result) { + resp, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + ch <- result{resp, rpcErr} + }(results[i]) + } + + // Give goroutines time to park. + time.Sleep(50 * time.Millisecond) + + // Register the session and flush — this resolves the 128 remaining waiters. + client.sessionsMux.Lock() + client.sessions[pendingID] = session + client.sessionsMux.Unlock() + client.flushPendingForSession(pendingID, session) + dispose() + + // Collect all results with a timeout. + var gotOverflow int + var gotSuccess int + deadline := time.After(2 * time.Second) + for _, ch := range results { + select { + case r := <-ch: + if r.err != nil { + if !strings.Contains(r.err.Message, "pending session buffer overflow") { + t.Errorf("unexpected error message: %s", r.err.Message) + } + if r.err.Code != -32603 { + t.Errorf("expected code -32603 for overflow, got: %d", r.err.Code) + } + gotOverflow++ + } else { + gotSuccess++ + } + case <-deadline: + t.Fatalf("timed out: overflow=%d success=%d", gotOverflow, gotSuccess) + } + } + + if gotOverflow != 1 { + t.Errorf("expected exactly 1 overflow rejection, got %d", gotOverflow) + } + if gotSuccess != pendingSessionBufferLimit { + t.Errorf("expected %d successful resolutions, got %d", pendingSessionBufferLimit, gotSuccess) + } +} + +// TestPendingRouting_GuardDropDistinctMessage verifies that when the last +// pending-routing guard drops without registration, parked waiters receive the +// distinct routing-ended error (not the overflow message) so the two paths are +// distinguishable in logs and debugging. +func TestPendingRouting_GuardDropDistinctMessage(t *testing.T) { + client := newCloudTestClient() + dispose := client.beginPendingSessionRouting() + + const pendingID = "guard-drop-session" + + resultCh := make(chan *jsonrpcError, 1) + go func() { + _, rpcErr := client.handleUserInputRequest(userInputRequest{ + SessionID: pendingID, + Question: "Proceed?", + }) + resultCh <- rpcErr + }() + + // Give the goroutine time to park. + time.Sleep(20 * time.Millisecond) + + // Drop the guard without registering — simulates session.create failing. + dispose() + + select { + case rpcErr := <-resultCh: + if rpcErr == nil { + t.Fatal("expected an rpc error after guard drop without registration") + } + const want = "pending session routing ended before session was registered" + if rpcErr.Message != want { + t.Errorf("expected exact message %q, got %q", want, rpcErr.Message) + } + if rpcErr.Code != -32603 { + t.Errorf("expected code -32603, got: %d", rpcErr.Code) + } + // Must NOT contain the overflow message. + if strings.Contains(rpcErr.Message, "buffer overflow") { + t.Errorf("guard-drop path must not use overflow message, got: %s", rpcErr.Message) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for rejected waiter") + } +} + +// jsonrpcError is a local alias for jsonrpc2.Error used in test assertions. +type jsonrpcError = jsonrpc2.Error