diff --git a/cmd/cone/cmd.go b/cmd/cone/cmd.go index 0f363248..b2d7aabc 100644 --- a/cmd/cone/cmd.go +++ b/cmd/cone/cmd.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "os" "github.com/conductorone/cone/pkg/client" "github.com/spf13/cobra" @@ -17,6 +18,36 @@ func cmdContext(cmd *cobra.Command) (context.Context, client.C1Client, *viper.Vi return nil, nil, nil, err } + // Priority 1: CONDUCTORONE_ACCESS_TOKEN -- pre-exchanged bearer token + if accessToken := os.Getenv(client.EnvAccessToken); accessToken != "" { + clientID := v.GetString("client-id") + if clientID == "" { + clientID = os.Getenv(client.EnvClientID) + } + c, err := client.NewWithAccessToken(ctx, accessToken, clientID, v, getCmdName(cmd)) + if err != nil { + return nil, nil, nil, err + } + return ctx, c, v, nil + } + + // Priority 2: CONDUCTORONE_OIDC_TOKEN -- RFC 8693 token exchange + if oidcToken := os.Getenv(client.EnvOIDCToken); oidcToken != "" { + clientID := v.GetString("client-id") + if clientID == "" { + clientID = os.Getenv(client.EnvClientID) + } + if clientID == "" { + return nil, nil, nil, fmt.Errorf("%s requires --client-id, CONE_CLIENT_ID, or %s", client.EnvOIDCToken, client.EnvClientID) + } + c, err := client.NewWithOIDCToken(ctx, oidcToken, clientID, v, getCmdName(cmd)) + if err != nil { + return nil, nil, nil, err + } + return ctx, c, v, nil + } + + // Priority 3: existing client-id + client-secret flow clientId, clientSecret, err := getCredentials(v) if err != nil { return nil, nil, nil, err diff --git a/cmd/cone/token.go b/cmd/cone/token.go index 3e3afa65..6a2bdeb3 100644 --- a/cmd/cone/token.go +++ b/cmd/cone/token.go @@ -40,9 +40,14 @@ func tokenRun(cmd *cobra.Command, args []string) error { return err } - tokenSrc, _, _, err := client.NewC1TokenSource(ctx, + _, tokenHost, err := client.ResolveServerHost(clientId, v) + if err != nil { + return err + } + + tokenSrc, err := client.NewC1TokenSource(ctx, clientId, clientSecret, - v.GetString("api-endpoint"), v.GetBool("debug"), + tokenHost, v.GetBool("debug"), ) if err != nil { return err diff --git a/pkg/client/client.go b/pkg/client/client.go index fd1dccd5..567ef1b0 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -5,9 +5,12 @@ import ( "fmt" "net/http" "net/url" + "os" "strconv" + "strings" "github.com/spf13/viper" + "golang.org/x/oauth2" sdk "github.com/conductorone/conductorone-sdk-go" "github.com/conductorone/conductorone-sdk-go/pkg/models/shared" @@ -17,6 +20,95 @@ import ( const ConeClientID = "2RGdOS94VDferT9e80mdgntl36K" +// Environment variable names for ConductorOne authentication. +// These match the constants in conductorone-sdk-go and terraform-provider-conductorone. +const ( + EnvAccessToken = "CONDUCTORONE_ACCESS_TOKEN" + EnvOIDCToken = "CONDUCTORONE_OIDC_TOKEN" + EnvClientID = "CONDUCTORONE_CLIENT_ID" + EnvClientSecret = "CONDUCTORONE_CLIENT_SECRET" + EnvServerURL = "CONDUCTORONE_SERVER_URL" +) + +// normalizeHost extracts the host (with port) from a value that may be a +// full URL ("https://host:port/"), a bare hostname ("host"), or host:port. +func normalizeHost(input string) string { + input = strings.TrimSpace(input) + if input == "" { + return "" + } + if strings.Contains(input, "://") { + u, err := url.Parse(input) + if err == nil && u.Host != "" { + return u.Host + } + } + return strings.TrimRight(input, "/") +} + +// ResolveServerHost determines the API server host using a consistent priority: +// 1. --api-endpoint flag (via viper) +// 2. CONDUCTORONE_SERVER_URL env var +// 3. CONE_API_ENDPOINT env var +// 4. Parsed from clientID (e.g. "name@host/suffix" -> "host") +// +// Returns (clientName, host, error). clientName is empty if no clientID provided. +func ResolveServerHost(clientID string, v *viper.Viper) (string, string, error) { + // Check explicit overrides first + if h := normalizeHost(v.GetString("api-endpoint")); h != "" { + clientName, _, err := parseClientIDName(clientID) + if err != nil && clientID != "" { + return "", "", err + } + return clientName, h, nil + } + if h := normalizeHost(os.Getenv(EnvServerURL)); h != "" { + clientName, _, err := parseClientIDName(clientID) + if err != nil && clientID != "" { + return "", "", err + } + return clientName, h, nil + } + if h := normalizeHost(os.Getenv("CONE_API_ENDPOINT")); h != "" { + clientName, _, err := parseClientIDName(clientID) + if err != nil && clientID != "" { + return "", "", err + } + return clientName, h, nil + } + + // Fall back to parsing host from client ID + if clientID != "" { + clientName, host, err := parseClientIDName(clientID) + if err != nil { + return "", "", err + } + return clientName, host, nil + } + + return "", "", nil +} + +// parseClientIDName splits a client ID into (cutename, host, error). +// Client IDs have the format "cutename@host/suffix". +func parseClientIDName(input string) (string, string, error) { + if input == "" { + return "", "", nil + } + items := strings.SplitN(input, "@", 2) + if len(items) != 2 { + return "", "", ErrInvalidClientID + } + clientName := items[0] + + parts := strings.SplitN(items[1], "/", 2) + if len(parts) != 2 { + return "", "", ErrInvalidClientID + } + + return clientName, parts[0], nil +} + type contextKey string const VersionKey contextKey = "version" @@ -111,15 +203,81 @@ func New( v *viper.Viper, cmdName string, ) (C1Client, error) { - tokenSrc, clientName, tokenHost, err := NewC1TokenSource(ctx, - clientId, clientSecret, - v.GetString("api-endpoint"), - v.GetBool("debug"), - ) + clientName, tokenHost, err := ResolveServerHost(clientId, v) + if err != nil { + return nil, err + } + + tokenSrc, err := NewC1TokenSource(ctx, clientId, clientSecret, tokenHost, v.GetBool("debug")) if err != nil { return nil, err } + return newClientWithTokenSource(ctx, tokenSrc, clientName, tokenHost, v, cmdName) +} + +// NewWithAccessToken creates a client using a pre-exchanged bearer token. +func NewWithAccessToken( + ctx context.Context, + accessToken string, + clientID string, + v *viper.Viper, + cmdName string, +) (C1Client, error) { + clientName, tokenHost, err := ResolveServerHost(clientID, v) + if err != nil { + return nil, err + } + if tokenHost == "" { + return nil, fmt.Errorf("%s requires --client-id, %s, or --api-endpoint to determine the server", EnvAccessToken, EnvClientID) + } + + tokenSrc := oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: accessToken, + }) + + return newClientWithTokenSource(ctx, tokenSrc, clientName, tokenHost, v, cmdName) +} + +// NewWithOIDCToken creates a client that exchanges an OIDC token for a C1 access token. +func NewWithOIDCToken( + ctx context.Context, + oidcToken string, + clientID string, + v *viper.Viper, + cmdName string, +) (C1Client, error) { + if oidcToken == "" { + return nil, fmt.Errorf("NewWithOIDCToken: oidcToken must be non-empty; set %s or pass --oidc-token", EnvOIDCToken) + } + if clientID == "" { + return nil, fmt.Errorf("NewWithOIDCToken: clientID must be non-empty; set %s or pass --client-id", EnvClientID) + } + + clientName, tokenHost, err := ResolveServerHost(clientID, v) + if err != nil { + return nil, err + } + if tokenHost == "" { + return nil, fmt.Errorf("NewWithOIDCToken: could not determine server host from clientID or --api-endpoint; parseClientID requires a clientID in the form \"name@host/suffix\"") + } + + tokenSrc, err := NewTokenExchangeSource(ctx, oidcToken, clientID, tokenHost, v.GetBool("debug")) + if err != nil { + return nil, err + } + + return newClientWithTokenSource(ctx, tokenSrc, clientName, tokenHost, v, cmdName) +} + +func newClientWithTokenSource( + ctx context.Context, + tokenSrc oauth2.TokenSource, + clientName string, + tokenHost string, + v *viper.Viper, + cmdName string, +) (C1Client, error) { uclient, err := uhttp.NewClient( ctx, uhttp.WithTokenSource(tokenSrc), diff --git a/pkg/client/token_exchange.go b/pkg/client/token_exchange.go new file mode 100644 index 00000000..b611ad3e --- /dev/null +++ b/pkg/client/token_exchange.go @@ -0,0 +1,98 @@ +package client + +import ( + "context" + "net/http" + "net/url" + "strings" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "golang.org/x/oauth2" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "gopkg.in/square/go-jose.v2/json" + + "github.com/conductorone/cone/pkg/uhttp" +) + +const ( + grantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" //nolint:gosec // OAuth2 grant type URI, not a credential + subjectTokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" //nolint:gosec // OAuth2 token type URI, not a credential +) + +// tokenExchangeSource implements oauth2.TokenSource by exchanging an external +// OIDC JWT for a ConductorOne access token via RFC 8693 token exchange. +type tokenExchangeSource struct { + oidcToken string + clientID string + tokenHost string + httpClient *http.Client +} + +func (t *tokenExchangeSource) Token() (*oauth2.Token, error) { + body := url.Values{ + "grant_type": []string{grantTypeTokenExchange}, + "subject_token": []string{t.oidcToken}, + "subject_token_type": []string{subjectTokenTypeJWT}, + "client_id": []string{t.clientID}, + } + + tokenURL := url.URL{ + Scheme: "https", + Host: t.tokenHost, + Path: "auth/v1/token", + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, tokenURL.String(), strings.NewReader(body.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, status.Errorf(codes.Unauthenticated, "token exchange failed: %s", resp.Status) + } + + c1t := &c1Token{} + err = json.NewDecoder(resp.Body).Decode(c1t) + if err != nil { + return nil, err + } + + if c1t.AccessToken == "" { + return nil, status.Errorf(codes.Unauthenticated, "token exchange failed: empty access token") + } + + return &oauth2.Token{ + AccessToken: c1t.AccessToken, + TokenType: c1t.TokenType, + Expiry: time.Now().Add(time.Duration(c1t.Expiry) * time.Second), + }, nil +} + +// NewTokenExchangeSource creates an oauth2.TokenSource that exchanges an external +// OIDC token for a ConductorOne access token via RFC 8693 token exchange. +func NewTokenExchangeSource(ctx context.Context, oidcToken, clientID, tokenHost string, debug bool) (oauth2.TokenSource, error) { + httpClient, err := uhttp.NewClient(ctx, + uhttp.WithLogger(true, ctxzap.Extract(ctx)), + uhttp.WithUserAgent("cone-wfe-credential-provider"), + uhttp.WithDebug(debug), + ) + if err != nil { + return nil, err + } + + return oauth2.ReuseTokenSource(nil, &tokenExchangeSource{ + oidcToken: oidcToken, + clientID: clientID, + tokenHost: tokenHost, + httpClient: httpClient, + }), nil +} diff --git a/pkg/client/token_source.go b/pkg/client/token_source.go index 4edac5ba..776e41b8 100644 --- a/pkg/client/token_source.go +++ b/pkg/client/token_source.go @@ -49,31 +49,6 @@ type c1TokenSource struct { httpClient *http.Client } -func parseClientID(input string, forceTokenHost string) (string, string, error) { - // split the input into 2 parts by @ - items := strings.SplitN(input, "@", 2) - if len(items) != 2 { - return "", "", ErrInvalidClientID - } - clientName := items[0] - - // split the right part into 2 parts by / - items = strings.SplitN(items[1], "/", 2) - if len(items) != 2 { - return "", "", ErrInvalidClientID - } - - if forceTokenHost != "" { - return clientName, forceTokenHost, nil - } - - if envHost, ok := os.LookupEnv("CONE_API_ENDPOINT"); ok && envHost != "" { - return clientName, envHost, nil - } - - return clientName, items[0], nil -} - func ParseSecret(input []byte) (*jose.JSONWebKey, error) { items := bytes.SplitN(input, []byte(":"), 4) if len(items) != 4 { @@ -203,17 +178,12 @@ func NewC1TokenSource( ctx context.Context, clientID string, clientSecret string, - forceTokenHost string, + tokenHost string, debug bool, -) (oauth2.TokenSource, string, string, error) { - clientName, tokenHost, err := parseClientID(clientID, forceTokenHost) - if err != nil { - return nil, "", "", err - } - +) (oauth2.TokenSource, error) { secret, err := ParseSecret([]byte(clientSecret)) if err != nil { - return nil, "", "", err + return nil, err } httpClient, err := uhttp.NewClient(ctx, @@ -222,12 +192,12 @@ func NewC1TokenSource( uhttp.WithDebug(debug), ) if err != nil { - return nil, "", "", err + return nil, err } return oauth2.ReuseTokenSource(nil, &c1TokenSource{ clientID: clientID, clientSecret: secret, tokenHost: tokenHost, httpClient: httpClient, - }), clientName, tokenHost, nil + }), nil }