From bb9295c27af966938bec3363efcd0732f18ab5a3 Mon Sep 17 00:00:00 2001 From: Marcus Pasell <3690498+rickyrombo@users.noreply.github.com> Date: Tue, 3 Mar 2026 19:18:35 -0800 Subject: [PATCH 1/5] Add OAuth endpoints --- api/auth_middleware.go | 29 +- api/request_helpers.go | 34 ++ api/server.go | 16 + api/v1_oauth.go | 618 ++++++++++++++++++++++++++++++ api/v1_oauth_test.go | 826 +++++++++++++++++++++++++++++++++++++++++ database/seed.go | 28 ++ 6 files changed, 1542 insertions(+), 9 deletions(-) create mode 100644 api/v1_oauth.go create mode 100644 api/v1_oauth_test.go diff --git a/api/auth_middleware.go b/api/auth_middleware.go index 90748081..7db7e14e 100644 --- a/api/auth_middleware.go +++ b/api/auth_middleware.go @@ -239,16 +239,27 @@ func (app *ApiServer) authMiddleware(c *fiber.Ctx) error { wallet = strings.ToLower(signer.Address) } else { wallet = app.recoverAuthorityFromSignatureHeaders(c) + // Extract Bearer token once for the fallback checks below + var bearerToken string + if authHeader := c.Get("Authorization"); authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") { + bearerToken = strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + } + // OAuth JWT fallback: when Bearer token is not api_access_key, try as OAuth JWT (Plans app) - if wallet == "" && myId != 0 { - if authHeader := c.Get("Authorization"); authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") { - token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) - if token != "" { - if oauthWallet, jwtUserId, err := app.validateOAuthJWTTokenToWalletAndUserId(c.Context(), token); err == nil { - if int32(jwtUserId) == myId { - wallet = oauthWallet - } - } + if wallet == "" && myId != 0 && bearerToken != "" { + if oauthWallet, jwtUserId, err := app.validateOAuthJWTTokenToWalletAndUserId(c.Context(), bearerToken); err == nil { + if int32(jwtUserId) == myId { + wallet = oauthWallet + } + } + } + // PKCE token fallback: resolve opaque Bearer token from oauth_tokens + if wallet == "" && bearerToken != "" { + if entry, ok := app.lookupOAuthAccessToken(c, bearerToken); ok { + wallet = strings.ToLower(entry.ClientID) + if myId == 0 { + myId = entry.UserID + c.Locals("myId", int(entry.UserID)) } } } diff --git a/api/request_helpers.go b/api/request_helpers.go index ed507f14..f8561600 100644 --- a/api/request_helpers.go +++ b/api/request_helpers.go @@ -57,6 +57,12 @@ func (app *ApiServer) getApiSigner(c *fiber.Ctx) (*Signer, error) { return signer, nil } } + // Try PKCE token → look up client_id → get api_secret from api_keys → return Signer + if app.writePool != nil { + if signer := app.getSignerFromOAuthToken(c, token); signer != nil { + return signer, nil + } + } // If authMiddleware already validated a JWT and set authedWallet, // use AudiusApiSecret to sign on behalf of the authenticated user. if wallet, _ := c.Locals("authedWallet").(string); wallet != "" && app.config.AudiusApiSecret != "" { @@ -158,3 +164,31 @@ func (app *ApiServer) getSignerFromApiAccessKey(ctx context.Context, apiAccessKe PrivateKey: privateKey, } } + +// getSignerFromOAuthToken looks up a PKCE access token, resolves the client_id to an api_key, +// then gets the api_secret to build a Signer. This allows writes (ManageEntity signing) +// to work for PKCE-authenticated requests. +func (app *ApiServer) getSignerFromOAuthToken(c *fiber.Ctx, token string) *Signer { + entry, ok := app.lookupOAuthAccessToken(c, token) + if !ok { + return nil + } + + // Look up api_secret for the client_id (developer app address = api_key) + var apiSecret string + err := app.writePool.QueryRow(c.Context(), ` + SELECT api_secret FROM api_keys WHERE LOWER(api_key) = LOWER($1) + `, entry.ClientID).Scan(&apiSecret) + if err != nil || apiSecret == "" { + return nil + } + + privateKey, err := crypto.HexToECDSA(strings.TrimPrefix(apiSecret, "0x")) + if err != nil { + return nil + } + return &Signer{ + Address: strings.ToLower(entry.ClientID), + PrivateKey: privateKey, + } +} diff --git a/api/server.go b/api/server.go index 1dfaa703..609de0a4 100644 --- a/api/server.go +++ b/api/server.go @@ -129,6 +129,14 @@ func NewApiServer(config config.Config) *ApiServer { panic(err) } + oauthTokenCache, err := otter.MustBuilder[string, oauthTokenCacheEntry](10_000). + WithTTL(60 * time.Second). + CollectStats(). + Build() + if err != nil { + panic(err) + } + privateKey, err := crypto.HexToECDSA(config.DelegatePrivateKey) if err != nil { panic(err) @@ -233,6 +241,7 @@ func NewApiServer(config config.Config) *ApiServer { resolveGrantCache: &resolveGrantCache, resolveWalletCache: &resolveWalletCache, apiAccessKeySignerCache: &apiAccessKeySignerCache, + oauthTokenCache: &oauthTokenCache, requestValidator: requestValidator, rewardAttester: rewardAttester, transactionSender: transactionSender, @@ -541,6 +550,12 @@ func NewApiServer(config config.Config) *ApiServer { g.Post("/developer_apps/:address/access-keys", app.postV1UsersDeveloperAppAccessKey) g.Post("/developer-apps/:address/access-keys", app.postV1UsersDeveloperAppAccessKey) + // OAuth2 PKCE + g.Post("/oauth/authorize", app.v1OAuthAuthorize) + g.Post("/oauth/token", app.v1OAuthToken) + g.Post("/oauth/revoke", app.v1OAuthRevoke) + g.Get("/oauth/me", app.requireAuthMiddleware, app.v1OAuthMe) + // Rewards g.Post("/rewards/claim", app.v1ClaimRewards) g.Post("/rewards/code", app.v1CreateRewardCode) @@ -737,6 +752,7 @@ type ApiServer struct { resolveGrantCache *otter.Cache[string, bool] resolveWalletCache *otter.Cache[string, int] apiAccessKeySignerCache *otter.Cache[string, apiAccessKeySignerEntry] + oauthTokenCache *otter.Cache[string, oauthTokenCacheEntry] requestValidator *RequestValidator rewardManagerClient *reward_manager.RewardManagerClient claimableTokensClient *claimable_tokens.ClaimableTokensClient diff --git a/api/v1_oauth.go b/api/v1_oauth.go new file mode 100644 index 00000000..b049d06e --- /dev/null +++ b/api/v1_oauth.go @@ -0,0 +1,618 @@ +package api + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "time" + + "api.audius.co/trashid" + "github.com/gofiber/fiber/v2" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" +) + +// oauthError returns an RFC 6749 §5.2 error response. +func oauthError(c *fiber.Ctx, status int, errCode, description string) error { + return c.Status(status).JSON(fiber.Map{ + "error": errCode, + "error_description": description, + }) +} + +// generateRandomToken generates a cryptographically random URL-safe base64 string of n bytes. +func generateRandomToken(nBytes int) (string, error) { + b := make([]byte, nBytes) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// --- Request body structs --- + +type oauthAuthorizeBody struct { + Token string `json:"token"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` + Scope string `json:"scope"` +} + +type oauthTokenBody struct { + GrantType string `json:"grant_type"` + Code string `json:"code"` + CodeVerifier string `json:"code_verifier"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + RefreshToken string `json:"refresh_token"` +} + +type oauthRevokeBody struct { + Token string `json:"token"` + ClientID string `json:"client_id"` +} + +// --- Cache entry for PKCE tokens --- + +type oauthTokenCacheEntry struct { + UserID int32 + ClientID string + Scope string +} + +// --- Handlers --- + +// v1OAuthAuthorize handles POST /v1/oauth/authorize +// Called by the audius.co consent screen after the user authenticates. +func (app *ApiServer) v1OAuthAuthorize(c *fiber.Ctx) error { + var body oauthAuthorizeBody + if err := c.BodyParser(&body); err != nil { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "Invalid request body") + } + + // Validate required fields + if body.Token == "" || body.ClientID == "" || body.RedirectURI == "" || body.CodeChallenge == "" || body.Scope == "" { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "Missing required parameters") + } + + if body.CodeChallengeMethod != "S256" { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "code_challenge_method must be S256") + } + + if body.Scope != "read" && body.Scope != "write" { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "scope must be 'read' or 'write'") + } + + // 1. Validate JWT via existing logic and also check iat + userId, err := app.validateOAuthJWTTokenToUserId(c.Context(), body.Token) + if err != nil { + return oauthError(c, fiber.StatusUnauthorized, "access_denied", "Invalid JWT token") + } + + // Validate iat (issued-at) — reject if more than 5 min in the past or in the future + if err := app.validateJWTIat(body.Token); err != nil { + return oauthError(c, fiber.StatusUnauthorized, "access_denied", err.Error()) + } + + // 2. Validate client_id exists in developer_apps + clientID := strings.ToLower(body.ClientID) + var appExists bool + err = app.pool.QueryRow(c.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM developer_apps + WHERE LOWER(address) = $1 AND is_current = true AND NOT is_delete + ) + `, clientID).Scan(&appExists) + if err != nil || !appExists { + return oauthError(c, fiber.StatusBadRequest, "invalid_client", "Unknown client_id") + } + + // 3. Validate redirect_uri + if !strings.EqualFold(body.RedirectURI, "postmessage") { + var uriRegistered bool + err = app.pool.QueryRow(c.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM oauth_redirect_uris + WHERE LOWER(client_id) = $1 AND redirect_uri = $2 + ) + `, clientID, body.RedirectURI).Scan(&uriRegistered) + if err != nil || !uriRegistered { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "redirect_uri not registered") + } + } + + // 4. If scope is write, check for existing approved grant + if body.Scope == "write" { + var grantExists bool + err = app.pool.QueryRow(c.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM grants + WHERE user_id = $1 + AND LOWER(grantee_address) = $2 + AND is_current = true + AND is_approved = true + AND is_revoked = false + ) + `, int32(userId), clientID).Scan(&grantExists) + if err != nil || !grantExists { + return oauthError(c, fiber.StatusForbidden, "access_denied", "No approved grant exists for this app. The user must grant permission first.") + } + } + + // 5. Generate random authorization code + code, err := generateRandomToken(32) + if err != nil { + app.logger.Error("Failed to generate auth code", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to generate authorization code") + } + + // Insert into oauth_authorization_codes + _, err = app.writePool.Exec(c.Context(), ` + INSERT INTO oauth_authorization_codes (code, client_id, user_id, redirect_uri, code_challenge, code_challenge_method, scope) + VALUES ($1, $2, $3, $4, $5, $6, $7) + `, code, clientID, int32(userId), body.RedirectURI, body.CodeChallenge, body.CodeChallengeMethod, body.Scope) + if err != nil { + app.logger.Error("Failed to insert auth code", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to create authorization code") + } + + // 6. Return the code + return c.JSON(fiber.Map{ + "code": code, + }) +} + +// v1OAuthToken handles POST /v1/oauth/token +// Supports grant_type=authorization_code and grant_type=refresh_token. +func (app *ApiServer) v1OAuthToken(c *fiber.Ctx) error { + var body oauthTokenBody + if err := c.BodyParser(&body); err != nil { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "Invalid request body") + } + + switch body.GrantType { + case "authorization_code": + return app.oauthTokenAuthorizationCode(c, &body) + case "refresh_token": + return app.oauthTokenRefreshToken(c, &body) + default: + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "grant_type must be 'authorization_code' or 'refresh_token'") + } +} + +func (app *ApiServer) oauthTokenAuthorizationCode(c *fiber.Ctx, body *oauthTokenBody) error { + if body.Code == "" || body.CodeVerifier == "" || body.ClientID == "" || body.RedirectURI == "" { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "Missing required parameters for authorization_code grant") + } + + clientID := strings.ToLower(body.ClientID) + + // Atomically consume the code + var storedClientID, storedRedirectURI, storedCodeChallenge, storedCodeChallengeMethod, storedScope string + var storedUserID int32 + err := app.writePool.QueryRow(c.Context(), ` + UPDATE oauth_authorization_codes + SET used = true + WHERE code = $1 AND used = false AND expires_at > NOW() + RETURNING client_id, user_id, redirect_uri, code_challenge, code_challenge_method, scope + `, body.Code).Scan(&storedClientID, &storedUserID, &storedRedirectURI, &storedCodeChallenge, &storedCodeChallengeMethod, &storedScope) + if err != nil { + if err == pgx.ErrNoRows { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "Authorization code is invalid, expired, or already used") + } + app.logger.Error("Failed to consume auth code", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to process authorization code") + } + + // Verify client_id matches + if strings.ToLower(storedClientID) != clientID { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "client_id mismatch") + } + + // Verify redirect_uri matches + if storedRedirectURI != body.RedirectURI { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") + } + + // Verify PKCE: base64url(SHA256(code_verifier)) must equal the stored code_challenge + if storedCodeChallengeMethod != "S256" { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "Unsupported code_challenge_method") + } + h := sha256.Sum256([]byte(body.CodeVerifier)) + computedChallenge := base64.RawURLEncoding.EncodeToString(h[:]) + if computedChallenge != storedCodeChallenge { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "PKCE verification failed") + } + + // Generate family_id, access token, refresh token + familyID, err := generateRandomToken(32) + if err != nil { + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to generate tokens") + } + accessToken, err := generateRandomToken(32) + if err != nil { + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to generate tokens") + } + refreshToken, err := generateRandomToken(32) + if err != nil { + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to generate tokens") + } + + accessExpiresAt := time.Now().Add(1 * time.Hour) + refreshExpiresAt := time.Now().Add(30 * 24 * time.Hour) // 30 days + + // Insert both tokens + _, err = app.writePool.Exec(c.Context(), ` + INSERT INTO oauth_tokens (token, token_type, client_id, user_id, scope, expires_at, family_id) + VALUES ($1, 'access', $2, $3, $4, $5, $6) + `, accessToken, clientID, storedUserID, storedScope, accessExpiresAt, familyID) + if err != nil { + app.logger.Error("Failed to insert access token", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to create tokens") + } + + _, err = app.writePool.Exec(c.Context(), ` + INSERT INTO oauth_tokens (token, token_type, client_id, user_id, scope, expires_at, family_id, refresh_token_id) + VALUES ($1, 'refresh', $2, $3, $4, $5, $6, $7) + `, refreshToken, clientID, storedUserID, storedScope, refreshExpiresAt, familyID, refreshToken) + if err != nil { + app.logger.Error("Failed to insert refresh token", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to create tokens") + } + + return c.JSON(fiber.Map{ + "access_token": accessToken, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": refreshToken, + "scope": storedScope, + }) +} + +func (app *ApiServer) oauthTokenRefreshToken(c *fiber.Ctx, body *oauthTokenBody) error { + if body.RefreshToken == "" || body.ClientID == "" { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "Missing required parameters for refresh_token grant") + } + + clientID := strings.ToLower(body.ClientID) + + // First, check if the token exists and whether it triggers reuse detection. + // We need a two-phase approach: check for reuse first, then atomically consume. + var storedClientID string + var storedUserID int32 + var storedScope, storedFamilyID string + var storedIsRevoked bool + var storedExpiresAt time.Time + var storedTokenType string + err := app.writePool.QueryRow(c.Context(), ` + SELECT client_id, user_id, scope, family_id, is_revoked, expires_at, token_type + FROM oauth_tokens + WHERE token = $1 + `, body.RefreshToken).Scan(&storedClientID, &storedUserID, &storedScope, &storedFamilyID, &storedIsRevoked, &storedExpiresAt, &storedTokenType) + if err != nil { + if err == pgx.ErrNoRows { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "Invalid refresh token") + } + app.logger.Error("Failed to look up refresh token", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to process refresh token") + } + + // Must be a refresh token type + if storedTokenType != "refresh" { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "Invalid refresh token") + } + + // Verify client_id matches + if strings.ToLower(storedClientID) != clientID { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "client_id mismatch") + } + + // If the token is already revoked — token reuse detected. Revoke all tokens in the family. + if storedIsRevoked { + _, _ = app.writePool.Exec(c.Context(), ` + UPDATE oauth_tokens SET is_revoked = true WHERE family_id = $1 + `, storedFamilyID) + // Invalidate cache entries for this family + app.invalidateOAuthTokenCacheByFamily(c.Context(), storedFamilyID) + app.logger.Warn("Refresh token reuse detected, revoking family", + zap.String("family_id", storedFamilyID), + zap.Int32("user_id", storedUserID), + ) + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "Token reuse detected") + } + + // If expired + if time.Now().After(storedExpiresAt) { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "Refresh token expired") + } + + // Atomically consume the refresh token to prevent race conditions. + // If two concurrent requests try to use the same token, only one will succeed. + var consumed bool + err = app.writePool.QueryRow(c.Context(), ` + UPDATE oauth_tokens SET is_revoked = true + WHERE token = $1 AND is_revoked = false + RETURNING true + `, body.RefreshToken).Scan(&consumed) + if err != nil { + if err == pgx.ErrNoRows { + // Another concurrent request already consumed this token — treat as reuse + _, _ = app.writePool.Exec(c.Context(), ` + UPDATE oauth_tokens SET is_revoked = true WHERE family_id = $1 + `, storedFamilyID) + app.invalidateOAuthTokenCacheByFamily(c.Context(), storedFamilyID) + app.logger.Warn("Refresh token concurrent reuse detected, revoking family", + zap.String("family_id", storedFamilyID), + zap.Int32("user_id", storedUserID), + ) + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "Token reuse detected") + } + app.logger.Error("Failed to revoke old refresh token", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to revoke old refresh token") + } + + // Generate new tokens with the same family_id + accessToken, err := generateRandomToken(32) + if err != nil { + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to generate tokens") + } + refreshToken, err := generateRandomToken(32) + if err != nil { + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to generate tokens") + } + + accessExpiresAt := time.Now().Add(1 * time.Hour) + refreshExpiresAt := time.Now().Add(30 * 24 * time.Hour) + + _, err = app.writePool.Exec(c.Context(), ` + INSERT INTO oauth_tokens (token, token_type, client_id, user_id, scope, expires_at, family_id) + VALUES ($1, 'access', $2, $3, $4, $5, $6) + `, accessToken, clientID, storedUserID, storedScope, accessExpiresAt, storedFamilyID) + if err != nil { + app.logger.Error("Failed to insert new access token", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to create tokens") + } + + _, err = app.writePool.Exec(c.Context(), ` + INSERT INTO oauth_tokens (token, token_type, client_id, user_id, scope, expires_at, family_id, refresh_token_id) + VALUES ($1, 'refresh', $2, $3, $4, $5, $6, $7) + `, refreshToken, clientID, storedUserID, storedScope, refreshExpiresAt, storedFamilyID, refreshToken) + if err != nil { + app.logger.Error("Failed to insert new refresh token", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to create tokens") + } + + return c.JSON(fiber.Map{ + "access_token": accessToken, + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": refreshToken, + "scope": storedScope, + }) +} + +// v1OAuthRevoke handles POST /v1/oauth/revoke +// Per RFC 7009 §2.2, always returns 200 regardless of whether the token was valid. +func (app *ApiServer) v1OAuthRevoke(c *fiber.Ctx) error { + var body oauthRevokeBody + if err := c.BodyParser(&body); err != nil { + // Per RFC 7009, even for bad requests we should be lenient, + // but missing token is an actual error + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "Invalid request body") + } + + if body.Token == "" { + return oauthError(c, fiber.StatusBadRequest, "invalid_request", "token is required") + } + + // Look up the token to get family_id for family-wide revocation + var familyID string + err := app.writePool.QueryRow(c.Context(), ` + SELECT family_id FROM oauth_tokens WHERE token = $1 + `, body.Token).Scan(&familyID) + if err != nil { + // Token not found or error — per RFC 7009, return 200 anyway + return c.JSON(fiber.Map{}) + } + + // Revoke all tokens in the family + _, err = app.writePool.Exec(c.Context(), ` + UPDATE oauth_tokens SET is_revoked = true WHERE family_id = $1 + `, familyID) + if err != nil { + app.logger.Error("Failed to revoke token family", zap.Error(err)) + } + + // Invalidate cache + app.invalidateOAuthTokenCacheByFamily(c.Context(), familyID) + + return c.JSON(fiber.Map{}) +} + +// v1OAuthMe handles GET /v1/oauth/me +// Returns the authenticated user's profile based on Bearer access token. +func (app *ApiServer) v1OAuthMe(c *fiber.Ctx) error { + // Extract Bearer token + authHeader := c.Get("Authorization") + if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { + return oauthError(c, fiber.StatusUnauthorized, "invalid_token", "Missing or invalid Authorization header") + } + token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + if token == "" { + return oauthError(c, fiber.StatusUnauthorized, "invalid_token", "Bearer token is empty") + } + + // Look up the access token (try cache first) + entry, ok := app.lookupOAuthAccessToken(c, token) + if !ok { + return oauthError(c, fiber.StatusUnauthorized, "invalid_token", "Invalid or expired access token") + } + + // Query user info + encodedUserId, _ := trashid.EncodeHashId(int(entry.UserID)) + + var handle, name string + var verified bool + var profilePicture *string + err := app.pool.QueryRow(c.Context(), ` + SELECT handle, name, is_verified, + CASE WHEN profile_picture_sizes IS NOT NULL + THEN CONCAT($2::text, '/content/', profile_picture_sizes, '/150x150.jpg') + ELSE NULL + END as profile_picture + FROM users + WHERE user_id = $1 AND is_current = true + `, entry.UserID, app.audiusAppUrl).Scan(&handle, &name, &verified, &profilePicture) + if err != nil { + if err == pgx.ErrNoRows { + return oauthError(c, fiber.StatusNotFound, "invalid_token", "User not found") + } + app.logger.Error("Failed to query user for /oauth/me", zap.Error(err)) + return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to get user info") + } + + response := fiber.Map{ + "userId": encodedUserId, + "name": name, + "handle": handle, + "verified": verified, + "sub": encodedUserId, + "iat": time.Now().Unix(), + } + if profilePicture != nil { + response["profilePicture"] = *profilePicture + } + + // Try to get email if available + var email *string + _ = app.pool.QueryRow(c.Context(), ` + SELECT email_address FROM user_emails WHERE user_id = $1 + `, entry.UserID).Scan(&email) + if email != nil { + response["email"] = *email + } + + return c.JSON(response) +} + +// --- Helper methods --- + +// validateJWTIat extracts and validates the iat (issued-at) claim from a JWT. +// Returns an error if iat is missing or more than 5 minutes from now. +func (app *ApiServer) validateJWTIat(token string) error { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return fmt.Errorf("invalid JWT format") + } + + paddedPayload := parts[1] + if len(paddedPayload)%4 != 0 { + paddedPayload += strings.Repeat("=", 4-len(paddedPayload)%4) + } + payloadBytes, err := base64.URLEncoding.DecodeString(paddedPayload) + if err != nil { + return fmt.Errorf("JWT payload could not be decoded") + } + var payload map[string]interface{} + if err := json.Unmarshal(payloadBytes, &payload); err != nil { + return fmt.Errorf("JWT payload could not be unmarshalled") + } + + iatRaw, exists := payload["iat"] + if !exists { + return fmt.Errorf("JWT missing iat claim") + } + + var iat float64 + switch v := iatRaw.(type) { + case float64: + iat = v + case string: + // iat might be a string-encoded number + var parsed float64 + if _, err := fmt.Sscanf(v, "%f", &parsed); err != nil { + return fmt.Errorf("JWT iat is not a valid number") + } + iat = parsed + default: + return fmt.Errorf("JWT iat has unexpected type") + } + + iatTime := time.Unix(int64(iat), 0) + diff := time.Since(iatTime) + if diff < 0 { + diff = -diff + } + if diff > 5*time.Minute { + return fmt.Errorf("JWT iat is too far from current time") + } + + return nil +} + +// lookupOAuthAccessToken checks the cache first, then the database, for a valid access token. +// Returns the cache entry and true if found and valid, false otherwise. +func (app *ApiServer) lookupOAuthAccessToken(c *fiber.Ctx, token string) (oauthTokenCacheEntry, bool) { + // Check cache + if app.oauthTokenCache != nil { + if entry, ok := app.oauthTokenCache.Get(token); ok { + return entry, true + } + } + + // Query database + var userID int32 + var clientID, scope string + var expiresAt time.Time + err := app.pool.QueryRow(c.Context(), ` + SELECT user_id, client_id, scope, expires_at + FROM oauth_tokens + WHERE token = $1 AND token_type = 'access' AND is_revoked = false AND expires_at > NOW() + `, token).Scan(&userID, &clientID, &scope, &expiresAt) + if err != nil { + return oauthTokenCacheEntry{}, false + } + + entry := oauthTokenCacheEntry{ + UserID: userID, + ClientID: clientID, + Scope: scope, + } + + // Cache the result + if app.oauthTokenCache != nil { + app.oauthTokenCache.Set(token, entry) + } + + return entry, true +} + +// invalidateOAuthTokenCacheByFamily removes all cached tokens belonging to a family. +// Since otter doesn't support iteration/invalidation by value, we query the DB for +// all tokens in the family and delete them from cache individually. +func (app *ApiServer) invalidateOAuthTokenCacheByFamily(ctx context.Context, familyID string) { + if app.oauthTokenCache == nil { + return + } + + rows, err := app.writePool.Query(ctx, ` + SELECT token FROM oauth_tokens WHERE family_id = $1 + `, familyID) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + var token string + if err := rows.Scan(&token); err == nil { + app.oauthTokenCache.Delete(token) + } + } +} diff --git a/api/v1_oauth_test.go b/api/v1_oauth_test.go new file mode 100644 index 00000000..89e9d621 --- /dev/null +++ b/api/v1_oauth_test.go @@ -0,0 +1,826 @@ +package api + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http/httptest" + "strings" + "testing" + "time" + + "api.audius.co/database" + + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// seedOAuthTestData seeds the database with test data for OAuth tests. +// Uses database.Seed for standard tables and database.SeedTable for OAuth tables. +// Returns the client_id (developer app address). +func seedOAuthTestData(t *testing.T, app *ApiServer) string { + t.Helper() + clientID := "0xaabb000000000000000000000000000000000001" + + database.Seed(app.pool.Replicas[0], database.FixtureMap{ + "users": { + { + "user_id": 100, + "handle": "oauthuser", + "handle_lc": "oauthuser", + "wallet": "0xoauthuserwallet000000000000000000000000", + "name": "OAuth User", + }, + }, + "developer_apps": { + { + "address": clientID, + "user_id": 100, + "name": "Test OAuth App", + "is_delete": false, + }, + }, + "grants": { + { + "user_id": 100, + "grantee_address": clientID, + "is_approved": true, + }, + }, + "oauth_redirect_uris": { + { + "client_id": clientID, + "redirect_uri": "https://example.com/callback", + }, + }, + }) + + return clientID +} + +// insertTestAuthCode inserts a test auth code and returns the code, code_verifier, and code_challenge. +func insertTestAuthCode(t *testing.T, app *ApiServer, clientID string, userID int, scope string) (code, codeVerifier, codeChallenge string) { + t.Helper() + codeVerifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + h := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(h[:]) + + code = "test-auth-code-" + fmt.Sprintf("%d", time.Now().UnixNano()) + + database.SeedTable(app.pool.Replicas[0], "oauth_authorization_codes", []map[string]any{ + { + "code": code, + "client_id": clientID, + "user_id": userID, + "redirect_uri": "https://example.com/callback", + "code_challenge": codeChallenge, + "scope": scope, + }, + }) + return +} + +// insertTestTokens inserts an access and refresh token pair. +func insertTestTokens(t *testing.T, app *ApiServer, clientID string, userID int, scope, familyID string, accessExpiresIn, refreshExpiresIn time.Duration) (accessToken, refreshToken string) { + t.Helper() + accessToken = "test-access-" + fmt.Sprintf("%d", time.Now().UnixNano()) + refreshToken = "test-refresh-" + fmt.Sprintf("%d", time.Now().UnixNano()) + + database.SeedTable(app.pool.Replicas[0], "oauth_tokens", []map[string]any{ + { + "token": accessToken, + "token_type": "access", + "client_id": clientID, + "user_id": userID, + "scope": scope, + "expires_at": time.Now().Add(accessExpiresIn), + "family_id": familyID, + }, + { + "token": refreshToken, + "token_type": "refresh", + "client_id": clientID, + "user_id": userID, + "scope": scope, + "expires_at": time.Now().Add(refreshExpiresIn), + "family_id": familyID, + "refresh_token_id": refreshToken, + }, + }) + + return +} + +// oauthPostJSON is a helper for posting JSON to OAuth endpoints via the app router. +func oauthPostJSON(t *testing.T, app *ApiServer, path string, body map[string]string) (int, []byte) { + t.Helper() + jsonBytes, err := json.Marshal(body) + require.NoError(t, err) + return testPost(t, app, path, jsonBytes, map[string]string{"Content-Type": "application/json"}) +} + +// oauthGetWithBearer is a helper for GET requests with a Bearer token. +func oauthGetWithBearer(t *testing.T, app *ApiServer, path, token string) (int, []byte) { + t.Helper() + req := httptest.NewRequest("GET", path, nil) + req.Header.Set("Authorization", "Bearer "+token) + res, err := app.Test(req, -1) + require.NoError(t, err) + body, _ := io.ReadAll(res.Body) + return res.StatusCode, body +} + +// --- /oauth/token (authorization_code grant) --- + +func TestOAuthTokenAuthorizationCode(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + code, codeVerifier, _ := insertTestAuthCode(t, app, clientID, 100, "write") + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "authorization_code", + "code": code, + "code_verifier": codeVerifier, + "client_id": clientID, + "redirect_uri": "https://example.com/callback", + }) + + assert.Equal(t, 200, status) + assert.True(t, gjson.GetBytes(body, "access_token").Exists()) + assert.True(t, gjson.GetBytes(body, "refresh_token").Exists()) + jsonAssert(t, body, map[string]any{ + "token_type": "Bearer", + "expires_in": float64(3600), + "scope": "write", + }) +} + +func TestOAuthTokenAuthorizationCode_InvalidCode(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "authorization_code", + "code": "nonexistent-code", + "code_verifier": "whatever", + "client_id": clientID, + "redirect_uri": "https://example.com/callback", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) +} + +func TestOAuthTokenAuthorizationCode_CodeReuse(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + code, codeVerifier, _ := insertTestAuthCode(t, app, clientID, 100, "write") + + reqBody := map[string]string{ + "grant_type": "authorization_code", + "code": code, + "code_verifier": codeVerifier, + "client_id": clientID, + "redirect_uri": "https://example.com/callback", + } + + // First use succeeds + status1, _ := oauthPostJSON(t, app, "/v1/oauth/token", reqBody) + assert.Equal(t, 200, status1) + + // Second use fails — code already consumed + status2, resBody := oauthPostJSON(t, app, "/v1/oauth/token", reqBody) + assert.Equal(t, 400, status2) + jsonAssert(t, resBody, map[string]any{"error": "invalid_grant"}) +} + +func TestOAuthTokenAuthorizationCode_WrongCodeVerifier(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + code, _, _ := insertTestAuthCode(t, app, clientID, 100, "write") + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "authorization_code", + "code": code, + "code_verifier": "wrong-verifier-that-does-not-match", + "client_id": clientID, + "redirect_uri": "https://example.com/callback", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) + assert.Contains(t, gjson.GetBytes(body, "error_description").String(), "PKCE") +} + +func TestOAuthTokenAuthorizationCode_ClientIDMismatch(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + code, codeVerifier, _ := insertTestAuthCode(t, app, clientID, 100, "write") + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "authorization_code", + "code": code, + "code_verifier": codeVerifier, + "client_id": "0xwrongclient00000000000000000000000000000", + "redirect_uri": "https://example.com/callback", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) +} + +func TestOAuthTokenAuthorizationCode_RedirectURIMismatch(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + code, codeVerifier, _ := insertTestAuthCode(t, app, clientID, 100, "write") + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "authorization_code", + "code": code, + "code_verifier": codeVerifier, + "client_id": clientID, + "redirect_uri": "https://evil.com/callback", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) + assert.Contains(t, gjson.GetBytes(body, "error_description").String(), "redirect_uri") +} + +func TestOAuthTokenAuthorizationCode_ExpiredCode(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + codeVerifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + h := sha256.Sum256([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(h[:]) + code := "expired-code-test" + + database.SeedTable(app.pool.Replicas[0], "oauth_authorization_codes", []map[string]any{ + { + "code": code, + "client_id": clientID, + "user_id": 100, + "code_challenge": codeChallenge, + "scope": "write", + "expires_at": time.Now().Add(-time.Hour), + }, + }) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "authorization_code", + "code": code, + "code_verifier": codeVerifier, + "client_id": clientID, + "redirect_uri": "https://example.com/callback", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) +} + +func TestOAuthTokenAuthorizationCode_MissingParams(t *testing.T) { + app := emptyTestApp(t) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "authorization_code", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_request"}) +} + +// --- /oauth/token (refresh_token grant) --- + +func TestOAuthTokenRefresh(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + familyID := "test-family-1" + _, refreshToken := insertTestTokens(t, app, clientID, 100, "write", familyID, time.Hour, 30*24*time.Hour) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": clientID, + }) + + assert.Equal(t, 200, status) + assert.True(t, gjson.GetBytes(body, "access_token").Exists()) + assert.True(t, gjson.GetBytes(body, "refresh_token").Exists()) + jsonAssert(t, body, map[string]any{ + "token_type": "Bearer", + "expires_in": float64(3600), + "scope": "write", + }) + assert.NotEqual(t, refreshToken, gjson.GetBytes(body, "refresh_token").String()) +} + +func TestOAuthTokenRefresh_ReuseDetection(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + familyID := "test-family-reuse" + _, refreshToken := insertTestTokens(t, app, clientID, 100, "write", familyID, time.Hour, 30*24*time.Hour) + + reqBody := map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": clientID, + } + + // First use succeeds + status1, body1 := oauthPostJSON(t, app, "/v1/oauth/token", reqBody) + assert.Equal(t, 200, status1) + newRefreshToken := gjson.GetBytes(body1, "refresh_token").String() + + // Reuse the old refresh token — should trigger reuse detection + status2, body2 := oauthPostJSON(t, app, "/v1/oauth/token", reqBody) + assert.Equal(t, 400, status2) + jsonAssert(t, body2, map[string]any{"error": "invalid_grant"}) + assert.Contains(t, gjson.GetBytes(body2, "error_description").String(), "reuse") + + // The new refresh token from the first rotation should also be revoked (family revocation) + status3, body3 := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "refresh_token", + "refresh_token": newRefreshToken, + "client_id": clientID, + }) + assert.Equal(t, 400, status3) + jsonAssert(t, body3, map[string]any{"error": "invalid_grant"}) +} + +func TestOAuthTokenRefresh_ExpiredToken(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + familyID := "test-family-expired" + refreshToken := "expired-refresh-" + fmt.Sprintf("%d", time.Now().UnixNano()) + + database.SeedTable(app.pool.Replicas[0], "oauth_tokens", []map[string]any{ + { + "token": refreshToken, + "token_type": "refresh", + "client_id": clientID, + "user_id": 100, + "scope": "write", + "expires_at": time.Now().Add(-time.Hour), + "family_id": familyID, + "refresh_token_id": refreshToken, + }, + }) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": clientID, + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) + assert.Contains(t, gjson.GetBytes(body, "error_description").String(), "expired") +} + +func TestOAuthTokenRefresh_ClientIDMismatch(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + familyID := "test-family-client-mismatch" + _, refreshToken := insertTestTokens(t, app, clientID, 100, "write", familyID, time.Hour, 30*24*time.Hour) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": "0xwrongclient00000000000000000000000000000", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) + assert.Contains(t, gjson.GetBytes(body, "error_description").String(), "client_id") +} + +func TestOAuthTokenRefresh_AccessTokenNotRefreshable(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + familyID := "test-family-access-not-refreshable" + accessToken, _ := insertTestTokens(t, app, clientID, 100, "write", familyID, time.Hour, 30*24*time.Hour) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "refresh_token", + "refresh_token": accessToken, + "client_id": clientID, + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) +} + +func TestOAuthTokenRefresh_InvalidToken(t *testing.T) { + app := emptyTestApp(t) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "refresh_token", + "refresh_token": "nonexistent-token", + "client_id": "0xaabb000000000000000000000000000000000001", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) +} + +// --- /oauth/token (invalid grant_type) --- + +func TestOAuthToken_InvalidGrantType(t *testing.T) { + app := emptyTestApp(t) + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "client_credentials", + }) + + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_request"}) +} + +// --- /oauth/revoke --- + +func TestOAuthRevoke(t *testing.T) { + app := emptyTestApp(t) + ctx := context.Background() + clientID := seedOAuthTestData(t, app) + + familyID := "test-family-revoke" + accessToken, refreshToken := insertTestTokens(t, app, clientID, 100, "write", familyID, time.Hour, 30*24*time.Hour) + + status, _ := oauthPostJSON(t, app, "/v1/oauth/revoke", map[string]string{ + "token": accessToken, + "client_id": clientID, + }) + assert.Equal(t, 200, status) + + // Verify both tokens in the family are revoked + var isRevoked bool + err := app.writePool.QueryRow(ctx, ` + SELECT is_revoked FROM oauth_tokens WHERE token = $1 + `, accessToken).Scan(&isRevoked) + require.NoError(t, err) + assert.True(t, isRevoked, "access token should be revoked") + + err = app.writePool.QueryRow(ctx, ` + SELECT is_revoked FROM oauth_tokens WHERE token = $1 + `, refreshToken).Scan(&isRevoked) + require.NoError(t, err) + assert.True(t, isRevoked, "refresh token should be revoked (family revocation)") +} + +func TestOAuthRevoke_UnknownToken(t *testing.T) { + app := emptyTestApp(t) + + // Per RFC 7009, always returns 200 even for unknown tokens + status, _ := oauthPostJSON(t, app, "/v1/oauth/revoke", map[string]string{ + "token": "nonexistent-token-xyz", + "client_id": "0xwhatever", + }) + assert.Equal(t, 200, status) +} + +func TestOAuthRevoke_MissingToken(t *testing.T) { + app := emptyTestApp(t) + + status, body := oauthPostJSON(t, app, "/v1/oauth/revoke", map[string]string{ + "client_id": "0xwhatever", + }) + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_request"}) +} + +// --- /oauth/me --- + +func TestOAuthMe(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + familyID := "test-family-me" + accessToken, _ := insertTestTokens(t, app, clientID, 100, "read", familyID, time.Hour, 30*24*time.Hour) + + status, body := oauthGetWithBearer(t, app, "/v1/oauth/me", accessToken) + + assert.Equal(t, 200, status) + assert.True(t, gjson.GetBytes(body, "userId").Exists()) + jsonAssert(t, body, map[string]any{ + "handle": "oauthuser", + "name": "OAuth User", + "verified": false, + }) + assert.Equal(t, + gjson.GetBytes(body, "userId").String(), + gjson.GetBytes(body, "sub").String(), + ) +} + +func TestOAuthMe_InvalidToken(t *testing.T) { + app := emptyTestApp(t) + + status, _ := oauthGetWithBearer(t, app, "/v1/oauth/me", "invalid-token") + + // requireAuthMiddleware intercepts before the handler runs + assert.Equal(t, 401, status) +} + +func TestOAuthMe_ExpiredToken(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + expiredToken := "expired-access-" + fmt.Sprintf("%d", time.Now().UnixNano()) + + database.SeedTable(app.pool.Replicas[0], "oauth_tokens", []map[string]any{ + { + "token": expiredToken, + "token_type": "access", + "client_id": clientID, + "user_id": 100, + "scope": "read", + "expires_at": time.Now().Add(-time.Hour), + "family_id": "expired-family", + }, + }) + + status, _ := oauthGetWithBearer(t, app, "/v1/oauth/me", expiredToken) + + // requireAuthMiddleware intercepts before the handler runs + assert.Equal(t, 401, status) +} + +func TestOAuthMe_RevokedToken(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + revokedToken := "revoked-access-" + fmt.Sprintf("%d", time.Now().UnixNano()) + + database.SeedTable(app.pool.Replicas[0], "oauth_tokens", []map[string]any{ + { + "token": revokedToken, + "token_type": "access", + "client_id": clientID, + "user_id": 100, + "scope": "read", + "expires_at": time.Now().Add(time.Hour), + "family_id": "revoked-family", + "is_revoked": true, + }, + }) + + status, _ := oauthGetWithBearer(t, app, "/v1/oauth/me", revokedToken) + + // requireAuthMiddleware intercepts before the handler runs + assert.Equal(t, 401, status) +} + +func TestOAuthMe_MissingAuthHeader(t *testing.T) { + app := emptyTestApp(t) + + req := httptest.NewRequest("GET", "/v1/oauth/me", nil) + res, err := app.Test(req, -1) + require.NoError(t, err) + + // Request goes through requireAuthMiddleware which returns 401 before the handler + assert.Equal(t, 401, res.StatusCode) +} + +// --- Full flow: code exchange → refresh → me → revoke --- + +func TestOAuthFullFlow(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + // Step 1: Exchange authorization code for tokens + code, codeVerifier, _ := insertTestAuthCode(t, app, clientID, 100, "write") + + status, body := oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "authorization_code", + "code": code, + "code_verifier": codeVerifier, + "client_id": clientID, + "redirect_uri": "https://example.com/callback", + }) + require.Equal(t, 200, status) + accessToken := gjson.GetBytes(body, "access_token").String() + refreshToken := gjson.GetBytes(body, "refresh_token").String() + + // Step 2: Use access token to get user profile + status, body = oauthGetWithBearer(t, app, "/v1/oauth/me", accessToken) + assert.Equal(t, 200, status) + jsonAssert(t, body, map[string]any{"handle": "oauthuser"}) + + // Step 3: Refresh the token + status, body = oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": clientID, + }) + require.Equal(t, 200, status) + newAccessToken := gjson.GetBytes(body, "access_token").String() + newRefreshToken := gjson.GetBytes(body, "refresh_token").String() + assert.NotEqual(t, accessToken, newAccessToken) + assert.NotEqual(t, refreshToken, newRefreshToken) + + // Step 4: New access token works + status, body = oauthGetWithBearer(t, app, "/v1/oauth/me", newAccessToken) + assert.Equal(t, 200, status) + jsonAssert(t, body, map[string]any{"handle": "oauthuser"}) + + // Step 5: Revoke + status, _ = oauthPostJSON(t, app, "/v1/oauth/revoke", map[string]string{ + "token": newAccessToken, + "client_id": clientID, + }) + assert.Equal(t, 200, status) + + // Step 6: Revoked tokens no longer work + status, _ = oauthGetWithBearer(t, app, "/v1/oauth/me", newAccessToken) + assert.Equal(t, 401, status) + + // Refreshing with the new refresh token also fails (family revoked) + status, body = oauthPostJSON(t, app, "/v1/oauth/token", map[string]string{ + "grant_type": "refresh_token", + "refresh_token": newRefreshToken, + "client_id": clientID, + }) + assert.Equal(t, 400, status) + jsonAssert(t, body, map[string]any{"error": "invalid_grant"}) +} + +// --- validateJWTIat --- + +func TestValidateJWTIat(t *testing.T) { + app := emptyTestApp(t) + + makeJWT := func(iat interface{}) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"ES256"}`)) + payload, _ := json.Marshal(map[string]interface{}{"iat": iat, "userId": "abc"}) + payloadB64 := base64.RawURLEncoding.EncodeToString(payload) + sig := base64.RawURLEncoding.EncodeToString([]byte("fakesig")) + return header + "." + payloadB64 + "." + sig + } + + t.Run("valid iat (current time)", func(t *testing.T) { + token := makeJWT(float64(time.Now().Unix())) + err := app.validateJWTIat(token) + assert.NoError(t, err) + }) + + t.Run("iat 4 minutes ago (within window)", func(t *testing.T) { + token := makeJWT(float64(time.Now().Add(-4 * time.Minute).Unix())) + err := app.validateJWTIat(token) + assert.NoError(t, err) + }) + + t.Run("iat 6 minutes ago (outside window)", func(t *testing.T) { + token := makeJWT(float64(time.Now().Add(-6 * time.Minute).Unix())) + err := app.validateJWTIat(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too far") + }) + + t.Run("iat 6 minutes in the future (outside window)", func(t *testing.T) { + token := makeJWT(float64(time.Now().Add(6 * time.Minute).Unix())) + err := app.validateJWTIat(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too far") + }) + + t.Run("missing iat", func(t *testing.T) { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"ES256"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(`{"userId":"abc"}`)) + sig := base64.RawURLEncoding.EncodeToString([]byte("fakesig")) + token := header + "." + payload + "." + sig + err := app.validateJWTIat(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing iat") + }) + + t.Run("iat as string number", func(t *testing.T) { + token := makeJWT(fmt.Sprintf("%d", time.Now().Unix())) + err := app.validateJWTIat(token) + assert.NoError(t, err) + }) + + t.Run("invalid JWT format", func(t *testing.T) { + err := app.validateJWTIat("not.a.valid-jwt-but-three-parts") + assert.Error(t, err) + }) +} + +// --- generateRandomToken --- + +func TestGenerateRandomToken(t *testing.T) { + token1, err := generateRandomToken(32) + assert.NoError(t, err) + assert.NotEmpty(t, token1) + assert.Equal(t, 43, len(token1)) // 32 bytes -> 43 chars in base64url without padding + + token2, err := generateRandomToken(32) + assert.NoError(t, err) + assert.NotEqual(t, token1, token2, "tokens should be unique") +} + +// --- lookupOAuthAccessToken + cache --- + +func TestLookupOAuthAccessToken_CacheBehavior(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + testApp := fiber.New() + var lookupResult oauthTokenCacheEntry + var lookupOK bool + + testApp.Get("/test", func(c *fiber.Ctx) error { + token := c.Get("X-Token") + lookupResult, lookupOK = app.lookupOAuthAccessToken(c, token) + return c.SendStatus(200) + }) + + familyID := "test-cache-family" + accessToken, _ := insertTestTokens(t, app, clientID, 100, "write", familyID, time.Hour, 30*24*time.Hour) + + // First lookup — cache miss, DB hit + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Token", accessToken) + _, err := testApp.Test(req, -1) + require.NoError(t, err) + assert.True(t, lookupOK) + assert.Equal(t, int32(100), lookupResult.UserID) + assert.Equal(t, strings.ToLower(clientID), strings.ToLower(lookupResult.ClientID)) + assert.Equal(t, "write", lookupResult.Scope) + + // Second lookup — should hit cache + req = httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Token", accessToken) + _, err = testApp.Test(req, -1) + require.NoError(t, err) + assert.True(t, lookupOK) + assert.Equal(t, int32(100), lookupResult.UserID) + + // Non-existent token should return false + req = httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Token", "nonexistent-token") + _, err = testApp.Test(req, -1) + require.NoError(t, err) + assert.False(t, lookupOK) +} + +// --- PKCE token in auth middleware --- + +func TestAuthMiddleware_PKCEToken(t *testing.T) { + app := emptyTestApp(t) + clientID := seedOAuthTestData(t, app) + + familyID := "test-family-middleware" + accessToken, _ := insertTestTokens(t, app, clientID, 100, "write", familyID, time.Hour, 30*24*time.Hour) + + var authedWallet string + var myId int32 + + testApp := fiber.New() + testApp.Use(app.resolveMyIdMiddleware) + testApp.Use(app.authMiddleware) + testApp.Get("/test", func(c *fiber.Ctx) error { + authedWallet = c.Locals("authedWallet").(string) + myId = app.getMyId(c) + return c.SendStatus(200) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + res, err := testApp.Test(req, -1) + require.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, strings.ToLower(clientID), authedWallet) + assert.Equal(t, int32(100), myId) +} + +func TestAuthMiddleware_PKCEToken_InvalidToken(t *testing.T) { + app := emptyTestApp(t) + + var authedWallet string + + testApp := fiber.New() + testApp.Use(app.resolveMyIdMiddleware) + testApp.Use(app.authMiddleware) + testApp.Get("/test", func(c *fiber.Ctx) error { + authedWallet = c.Locals("authedWallet").(string) + return c.SendStatus(200) + }) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer invalid-pkce-token") + res, err := testApp.Test(req, -1) + require.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "", authedWallet) +} diff --git a/database/seed.go b/database/seed.go index 1c7da467..d8a965f1 100644 --- a/database/seed.go +++ b/database/seed.go @@ -634,6 +634,34 @@ var ( "remaining_uses": 1, "created_at": time.Now(), }, + "oauth_redirect_uris": { + "client_id": nil, + "redirect_uri": nil, + "created_at": time.Now(), + }, + "oauth_authorization_codes": { + "code": nil, + "client_id": nil, + "user_id": nil, + "redirect_uri": "https://example.com/callback", + "code_challenge": nil, + "code_challenge_method": "S256", + "scope": "read", + "expires_at": time.Now().Add(10 * time.Minute), + "used": false, + }, + "oauth_tokens": { + "token": nil, + "token_type": nil, + "client_id": nil, + "user_id": nil, + "scope": "read", + "expires_at": time.Now().Add(time.Hour), + "is_revoked": false, + "created_at": time.Now(), + "refresh_token_id": nil, + "family_id": nil, + }, } ) From 8b143ccad418c9eb6d5dab1b58a0c7372b1a7a25 Mon Sep 17 00:00:00 2001 From: Marcus Pasell <3690498+rickyrombo@users.noreply.github.com> Date: Wed, 4 Mar 2026 00:38:40 -0800 Subject: [PATCH 2/5] use normal user query and proper profile picture --- api/v1_oauth.go | 43 +++++++++++++++++-------------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/api/v1_oauth.go b/api/v1_oauth.go index b049d06e..e678d013 100644 --- a/api/v1_oauth.go +++ b/api/v1_oauth.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "api.audius.co/trashid" + "api.audius.co/api/dbv1" "github.com/gofiber/fiber/v2" "github.com/jackc/pgx/v5" "go.uber.org/zap" @@ -454,39 +454,30 @@ func (app *ApiServer) v1OAuthMe(c *fiber.Ctx) error { return oauthError(c, fiber.StatusUnauthorized, "invalid_token", "Invalid or expired access token") } - // Query user info - encodedUserId, _ := trashid.EncodeHashId(int(entry.UserID)) - - var handle, name string - var verified bool - var profilePicture *string - err := app.pool.QueryRow(c.Context(), ` - SELECT handle, name, is_verified, - CASE WHEN profile_picture_sizes IS NOT NULL - THEN CONCAT($2::text, '/content/', profile_picture_sizes, '/150x150.jpg') - ELSE NULL - END as profile_picture - FROM users - WHERE user_id = $1 AND is_current = true - `, entry.UserID, app.audiusAppUrl).Scan(&handle, &name, &verified, &profilePicture) + // Fetch user via the standard query helper (includes rendezvous-based image URLs) + users, err := app.queries.Users(c.Context(), dbv1.GetUsersParams{ + Ids: []int32{entry.UserID}, + }) if err != nil { - if err == pgx.ErrNoRows { - return oauthError(c, fiber.StatusNotFound, "invalid_token", "User not found") - } app.logger.Error("Failed to query user for /oauth/me", zap.Error(err)) return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to get user info") } + if len(users) == 0 { + return oauthError(c, fiber.StatusNotFound, "invalid_token", "User not found") + } + + user := users[0] response := fiber.Map{ - "userId": encodedUserId, - "name": name, - "handle": handle, - "verified": verified, - "sub": encodedUserId, + "userId": user.ID, + "name": user.Name.String, + "handle": user.Handle.String, + "verified": user.IsVerified, + "sub": user.ID, "iat": time.Now().Unix(), } - if profilePicture != nil { - response["profilePicture"] = *profilePicture + if user.ProfilePicture != nil { + response["profilePicture"] = user.ProfilePicture } // Try to get email if available From 17401255e3c180c491021c9d5dd48292fe5f1bfc Mon Sep 17 00:00:00 2001 From: Marcus Pasell <3690498+rickyrombo@users.noreply.github.com> Date: Wed, 4 Mar 2026 00:44:59 -0800 Subject: [PATCH 3/5] remove email, don't expose that --- api/v1_oauth.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/api/v1_oauth.go b/api/v1_oauth.go index e678d013..8824cb76 100644 --- a/api/v1_oauth.go +++ b/api/v1_oauth.go @@ -480,15 +480,6 @@ func (app *ApiServer) v1OAuthMe(c *fiber.Ctx) error { response["profilePicture"] = user.ProfilePicture } - // Try to get email if available - var email *string - _ = app.pool.QueryRow(c.Context(), ` - SELECT email_address FROM user_emails WHERE user_id = $1 - `, entry.UserID).Scan(&email) - if email != nil { - response["email"] = *email - } - return c.JSON(response) } From ffd97f273940554f3f76b740b9845a3ce8f813c3 Mon Sep 17 00:00:00 2001 From: Marcus Pasell <3690498+rickyrombo@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:08:29 -0800 Subject: [PATCH 4/5] use read pool --- api/request_helpers.go | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/api/request_helpers.go b/api/request_helpers.go index f8561600..517f5060 100644 --- a/api/request_helpers.go +++ b/api/request_helpers.go @@ -52,17 +52,16 @@ func (app *ApiServer) getApiSigner(c *fiber.Ctx) (*Signer, error) { if token == "" { return nil, fmt.Errorf("Bearer token is empty") } - if app.writePool != nil { - if signer := app.getSignerFromApiAccessKey(c.Context(), token); signer != nil { - return signer, nil - } + + if signer := app.getSignerFromApiAccessKey(c.Context(), token); signer != nil { + return signer, nil } + // Try PKCE token → look up client_id → get api_secret from api_keys → return Signer - if app.writePool != nil { - if signer := app.getSignerFromOAuthToken(c, token); signer != nil { - return signer, nil - } + if signer := app.getSignerFromOAuthToken(c, token); signer != nil { + return signer, nil } + // If authMiddleware already validated a JWT and set authedWallet, // use AudiusApiSecret to sign on behalf of the authenticated user. if wallet, _ := c.Locals("authedWallet").(string); wallet != "" && app.config.AudiusApiSecret != "" { @@ -139,7 +138,7 @@ func (app *ApiServer) getSignerFromApiAccessKey(ctx context.Context, apiAccessKe } var parentApiKey, apiSecret string - err := app.writePool.QueryRow(ctx, ` + err := app.pool.QueryRow(ctx, ` SELECT aak.api_key, ak.api_secret FROM api_access_keys aak JOIN api_keys ak ON LOWER(ak.api_key) = LOWER(aak.api_key) @@ -176,7 +175,7 @@ func (app *ApiServer) getSignerFromOAuthToken(c *fiber.Ctx, token string) *Signe // Look up api_secret for the client_id (developer app address = api_key) var apiSecret string - err := app.writePool.QueryRow(c.Context(), ` + err := app.pool.QueryRow(c.Context(), ` SELECT api_secret FROM api_keys WHERE LOWER(api_key) = LOWER($1) `, entry.ClientID).Scan(&apiSecret) if err != nil || apiSecret == "" { From bf44bd0dba38595cf9a451c6cf8f7d556657f46b Mon Sep 17 00:00:00 2001 From: Marcus Pasell <3690498+rickyrombo@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:09:32 -0800 Subject: [PATCH 5/5] allow form encoded requests, make client_id optional once code is received, normalize client_id to wallet address --- api/v1_oauth.go | 84 ++++++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 36 deletions(-) diff --git a/api/v1_oauth.go b/api/v1_oauth.go index 8824cb76..029c5105 100644 --- a/api/v1_oauth.go +++ b/api/v1_oauth.go @@ -36,26 +36,35 @@ func generateRandomToken(nBytes int) (string, error) { // --- Request body structs --- type oauthAuthorizeBody struct { - Token string `json:"token"` - ClientID string `json:"client_id"` - RedirectURI string `json:"redirect_uri"` - CodeChallenge string `json:"code_challenge"` - CodeChallengeMethod string `json:"code_challenge_method"` - Scope string `json:"scope"` + Token string `json:"token" form:"token"` + ClientID string `json:"client_id" form:"client_id"` + RedirectURI string `json:"redirect_uri" form:"redirect_uri"` + CodeChallenge string `json:"code_challenge" form:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method" form:"code_challenge_method"` + Scope string `json:"scope" form:"scope"` } type oauthTokenBody struct { - GrantType string `json:"grant_type"` - Code string `json:"code"` - CodeVerifier string `json:"code_verifier"` - ClientID string `json:"client_id"` - RedirectURI string `json:"redirect_uri"` - RefreshToken string `json:"refresh_token"` + GrantType string `json:"grant_type" form:"grant_type"` + Code string `json:"code" form:"code"` + CodeVerifier string `json:"code_verifier" form:"code_verifier"` + ClientID string `json:"client_id" form:"client_id"` + RedirectURI string `json:"redirect_uri" form:"redirect_uri"` + RefreshToken string `json:"refresh_token" form:"refresh_token"` } type oauthRevokeBody struct { - Token string `json:"token"` - ClientID string `json:"client_id"` + Token string `json:"token" form:"token"` + ClientID string `json:"client_id" form:"client_id"` +} + +// normalizeClientID lowercases and ensures the 0x prefix on a client_id (developer app address). +func normalizeClientID(raw string) string { + id := strings.ToLower(strings.TrimSpace(raw)) + if !strings.HasPrefix(id, "0x") { + id = "0x" + id + } + return id } // --- Cache entry for PKCE tokens --- @@ -101,7 +110,7 @@ func (app *ApiServer) v1OAuthAuthorize(c *fiber.Ctx) error { } // 2. Validate client_id exists in developer_apps - clientID := strings.ToLower(body.ClientID) + clientID := normalizeClientID(body.ClientID) var appExists bool err = app.pool.QueryRow(c.Context(), ` SELECT EXISTS ( @@ -114,18 +123,19 @@ func (app *ApiServer) v1OAuthAuthorize(c *fiber.Ctx) error { } // 3. Validate redirect_uri - if !strings.EqualFold(body.RedirectURI, "postmessage") { - var uriRegistered bool - err = app.pool.QueryRow(c.Context(), ` - SELECT EXISTS ( - SELECT 1 FROM oauth_redirect_uris - WHERE LOWER(client_id) = $1 AND redirect_uri = $2 - ) - `, clientID, body.RedirectURI).Scan(&uriRegistered) - if err != nil || !uriRegistered { - return oauthError(c, fiber.StatusBadRequest, "invalid_request", "redirect_uri not registered") - } - } + // SKIP FOR NOW + // if !strings.EqualFold(body.RedirectURI, "postmessage") { + // var uriRegistered bool + // err = app.pool.QueryRow(c.Context(), ` + // SELECT EXISTS ( + // SELECT 1 FROM oauth_redirect_uris + // WHERE LOWER(client_id) = $1 AND redirect_uri = $2 + // ) + // `, clientID, body.RedirectURI).Scan(&uriRegistered) + // if err != nil || !uriRegistered { + // return oauthError(c, fiber.StatusBadRequest, "invalid_request", "redirect_uri not registered") + // } + // } // 4. If scope is write, check for existing approved grant if body.Scope == "write" { @@ -187,12 +197,10 @@ func (app *ApiServer) v1OAuthToken(c *fiber.Ctx) error { } func (app *ApiServer) oauthTokenAuthorizationCode(c *fiber.Ctx, body *oauthTokenBody) error { - if body.Code == "" || body.CodeVerifier == "" || body.ClientID == "" || body.RedirectURI == "" { + if body.Code == "" || body.CodeVerifier == "" { return oauthError(c, fiber.StatusBadRequest, "invalid_request", "Missing required parameters for authorization_code grant") } - clientID := strings.ToLower(body.ClientID) - // Atomically consume the code var storedClientID, storedRedirectURI, storedCodeChallenge, storedCodeChallengeMethod, storedScope string var storedUserID int32 @@ -210,13 +218,17 @@ func (app *ApiServer) oauthTokenAuthorizationCode(c *fiber.Ctx, body *oauthToken return oauthError(c, fiber.StatusInternalServerError, "server_error", "Failed to process authorization code") } - // Verify client_id matches - if strings.ToLower(storedClientID) != clientID { - return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "client_id mismatch") + // client_id is optional in the request — if provided, verify it matches the stored one. + // If omitted, use the client_id bound to the authorization code. + clientID := normalizeClientID(storedClientID) + if body.ClientID != "" { + if normalizeClientID(body.ClientID) != clientID { + return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "client_id mismatch") + } } - // Verify redirect_uri matches - if storedRedirectURI != body.RedirectURI { + // redirect_uri is optional — if provided, verify it matches + if body.RedirectURI != "" && storedRedirectURI != body.RedirectURI { return oauthError(c, fiber.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") } @@ -280,7 +292,7 @@ func (app *ApiServer) oauthTokenRefreshToken(c *fiber.Ctx, body *oauthTokenBody) return oauthError(c, fiber.StatusBadRequest, "invalid_request", "Missing required parameters for refresh_token grant") } - clientID := strings.ToLower(body.ClientID) + clientID := normalizeClientID(body.ClientID) // First, check if the token exists and whether it triggers reuse detection. // We need a two-phase approach: check for reuse first, then atomically consume.