diff --git a/go.mod b/go.mod index 723c257..34eeedc 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( charm.land/bubbletea/v2 v2.0.6 github.com/BurntSushi/toml v1.6.0 github.com/aymanbagabas/go-udiff v0.4.1 + github.com/cenkalti/backoff/v4 v4.3.0 github.com/coder/websocket v1.8.14 github.com/dustinkirkland/golang-petname v0.0.0-20260215035315-f0c533e9ce9b github.com/gin-gonic/gin v1.12.0 diff --git a/go.sum b/go.sum index 230f936..cd054a9 100644 --- a/go.sum +++ b/go.sum @@ -140,6 +140,8 @@ github.com/catenacyber/perfsprint v0.10.1 h1:u7Riei30bk46XsG8nknMhKLXG9BcXz3+3tl github.com/catenacyber/perfsprint v0.10.1/go.mod h1:DJTGsi/Zufpuus6XPGJyKOTMELe347o6akPvWG9Zcsc= github.com/ccojocar/zxcvbn-go v1.0.4 h1:FWnCIRMXPj43ukfX000kvBZvV6raSxakYr1nzyNrUcc= github.com/ccojocar/zxcvbn-go v1.0.4/go.mod h1:3GxGX+rHmueTUMvm5ium7irpyjmm7ikxYFOSJB21Das= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= diff --git a/internal/sidecar/ssh.go b/internal/sidecar/ssh.go index 02afe87..36f490b 100644 --- a/internal/sidecar/ssh.go +++ b/internal/sidecar/ssh.go @@ -16,12 +16,16 @@ import ( "path/filepath" "sort" "strings" + "syscall" + "time" + "github.com/cenkalti/backoff/v4" "github.com/coder/websocket" "golang.org/x/crypto/ssh" "golang.org/x/term" "github.com/CircleCI-Public/chunk-cli/internal/closer" + "github.com/CircleCI-Public/chunk-cli/internal/iostream" ) // ExecResult holds the output of a command executed over SSH. @@ -336,3 +340,62 @@ func tofuHostKeyCallback(knownHostsPath, host string) ssh.HostKeyCallback { return err } } + +// waitForSSHReady probes the SSH connection with exponential backoff, retrying +// on transient errors so a newly-created sidecar has time to finish booting +// before its SSH service accepts connections. +func waitForSSHReady(ctx context.Context, session *Session, status iostream.StatusFunc) error { + return waitForSSHReadyWithDial(ctx, session, status, dialSSH) +} + +// waitForSSHReadyWithDial is the underlying implementation; dialFn is injectable +// for testing the retry control flow without a live network. +func waitForSSHReadyWithDial( + ctx context.Context, + session *Session, + status iostream.StatusFunc, + dialFn func(context.Context, *Session) (*sshConn, error), +) error { + b := backoff.NewExponentialBackOff() + b.InitialInterval = 2 * time.Second + b.MaxInterval = 15 * time.Second + b.MaxElapsedTime = 90 * time.Second + + notified := false + return backoff.RetryNotify( + func() error { + conn, err := dialFn(ctx, session) + if err != nil { + if !isTransientSSHError(err) { + return backoff.Permanent(err) + } + return err + } + if conn != nil { + _ = conn.Close() + } + return nil + }, + backoff.WithContext(b, ctx), + func(_ error, _ time.Duration) { + if !notified { + status(iostream.LevelInfo, "Waiting for sidecar SSH to become available...") + notified = true + } + }, + ) +} + +// isTransientSSHError reports whether err is a network-level error worth +// retrying when dialling the sidecar SSH service — specifically connection +// refused and timeouts that indicate the daemon is not yet ready. +func isTransientSSHError(err error) bool { + var netErr net.Error + if !errors.As(err, &netErr) { + return false + } + if netErr.Timeout() { + return true + } + return errors.Is(err, syscall.ECONNREFUSED) +} diff --git a/internal/sidecar/sync.go b/internal/sidecar/sync.go index 5d491c9..6891eee 100644 --- a/internal/sidecar/sync.go +++ b/internal/sidecar/sync.go @@ -56,6 +56,9 @@ func Sync(ctx context.Context, if err != nil { return err } + if err := waitForSSHReady(ctx, session, status); err != nil { + return err + } cwd, err := os.Getwd() if err != nil { diff --git a/internal/sidecar/sync_whitebox_test.go b/internal/sidecar/sync_whitebox_test.go new file mode 100644 index 0000000..901c87e --- /dev/null +++ b/internal/sidecar/sync_whitebox_test.go @@ -0,0 +1,174 @@ +package sidecar + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "syscall" + "testing" + + "gotest.tools/v3/assert" + + "github.com/CircleCI-Public/chunk-cli/internal/circleci" + "github.com/CircleCI-Public/chunk-cli/internal/iostream" + "github.com/CircleCI-Public/chunk-cli/internal/testing/fakes" +) + +func TestIsTransientSSHError(t *testing.T) { + t.Run("timeout is transient", func(t *testing.T) { + err := &net.OpError{Op: "dial", Err: &timeoutError{}} + assert.Equal(t, isTransientSSHError(err), true) + }) + + t.Run("connection refused is transient", func(t *testing.T) { + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}, + } + assert.Equal(t, isTransientSSHError(err), true) + }) + + t.Run("connection refused wrapped with fmt.Errorf is transient", func(t *testing.T) { + inner := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}, + } + err := fmt.Errorf("websocket connect: %w", inner) + assert.Equal(t, isTransientSSHError(err), true) + }) + + t.Run("timeout wrapped with fmt.Errorf is transient", func(t *testing.T) { + inner := &net.OpError{Op: "dial", Err: &timeoutError{}} + err := fmt.Errorf("register SSH key: %w", inner) + assert.Equal(t, isTransientSSHError(err), true) + }) + + t.Run("unreachable host is not transient", func(t *testing.T) { + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH}, + } + assert.Equal(t, isTransientSSHError(err), false) + }) + + t.Run("ErrNotAuthorized is not transient", func(t *testing.T) { + err := fmt.Errorf("add ssh key: %w", circleci.ErrNotAuthorized) + assert.Equal(t, isTransientSSHError(err), false) + }) + + t.Run("StatusError is not transient", func(t *testing.T) { + err := &circleci.StatusError{Op: "add ssh key", StatusCode: 503} + assert.Equal(t, isTransientSSHError(err), false) + }) + + t.Run("KeyNotFoundError is not transient", func(t *testing.T) { + err := &KeyNotFoundError{Path: "/home/user/.ssh/chunk_ai"} + assert.Equal(t, isTransientSSHError(err), false) + }) + + t.Run("PublicKeyNotFoundError is not transient", func(t *testing.T) { + err := &PublicKeyNotFoundError{KeyPath: "/home/user/.ssh/chunk_ai.pub"} + assert.Equal(t, isTransientSSHError(err), false) + }) + + t.Run("generic error is not transient", func(t *testing.T) { + err := fmt.Errorf("resolve home directory: permission denied") + assert.Equal(t, isTransientSSHError(err), false) + }) +} + +func TestWaitForSSHReady(t *testing.T) { + t.Run("succeeds immediately when SSH server is ready", func(t *testing.T) { + keyFile, pubKey := fakes.GenerateSSHKeypair(t) + sshSrv := fakes.NewSSHServer(t, pubKey) + + session := &Session{ + URL: sshSrv.Addr(), + IdentityFile: keyFile, + KnownHosts: filepath.Join(t.TempDir(), "known_hosts"), + } + + var notified bool + statusFn := iostream.StatusFunc(func(_ iostream.Level, _ string) { notified = true }) + + err := waitForSSHReady(context.Background(), session, statusFn) + assert.NilError(t, err) + assert.Equal(t, notified, false, "no retry should be needed when SSH is already ready") + }) + + t.Run("permanent error returns immediately without notifying", func(t *testing.T) { + permanentErr := errors.New("ssh handshake: auth failed") + + var notifications int + statusFn := iostream.StatusFunc(func(_ iostream.Level, _ string) { notifications++ }) + + err := waitForSSHReadyWithDial(context.Background(), &Session{}, statusFn, + func(_ context.Context, _ *Session) (*sshConn, error) { + return nil, permanentErr // not a net.Error → permanent + }, + ) + assert.ErrorIs(t, err, permanentErr) + assert.Equal(t, notifications, 0, "permanent error should not trigger retry notification") + }) + + t.Run("retries on transient error and notifies exactly once", func(t *testing.T) { + transientErr := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var notifications int + statusFn := iostream.StatusFunc(func(_ iostream.Level, _ string) { + notifications++ + cancel() // stop retrying after the first notification + }) + + err := waitForSSHReadyWithDial(ctx, &Session{}, statusFn, + func(_ context.Context, _ *Session) (*sshConn, error) { + return nil, transientErr + }, + ) + assert.Assert(t, err != nil, "should return an error when retries are stopped") + assert.Equal(t, notifications, 1, "status should be notified exactly once regardless of retry count") + }) + + t.Run("succeeds after transient errors resolve", func(t *testing.T) { + transientErr := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED}, + } + + attempts := 0 + statusFn := iostream.StatusFunc(func(_ iostream.Level, _ string) {}) + + err := waitForSSHReadyWithDial(context.Background(), &Session{}, statusFn, + func(_ context.Context, _ *Session) (*sshConn, error) { + attempts++ + if attempts < 3 { + return nil, transientErr + } + return nil, nil // success: SSH is now ready + }, + ) + assert.NilError(t, err) + assert.Equal(t, attempts, 3, "should have retried until success") + }) +} + +// timeoutError is a net.Error that reports Timeout() == true. +type timeoutError struct{} + +func (timeoutError) Error() string { return "i/o timeout" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true }