diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go index f896562..36fbfd8 100644 --- a/internal/cmd/auth.go +++ b/internal/cmd/auth.go @@ -12,6 +12,7 @@ import ( "github.com/CircleCI-Public/chunk-cli/internal/config" "github.com/CircleCI-Public/chunk-cli/internal/iostream" "github.com/CircleCI-Public/chunk-cli/internal/keyring" + "github.com/CircleCI-Public/chunk-cli/internal/oauth" "github.com/CircleCI-Public/chunk-cli/internal/tui" "github.com/CircleCI-Public/chunk-cli/internal/ui" ) @@ -29,12 +30,53 @@ func newAuthCmd() *cobra.Command { RunE: groupRunE, FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true}, } + cmd.AddCommand(newAuthLoginCmd()) cmd.AddCommand(newAuthSetCmd()) cmd.AddCommand(newAuthStatusCmd()) cmd.AddCommand(newAuthRemoveCmd()) return cmd } +func newAuthLoginCmd() *cobra.Command { + var noBrowser bool + var signup bool + cmd := &cobra.Command{ + Use: "login", + Short: "Log in to CircleCI via browser (recommended)", + Long: "Authenticate with CircleCI using OAuth. Opens your browser for a secure login flow.", + RunE: func(cmd *cobra.Command, _ []string) error { + insecureStorage, _ := cmd.Flags().GetBool("insecure-storage") + rc, _ := config.Resolve("", "", insecureStorage) + io := iostream.FromCmd(cmd) + return authLogin(cmd.Context(), io, rc.CircleCIBaseURL, noBrowser, signup, insecureStorage) + }, + } + cmd.Flags().BoolVar(&noBrowser, "no-browser", false, "Print the login URL instead of opening a browser") + cmd.Flags().BoolVar(&signup, "signup", false, "Route to the signup page instead of login") + return cmd +} + +func authLogin(ctx context.Context, streams iostream.Streams, baseURL string, noBrowser, signup, insecureStorage bool) error { + streams.Println("") + streams.Println(ui.Bold("Chunk CLI - CircleCI Login")) + streams.Println("") + + token, err := oauth.Login(ctx, oauth.LoginConfig{ + BaseURL: baseURL, + NoBrowser: noBrowser, + Signup: signup, + }, streams.Err) + if err != nil { + return &userError{ + msg: "Login failed.", + suggestion: "Try again or use `chunk auth set circleci` to set a token manually.", + err: fmt.Errorf("oauth login: %w", err), + } + } + + return saveCircleCIToken(ctx, token, streams, baseURL, insecureStorage) +} + func newAuthSetCmd() *cobra.Command { var force bool cmd := &cobra.Command{ diff --git a/internal/cmd/authhelper.go b/internal/cmd/authhelper.go index 5f32f67..82c4b53 100644 --- a/internal/cmd/authhelper.go +++ b/internal/cmd/authhelper.go @@ -16,12 +16,13 @@ import ( "github.com/CircleCI-Public/chunk-cli/internal/github" hc "github.com/CircleCI-Public/chunk-cli/internal/httpcl" "github.com/CircleCI-Public/chunk-cli/internal/iostream" + "github.com/CircleCI-Public/chunk-cli/internal/oauth" "github.com/CircleCI-Public/chunk-cli/internal/tui" "github.com/CircleCI-Public/chunk-cli/internal/ui" ) const ( - suggestionCircleCIAuth = "Set " + config.EnvCircleToken + " or run 'chunk auth set circleci'." + suggestionCircleCIAuth = "Set " + config.EnvCircleToken + " or run 'chunk auth login'." suggestionAnthropicAuth = "Set " + config.EnvAnthropicAPIKey + " or run 'chunk auth set anthropic'." suggestionGitHubAuth = "Set " + config.EnvGitHubToken + " or run 'chunk auth set github'." ) @@ -64,30 +65,56 @@ func ensureCircleCIClient(ctx context.Context, cmd *cobra.Command, rc config.Res } streams.ErrPrintln("") - streams.ErrPrintln(ui.Bold("CircleCI token required")) - streams.ErrPrintln("Create a token at https://app.circleci.com/settings/user/tokens") - streams.ErrPrintln("Don't have an account? Sign up at https://app.circleci.com/signup") + streams.ErrPrintln(ui.Bold("CircleCI authentication required")) printSaveHint(streams, "Token", insecureStorage) streams.ErrPrintln("") - token, err := prompter("CircleCI Token") - if err != nil { - if errors.Is(err, tui.ErrNoTTY) { + choice, selectErr := tui.SelectFromList("How would you like to authenticate?", []string{ + "Log in via browser (recommended)", + "Enter a token manually", + }) + if selectErr != nil { + if errors.Is(selectErr, tui.ErrNoTTY) { return nil, newUserError("CircleCI token required."). withCode("auth.circleci_token_required"). withSuggestion(suggestionCircleCIAuth). withExitCode(ExitAuthError). - wrap(err) + wrap(selectErr) } - return nil, err + return nil, selectErr } - token = strings.TrimSpace(token) - if token == "" { - return nil, newUserError("CircleCI token required."). - withCode("auth.circleci_token_required"). - withSuggestion(suggestionCircleCIAuth). - withExitCode(ExitAuthError). - wrapMsg("empty token entered") + + var token string + switch choice { + case 0: + token, err = oauth.Login(ctx, oauth.LoginConfig{ + BaseURL: rc.CircleCIBaseURL, + }, streams.Err) + if err != nil { + return nil, fmt.Errorf("oauth login: %w", err) + } + case 1: + streams.ErrPrintln("Create a token at https://app.circleci.com/settings/user/tokens") + streams.ErrPrintln("") + token, err = prompter("CircleCI Token") + if err != nil { + if errors.Is(err, tui.ErrNoTTY) { + return nil, newUserError("CircleCI token required."). + withCode("auth.circleci_token_required"). + withSuggestion(suggestionCircleCIAuth). + withExitCode(ExitAuthError). + wrap(err) + } + return nil, err + } + token = strings.TrimSpace(token) + if token == "" { + return nil, newUserError("CircleCI token required."). + withCode("auth.circleci_token_required"). + withSuggestion(suggestionCircleCIAuth). + withExitCode(ExitAuthError). + wrapMsg("empty token entered") + } } streams.ErrPrintln(ui.Dim("Validating CircleCI token...")) diff --git a/internal/cmd/validate_test.go b/internal/cmd/validate_test.go index 86440a0..7d5ceea 100644 --- a/internal/cmd/validate_test.go +++ b/internal/cmd/validate_test.go @@ -55,7 +55,7 @@ func TestValidateHookExitsOneWhenCircleCITokenMissingAndRemoteCommands(t *testin assert.Equal(t, ec.ExitCode(), 1) assert.Assert(t, strings.Contains(stderr, "CircleCI auth is not configured"), "expected auth message in stderr, got: %q", stderr) - assert.Assert(t, strings.Contains(stderr, "chunk auth set circleci"), + assert.Assert(t, strings.Contains(stderr, "chunk auth login"), "expected auth hint in stderr, got: %q", stderr) } diff --git a/internal/oauth/browser.go b/internal/oauth/browser.go new file mode 100644 index 0000000..21583a2 --- /dev/null +++ b/internal/oauth/browser.go @@ -0,0 +1,20 @@ +package oauth + +import ( + "fmt" + "os/exec" + "runtime" +) + +func OpenBrowser(url string) error { + switch runtime.GOOS { + case "darwin": + return exec.Command("open", url).Start() + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("cmd", "/c", "start", url).Start() + default: + return fmt.Errorf("unsupported platform %s", runtime.GOOS) + } +} diff --git a/internal/oauth/callback.go b/internal/oauth/callback.go new file mode 100644 index 0000000..10c7bc2 --- /dev/null +++ b/internal/oauth/callback.go @@ -0,0 +1,63 @@ +package oauth + +import ( + "context" + "net" + "net/http" + "time" +) + +type CallbackResult struct { + Code string + State string + Error string +} + +func ListenForCallback(ctx context.Context) (port int, result <-chan CallbackResult, cleanup func(), err error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return 0, nil, nil, err + } + port = listener.Addr().(*net.TCPAddr).Port + + ch := make(chan CallbackResult, 1) + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + res := CallbackResult{ + Code: q.Get("code"), + State: q.Get("state"), + Error: q.Get("error"), + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if res.Error != "" { + _, _ = w.Write([]byte("

Login was denied.

You can close this tab.

")) + } else { + _, _ = w.Write([]byte("

Login successful!

You can close this tab and return to the terminal.

")) + } + + select { + case ch <- res: + default: + } + }) + + srv := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + BaseContext: func(_ net.Listener) context.Context { + return ctx + }, + } + + go func() { + _ = srv.Serve(listener) + }() + + cleanup = func() { + _ = srv.Shutdown(context.Background()) + } + + return port, ch, cleanup, nil +} diff --git a/internal/oauth/callback_test.go b/internal/oauth/callback_test.go new file mode 100644 index 0000000..fb32fab --- /dev/null +++ b/internal/oauth/callback_test.go @@ -0,0 +1,61 @@ +package oauth + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + "gotest.tools/v3/assert" +) + +func TestListenForCallback_Success(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + port, resultCh, cleanup, err := ListenForCallback(ctx) + assert.NilError(t, err) + defer cleanup() + + url := fmt.Sprintf("http://127.0.0.1:%d/callback?code=test-code&state=test-state", port) + resp, err := http.Get(url) + assert.NilError(t, err) + resp.Body.Close() + assert.Equal(t, resp.StatusCode, http.StatusOK) + + res := <-resultCh + assert.Equal(t, res.Code, "test-code") + assert.Equal(t, res.State, "test-state") + assert.Equal(t, res.Error, "") +} + +func TestListenForCallback_Error(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + port, resultCh, cleanup, err := ListenForCallback(ctx) + assert.NilError(t, err) + defer cleanup() + + url := fmt.Sprintf("http://127.0.0.1:%d/callback?error=access_denied&state=test-state", port) + resp, err := http.Get(url) + assert.NilError(t, err) + resp.Body.Close() + + res := <-resultCh + assert.Equal(t, res.Error, "access_denied") + assert.Equal(t, res.State, "test-state") + assert.Equal(t, res.Code, "") +} + +func TestListenForCallback_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + _, _, cleanup, err := ListenForCallback(ctx) + assert.NilError(t, err) + defer cleanup() + + cancel() + // Server should shut down without hanging; cleanup is the verification. +} diff --git a/internal/oauth/deviceid.go b/internal/oauth/deviceid.go new file mode 100644 index 0000000..f946329 --- /dev/null +++ b/internal/oauth/deviceid.go @@ -0,0 +1,53 @@ +package oauth + +import ( + "crypto/rand" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/CircleCI-Public/chunk-cli/internal/config" +) + +const deviceIDFile = "device_id" + +func LoadOrCreateDeviceID() (string, error) { + dir, err := config.AppState() + if err != nil { + return "", fmt.Errorf("resolve state dir: %w", err) + } + path := filepath.Join(dir, deviceIDFile) + + data, err := os.ReadFile(path) + if err == nil { + id := strings.TrimSpace(string(data)) + if id != "" { + return id, nil + } + } + + id, err := generateUUID4() + if err != nil { + return "", fmt.Errorf("generate device id: %w", err) + } + + if err := os.MkdirAll(dir, 0o700); err != nil { + return "", fmt.Errorf("create state dir: %w", err) + } + if err := os.WriteFile(path, []byte(id+"\n"), 0o600); err != nil { + return "", fmt.Errorf("write device id: %w", err) + } + return id, nil +} + +func generateUUID4() (string, error) { + var buf [16]byte + if _, err := rand.Read(buf[:]); err != nil { + return "", err + } + buf[6] = (buf[6] & 0x0f) | 0x40 // version 4 + buf[8] = (buf[8] & 0x3f) | 0x80 // variant 10 + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:16]), nil +} diff --git a/internal/oauth/deviceid_test.go b/internal/oauth/deviceid_test.go new file mode 100644 index 0000000..b3ae6f2 --- /dev/null +++ b/internal/oauth/deviceid_test.go @@ -0,0 +1,32 @@ +package oauth + +import ( + "regexp" + "testing" + + "gotest.tools/v3/assert" + + "github.com/CircleCI-Public/chunk-cli/internal/config" +) + +var uuidPattern = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`) + +func TestLoadOrCreateDeviceID_Creates(t *testing.T) { + t.Setenv(config.EnvXDGStateHome, t.TempDir()) + + id, err := LoadOrCreateDeviceID() + assert.NilError(t, err) + assert.Assert(t, uuidPattern.MatchString(id), "expected UUID v4, got %q", id) +} + +func TestLoadOrCreateDeviceID_Reuses(t *testing.T) { + t.Setenv(config.EnvXDGStateHome, t.TempDir()) + + id1, err := LoadOrCreateDeviceID() + assert.NilError(t, err) + + id2, err := LoadOrCreateDeviceID() + assert.NilError(t, err) + + assert.Equal(t, id1, id2) +} diff --git a/internal/oauth/exchange.go b/internal/oauth/exchange.go new file mode 100644 index 0000000..7f10c24 --- /dev/null +++ b/internal/oauth/exchange.go @@ -0,0 +1,59 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +func exchangeToken(ctx context.Context, baseURL, code, redirectURI, verifier string) (*TokenResponse, error) { + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "client_id": {ClientID}, + "code_verifier": {verifier}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + strings.TrimRight(baseURL, "/")+"/oauth/token", + strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("build token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read token response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body)) + } + + var tok TokenResponse + if err := json.Unmarshal(body, &tok); err != nil { + return nil, fmt.Errorf("decode token response: %w", err) + } + if tok.AccessToken == "" { + return nil, fmt.Errorf("token endpoint returned empty access_token") + } + return &tok, nil +} diff --git a/internal/oauth/exchange_test.go b/internal/oauth/exchange_test.go new file mode 100644 index 0000000..8edb76c --- /dev/null +++ b/internal/oauth/exchange_test.go @@ -0,0 +1,52 @@ +package oauth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "gotest.tools/v3/assert" +) + +func TestExchangeToken_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Method, http.MethodPost) + assert.Equal(t, r.URL.Path, "/oauth/token") + assert.Equal(t, r.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + assert.NilError(t, r.ParseForm()) + assert.Equal(t, r.FormValue("grant_type"), "authorization_code") + assert.Equal(t, r.FormValue("code"), "test-code") + assert.Equal(t, r.FormValue("client_id"), ClientID) + assert.Equal(t, r.FormValue("code_verifier"), "test-verifier") + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"CCIPAT_test_token","token_type":"Bearer","expires_in":7776000}`)) + })) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + tok, err := exchangeToken(ctx, srv.URL, "test-code", "http://127.0.0.1:12345/callback", "test-verifier") + assert.NilError(t, err) + assert.Equal(t, tok.AccessToken, "CCIPAT_test_token") + assert.Equal(t, tok.TokenType, "Bearer") + assert.Equal(t, tok.ExpiresIn, 7776000) +} + +func TestExchangeToken_ErrorResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + })) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := exchangeToken(ctx, srv.URL, "bad-code", "http://127.0.0.1:12345/callback", "test-verifier") + assert.ErrorContains(t, err, "400") +} diff --git a/internal/oauth/login.go b/internal/oauth/login.go new file mode 100644 index 0000000..74ac749 --- /dev/null +++ b/internal/oauth/login.go @@ -0,0 +1,112 @@ +package oauth + +import ( + "context" + "fmt" + "io" + "net/url" + "runtime" + "strings" + "time" +) + +const ( + // ClientID is the OAuth client identifier registered with CircleCI. + ClientID = "chunk-cli" + // CallbackTimeout is how long to wait for the browser callback before giving up. + CallbackTimeout = 5 * time.Minute + callbackPath = "/callback" +) + +type LoginConfig struct { + BaseURL string + NoBrowser bool + Signup bool +} + +func Login(ctx context.Context, cfg LoginConfig, status io.Writer) (string, error) { + deviceID, err := LoadOrCreateDeviceID() + if err != nil { + return "", fmt.Errorf("device id: %w", err) + } + + verifier, err := GenerateVerifier() + if err != nil { + return "", fmt.Errorf("pkce verifier: %w", err) + } + challenge := S256Challenge(verifier) + + state, err := GenerateState() + if err != nil { + return "", fmt.Errorf("state: %w", err) + } + + port, resultCh, cleanup, err := ListenForCallback(ctx) + if err != nil { + return "", err + } + defer cleanup() + + redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, callbackPath) + authorizeURL := buildAuthorizeURL(cfg.BaseURL, redirectURI, challenge, state, deviceID, cfg.Signup) + + w := func(s string) { _, _ = fmt.Fprintln(status, s) } + + if cfg.NoBrowser { + w("Open this URL in your browser to log in:") + w("") + w(" " + authorizeURL) + w("") + } else { + w("Opening browser for CircleCI login...") + if err := OpenBrowser(authorizeURL); err != nil { + w("Could not open browser. Open this URL manually:") + w("") + w(" " + authorizeURL) + w("") + } + } + w("Waiting for login (up to 5 minutes)...") + + var res CallbackResult + select { + case res = <-resultCh: + case <-time.After(CallbackTimeout): + return "", fmt.Errorf("timed out waiting for browser callback after %s", CallbackTimeout) + case <-ctx.Done(): + return "", ctx.Err() + } + + if res.Error != "" { + return "", fmt.Errorf("authorization denied: %s", res.Error) + } + if res.State != state { + return "", fmt.Errorf("state mismatch (possible CSRF)") + } + if res.Code == "" { + return "", fmt.Errorf("callback contained no authorization code") + } + + tok, err := exchangeToken(ctx, cfg.BaseURL, res.Code, redirectURI, verifier) + if err != nil { + return "", err + } + return tok.AccessToken, nil +} + +func buildAuthorizeURL(baseURL, redirectURI, challenge, state, deviceID string, signup bool) string { + params := url.Values{ + "client_id": {ClientID}, + "response_type": {"code"}, + "redirect_uri": {redirectURI}, + "code_challenge": {challenge}, + "code_challenge_method": {"S256"}, + "state": {state}, + "os": {runtime.GOOS}, + "device_id": {deviceID}, + } + if signup { + params.Set("signup", "true") + } + return strings.TrimRight(baseURL, "/") + "/oauth/authorize?" + params.Encode() +} diff --git a/internal/oauth/pkce.go b/internal/oauth/pkce.go new file mode 100644 index 0000000..8d00ac7 --- /dev/null +++ b/internal/oauth/pkce.go @@ -0,0 +1,20 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +func GenerateVerifier() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func S256Challenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} diff --git a/internal/oauth/pkce_test.go b/internal/oauth/pkce_test.go new file mode 100644 index 0000000..859b717 --- /dev/null +++ b/internal/oauth/pkce_test.go @@ -0,0 +1,31 @@ +package oauth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" + + "gotest.tools/v3/assert" +) + +func TestGenerateVerifier_Length(t *testing.T) { + v, err := GenerateVerifier() + assert.NilError(t, err) + assert.Equal(t, len(v), 43) // 32 bytes base64url = 43 chars +} + +func TestGenerateVerifier_Uniqueness(t *testing.T) { + v1, err := GenerateVerifier() + assert.NilError(t, err) + v2, err := GenerateVerifier() + assert.NilError(t, err) + assert.Assert(t, v1 != v2) +} + +func TestS256Challenge_KnownVector(t *testing.T) { + // RFC 7636 Appendix B test vector + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + h := sha256.Sum256([]byte(verifier)) + expected := base64.RawURLEncoding.EncodeToString(h[:]) + assert.Equal(t, S256Challenge(verifier), expected) +} diff --git a/internal/oauth/state.go b/internal/oauth/state.go new file mode 100644 index 0000000..1f4f727 --- /dev/null +++ b/internal/oauth/state.go @@ -0,0 +1,14 @@ +package oauth + +import ( + "crypto/rand" + "encoding/hex" +) + +func GenerateState() (string, error) { + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} diff --git a/internal/oauth/state_test.go b/internal/oauth/state_test.go new file mode 100644 index 0000000..f53c5de --- /dev/null +++ b/internal/oauth/state_test.go @@ -0,0 +1,21 @@ +package oauth + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +func TestGenerateState_Length(t *testing.T) { + s, err := GenerateState() + assert.NilError(t, err) + assert.Equal(t, len(s), 32) // 16 bytes hex = 32 chars +} + +func TestGenerateState_Uniqueness(t *testing.T) { + s1, err := GenerateState() + assert.NilError(t, err) + s2, err := GenerateState() + assert.NilError(t, err) + assert.Assert(t, s1 != s2) +}