Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 76 additions & 50 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,23 +417,15 @@ func (c *Client) Start(ctx context.Context) error {
func (c *Client) Stop() error {
var errs []error

// Disconnect all active sessions
c.sessionsMux.Lock()
sessions := make([]*Session, 0, len(c.sessions))
for _, session := range c.sessions {
sessions = append(sessions, session)
}
c.sessionsMux.Unlock()

for _, session := range sessions {
for _, session := range c.snapshotSessions() {
if err := session.Disconnect(); err != nil {
errs = append(errs, fmt.Errorf("failed to disconnect session %s: %w", session.SessionID, err))
}
}

c.sessionsMux.Lock()
c.sessions = make(map[string]*Session)
c.sessionsMux.Unlock()
for _, session := range c.clearSessions() {
session.markDisconnected()
}

c.startStopMux.Lock()
defer c.startStopMux.Unlock()
Expand Down Expand Up @@ -504,10 +496,9 @@ func (c *Client) ForceStop() {
p.Kill()
}

// Clear sessions immediately without trying to destroy them
c.sessionsMux.Lock()
c.sessions = make(map[string]*Session)
c.sessionsMux.Unlock()
for _, session := range c.clearSessions() {
session.markDisconnected()
}

c.startStopMux.Lock()
defer c.startStopMux.Unlock()
Expand Down Expand Up @@ -556,6 +547,45 @@ func (c *Client) ensureConnected(ctx context.Context) error {
return fmt.Errorf("client not connected. Call Start() first")
}

func (c *Client) registerSession(session *Session) error {
c.sessionsMux.Lock()
defer c.sessionsMux.Unlock()
if existing := c.sessions[session.SessionID]; existing != nil && existing != session {
return fmt.Errorf("session %s is already active", session.SessionID)
}
c.sessions[session.SessionID] = session
return nil
}

func (c *Client) unregisterSession(session *Session) {
c.sessionsMux.Lock()
defer c.sessionsMux.Unlock()
if c.sessions[session.SessionID] == session {
delete(c.sessions, session.SessionID)
}
}

func (c *Client) snapshotSessions() []*Session {
c.sessionsMux.Lock()
defer c.sessionsMux.Unlock()
sessions := make([]*Session, 0, len(c.sessions))
for _, session := range c.sessions {
sessions = append(sessions, session)
}
return sessions
}

func (c *Client) clearSessions() []*Session {
c.sessionsMux.Lock()
defer c.sessionsMux.Unlock()
sessions := make([]*Session, 0, len(c.sessions))
for _, session := range c.sessions {
sessions = append(sessions, session)
}
c.sessions = make(map[string]*Session)
return sessions
}

// CreateSession creates a new conversation session with the Copilot CLI.
//
// Sessions maintain conversation state, handle events, and manage tool execution.
Expand Down Expand Up @@ -704,7 +734,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses

// Create and register the session before issuing the RPC so that
// events emitted by the CLI (e.g. session.start) are not dropped.
session := newSession(sessionID, c.client, "")
session := newSession(sessionID, c.client, "", c)

session.registerTools(config.Tools)
session.registerPermissionHandler(config.OnPermissionRequest)
Expand Down Expand Up @@ -733,23 +763,22 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses
session.registerAutoModeSwitchHandler(config.OnAutoModeSwitch)
}

c.sessionsMux.Lock()
c.sessions[sessionID] = session
c.sessionsMux.Unlock()
if err := c.registerSession(session); err != nil {
session.markDisconnected()
return nil, err
}

if c.options.SessionFs != nil {
if config.CreateSessionFsHandler == nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
c.unregisterSession(session)
session.markDisconnected()
return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options")
}
provider := config.CreateSessionFsHandler(session)
if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite {
if _, ok := provider.(SessionFsSqliteProvider); !ok {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
c.unregisterSession(session)
session.markDisconnected()
return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider")
}
}
Expand All @@ -758,17 +787,15 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses

result, err := c.client.Request("session.create", req)
if err != nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
c.unregisterSession(session)
session.markDisconnected()
return nil, fmt.Errorf("failed to create session: %w", err)
}

var response createSessionResponse
if err := json.Unmarshal(result, &response); err != nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
c.unregisterSession(session)
session.markDisconnected()
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}

Expand Down Expand Up @@ -889,7 +916,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,

// Create and register the session before issuing the RPC so that
// events emitted by the CLI (e.g. session.start) are not dropped.
session := newSession(sessionID, c.client, "")
session := newSession(sessionID, c.client, "", c)

session.registerTools(config.Tools)
session.registerPermissionHandler(config.OnPermissionRequest)
Expand Down Expand Up @@ -918,23 +945,22 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
session.registerAutoModeSwitchHandler(config.OnAutoModeSwitch)
}

c.sessionsMux.Lock()
c.sessions[sessionID] = session
c.sessionsMux.Unlock()
if err := c.registerSession(session); err != nil {
session.markDisconnected()
return nil, err
}

if c.options.SessionFs != nil {
if config.CreateSessionFsHandler == nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
c.unregisterSession(session)
session.markDisconnected()
return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options")
}
provider := config.CreateSessionFsHandler(session)
if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite {
if _, ok := provider.(SessionFsSqliteProvider); !ok {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
c.unregisterSession(session)
session.markDisconnected()
return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider")
}
}
Expand All @@ -943,17 +969,15 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,

result, err := c.client.Request("session.resume", req)
if err != nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
c.unregisterSession(session)
session.markDisconnected()
return nil, fmt.Errorf("failed to resume session: %w", err)
}

var response resumeSessionResponse
if err := json.Unmarshal(result, &response); err != nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
c.unregisterSession(session)
session.markDisconnected()
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}

Expand Down Expand Up @@ -1073,10 +1097,12 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error {
return fmt.Errorf("failed to delete session %s: %s", sessionID, errorMsg)
}

// Remove from local sessions map if present
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
session := c.sessions[sessionID]
c.sessionsMux.Unlock()
if session != nil {
session.markDisconnected()
}

return nil
}
Expand Down
4 changes: 3 additions & 1 deletion go/internal/e2e/commands_and_elicitation_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ func TestCommandsE2E(t *testing.T) {
sessionID := session1.SessionID
t.Cleanup(func() { _ = session1.Disconnect() })

session2, err := client1.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{
resumeClient := newResumeClient(t, client1)
session2, err := resumeClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{
OnPermissionRequest: copilot.PermissionHandler.ApproveAll,
DisableResume: true,
Commands: []copilot.CommandDefinition{
{Name: "deploy", Description: "Deploy", Handler: func(_ copilot.CommandContext) error { return nil }},
},
Expand Down
47 changes: 14 additions & 33 deletions go/internal/e2e/mcp_and_agents_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import (

func TestMCPServersE2E(t *testing.T) {
ctx := testharness.NewTestContext(t)
client := ctx.NewClient()
client := ctx.NewClient(func(opts *copilot.ClientOptions) {
opts.UseStdio = copilot.Bool(false)
opts.TCPConnectionToken = sharedTcpToken
})
t.Cleanup(func() { client.ForceStop() })

t.Run("accept MCP server config on create", func(t *testing.T) {
Expand Down Expand Up @@ -44,7 +47,6 @@ func TestMCPServersE2E(t *testing.T) {
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

message, err := testharness.GetFinalAssistantMessage(t.Context(), session)
if err != nil {
t.Fatalf("Failed to get final message: %v", err)
Expand All @@ -67,11 +69,6 @@ func TestMCPServersE2E(t *testing.T) {
}
sessionID := session1.SessionID

_, err = session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"})
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

// Resume with MCP servers
mcpServers := map[string]copilot.MCPServerConfig{
"test-server": copilot.MCPStdioServerConfig{
Expand All @@ -81,8 +78,10 @@ func TestMCPServersE2E(t *testing.T) {
},
}

session2, err := client.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{
resumeClient := newResumeClient(t, client)
session2, err := resumeClient.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{
OnPermissionRequest: copilot.PermissionHandler.ApproveAll,
DisableResume: true,
MCPServers: mcpServers,
})
if err != nil {
Expand All @@ -93,15 +92,6 @@ func TestMCPServersE2E(t *testing.T) {
t.Errorf("Expected session ID %s, got %s", sessionID, session2.SessionID)
}

message, err := session2.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 3+3?"})
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

if md, ok := message.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(md.Content, "6") {
t.Errorf("Expected message to contain '6', got: %v", message.Data)
}

session2.Disconnect()
})

Expand Down Expand Up @@ -184,7 +174,10 @@ func TestMCPServersE2E(t *testing.T) {

func TestCustomAgentsE2E(t *testing.T) {
ctx := testharness.NewTestContext(t)
client := ctx.NewClient()
client := ctx.NewClient(func(opts *copilot.ClientOptions) {
opts.UseStdio = copilot.Bool(false)
opts.TCPConnectionToken = sharedTcpToken
})
t.Cleanup(func() { client.ForceStop() })

t.Run("accept custom agent config on create", func(t *testing.T) {
Expand Down Expand Up @@ -243,11 +236,6 @@ func TestCustomAgentsE2E(t *testing.T) {
}
sessionID := session1.SessionID

_, err = session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"})
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

// Resume with custom agents
customAgents := []copilot.CustomAgentConfig{
{
Expand All @@ -258,8 +246,10 @@ func TestCustomAgentsE2E(t *testing.T) {
},
}

session2, err := client.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{
resumeClient := newResumeClient(t, client)
session2, err := resumeClient.ResumeSessionWithOptions(t.Context(), sessionID, &copilot.ResumeSessionConfig{
OnPermissionRequest: copilot.PermissionHandler.ApproveAll,
DisableResume: true,
CustomAgents: customAgents,
})
if err != nil {
Expand All @@ -270,15 +260,6 @@ func TestCustomAgentsE2E(t *testing.T) {
t.Errorf("Expected session ID %s, got %s", sessionID, session2.SessionID)
}

message, err := session2.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 6+6?"})
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

if md, ok := message.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(md.Content, "12") {
t.Errorf("Expected message to contain '12', got: %v", message.Data)
}

session2.Disconnect()
})

Expand Down
Loading
Loading