From 2f5820d8d8c6d382e28b5c75c5248355270d8c9d Mon Sep 17 00:00:00 2001 From: Zeno Belli Date: Wed, 25 Mar 2026 14:10:41 +0100 Subject: [PATCH] refactor(jwtinfo): ParseTokenData function to Parse methods --- cmd/jwtinfo.go | 13 ++- internal/jwtinfo/jwtinfo.go | 141 ++++++++++++------------------- internal/jwtinfo/jwtinfo_test.go | 90 ++++++++++++-------- internal/style/style.go | 4 + 4 files changed, 118 insertions(+), 130 deletions(-) diff --git a/cmd/jwtinfo.go b/cmd/jwtinfo.go index 22bba7f..db752ee 100644 --- a/cmd/jwtinfo.go +++ b/cmd/jwtinfo.go @@ -65,15 +65,14 @@ var jwtinfoCmd = &cobra.Command{ return } - // TODO: turn into method - token, err := jwtinfo.ParseTokenData(tokenData, jwksURL, keyfuncDefOverride) - if err != nil { - fmt.Printf("error while parsing token data: %s\n", err) - return + if jwksURL != "" { + err = tokenData.ParseWithJWKS(jwksURL, keyfuncDefOverride) + if err != nil { + fmt.Printf("error while parsing token data: %s\n", err) + return + } } - fmt.Printf("Token valid: %v\n", token.Valid) - err = jwtinfo.PrintTokenInfo(tokenData, os.Stdout) if err != nil { fmt.Printf("error while printing token data: %s\n", err) diff --git a/internal/jwtinfo/jwtinfo.go b/internal/jwtinfo/jwtinfo.go index 4400d36..d2bd065 100644 --- a/internal/jwtinfo/jwtinfo.go +++ b/internal/jwtinfo/jwtinfo.go @@ -29,10 +29,12 @@ var ( ) type JwtTokenData struct { - AccessToken string `json:"access_token"` //nolint:tagliatelle // OAuth token field name + AccessTokenRaw string `json:"access_token"` //nolint:tagliatelle // OAuth token field name + AccessTokenJwt *jwt.Token AccessTokenHeader []byte AccessTokenClaims []byte - RefreshToken string `json:"refresh_token"` //nolint:tagliatelle // OAuth token field name + RefreshTokenRaw string `json:"refresh_token"` //nolint:tagliatelle // OAuth token field name + RefreshTokenJwt *jwt.Token RefreshTokenHeader []byte RefreshTokenClaims []byte } @@ -97,7 +99,7 @@ func RequestToken(reqURL string, reqValues map[string]string, client *http.Clien mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) if mediaType == "application/jwt" { - t.AccessToken = string(bodyBytes) + t.AccessTokenRaw = string(bodyBytes) } if mediaType == "application/json" { @@ -110,7 +112,7 @@ func RequestToken(reqURL string, reqValues map[string]string, client *http.Clien } _, _, err = jwt.NewParser().ParseUnverified( - t.AccessToken, + t.AccessTokenRaw, &jwt.RegisteredClaims{}, ) if err != nil { @@ -158,25 +160,25 @@ func (jtd *JwtTokenData) DecodeBase64() error { }{ { name: "AccessToken", - raw: jtd.AccessToken, + raw: jtd.AccessTokenRaw, }, { name: "RefreshToken", - raw: jtd.RefreshToken, + raw: jtd.RefreshTokenRaw, }, } for _, token := range tokens { + if token.raw == emptyString { + continue + } + var tokenHeader []byte var tokenClaims []byte var err error - if token.raw == emptyString { - continue - } - tokenB64Elements := strings.Split(token.raw, ".") if len(tokenB64Elements) != 3 { return fmt.Errorf("invalid three dotted JWT format in %s", token.name) @@ -230,24 +232,28 @@ func (jtd *JwtTokenData) DecodeBase64() error { return nil } -func ParseTokenData(jtd JwtTokenData, jwksURL string, keyfuncOverride keyfunc.Override) (*jwt.Token, error) { - // Parsing the access token without validation - if jwksURL == "" { - token, _, err := jwt.NewParser().ParseUnverified( - jtd.AccessToken, - &jwt.RegisteredClaims{}, +func (jtd *JwtTokenData) ParseUnverified() error { + token, _, err := jwt.NewParser().ParseUnverified( + jtd.AccessTokenRaw, + &jwt.RegisteredClaims{}, + ) + if err != nil { + return fmt.Errorf( + "unable to parse AccessTokenRaw: %w", + err, ) - if err != nil { - return nil, fmt.Errorf( - "unable to parse unverified access token: %w", - err, - ) - } + } + + jtd.AccessTokenJwt = token - return token, nil + return nil +} + +func (jtd *JwtTokenData) ParseWithJWKS(jwksURL string, keyfuncOverride keyfunc.Override) error { + if jwksURL == emptyString { + return errors.New("emptyString string provided as JWKS url") } - // Parsing and validating the access token ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -257,7 +263,7 @@ func ParseTokenData(jtd JwtTokenData, jwksURL string, keyfuncOverride keyfunc.Ov keyfuncOverride, ) if err != nil { - return nil, fmt.Errorf( + return fmt.Errorf( "failed to create JWK Set from resource at URL %s: %w", jwksURL, err, @@ -265,82 +271,36 @@ func ParseTokenData(jtd JwtTokenData, jwksURL string, keyfuncOverride keyfunc.Ov } token, err := jwt.Parse( - jtd.AccessToken, + jtd.AccessTokenRaw, jwks.Keyfunc, ) if err != nil { - return nil, fmt.Errorf("failed to parse the JWT: %w", err) + return fmt.Errorf( + "failed to parse the JWT AccessTokenRaw against JWKS Url %s: %w", + jwksURL, + err, + ) } - return token, nil + jtd.AccessTokenJwt = token + + return nil } -// func DisplayTokenInfo(t *jwt.Token, w io.Writer) error { -// sl := style.CertKeyP4.Render -// sv := style.CertValue.Render -// sTrue := style.BoolTrue.Render -// sFalse := style.BoolFalse.Render -// -// fmt.Fprintln(w) -// fmt.Fprintln(w, style.LgSprintf(style.Cmd, "JwtInfo")) -// fmt.Fprintln(w) -// -// validString := sFalse("false") -// if t.Valid { -// validString = sTrue("true") -// } -// -// fmt.Fprintln(w, style.LgSprintf(style.CertKeyP3, "Valid %s", validString)) -// -// tokenHeaders := getTokenHeadersMap(t) -// hTable := table.New().Border(style.LGDefBorder).Headers("Header") -// -// for hKey, hVal := range tokenHeaders { -// hTable.Row(sl(hKey), sv(hVal)) -// } -// -// fmt.Fprintln(w, hTable.Render()) -// -// hTable.ClearRows() -// -// tokenClaims, err := getTokenClaimsMap(t) -// if err != nil { -// return fmt.Errorf("unable to get token Claims: %w", err) -// } -// -// cTable := table.New().Border(style.LGDefBorder).Headers("Claims") -// -// for cKey, cVal := range tokenClaims { -// cTable.Row(sl(cKey), sv(cVal)) -// } -// -// unregisteredClaims := getUnregisteredClaimsMap(t, tokenClaims) -// for ucKey, ucVal := range unregisteredClaims { -// cTable.Row(sl(ucKey), sv(ucVal)) -// } -// -// fmt.Fprintln(w, cTable.Render()) -// -// cTable.ClearRows() -// -// return nil -// } func PrintTokenInfo(jtd JwtTokenData, w io.Writer) error { sl := style.CertKeyP4.Render sv := style.CertValue.Render - // sTrue := style.BoolTrue.Render - // sFalse := style.BoolFalse.Render + sTrue := style.BoolTrue.Render + sFalse := style.BoolFalse.Render fmt.Fprintln(w) fmt.Fprintln(w, style.LgSprintf(style.Cmd, "JwtInfo")) fmt.Fprintln(w) - // validString := sFalse("false") - // if t.Valid { - // validString = sTrue("true") - // } - - // fmt.Fprintln(w, style.LgSprintf(style.CertKeyP3, "Valid %s", validString)) + validString := sFalse("false") + if jtd.AccessTokenJwt != nil && jtd.AccessTokenJwt.Valid { + validString = sTrue("true") + } tokens := []struct { name string @@ -364,9 +324,14 @@ func PrintTokenInfo(jtd JwtTokenData, w io.Writer) error { continue } - fmt.Fprintln(w, style.LgSprintf(style.Title, "%s", token.name)) - + fmt.Fprintln(w, style.LgSprintf(style.Title2, "%s", token.name)) fmt.Fprintln(w) + + if token.name == "AccessToken" { + fmt.Fprintln(w, style.LgSprintf(style.ItemKey, "Valid %s", validString)) + fmt.Fprintln(w) + } + fmt.Fprintln(w, style.LgSprintf(style.ItemKey, "Header")) var prettyJSON bytes.Buffer diff --git a/internal/jwtinfo/jwtinfo_test.go b/internal/jwtinfo/jwtinfo_test.go index 1a25a88..500557c 100644 --- a/internal/jwtinfo/jwtinfo_test.go +++ b/internal/jwtinfo/jwtinfo_test.go @@ -231,7 +231,7 @@ func TestRequestToken(t *testing.T) { } } -func TestParseTokenData(t *testing.T) { +func TestParseWithJWKS(t *testing.T) { tests := []struct { name string user string @@ -343,13 +343,6 @@ func TestParseTokenData(t *testing.T) { Client: server.Client(), } - _, err = ParseTokenData( - td, - "", - keyfuncOverrideTesting, - ) - require.NoError(t, err) - if tt.scope == "jwksEmpty" { respJwksEmpty, errEmpty := server.Client().Get(serverJwksEmptyEndpoint) require.NoError(t, errEmpty) @@ -365,8 +358,7 @@ func TestParseTokenData(t *testing.T) { "{}", ) - _, err = ParseTokenData( - td, + err = td.ParseWithJWKS( serverJwksEmptyEndpoint, keyfuncOverrideTesting, ) @@ -394,8 +386,7 @@ func TestParseTokenData(t *testing.T) { "UniqueKeyID1", ) - _, err = ParseTokenData( - td, + err = td.ParseWithJWKS( serverJwksFaultyEndpoint, keyfuncOverrideTesting, ) @@ -408,36 +399,67 @@ func TestParseTokenData(t *testing.T) { return } - tokenVerified, err := ParseTokenData( - td, + err = td.ParseWithJWKS( serverJwksEndpoint, keyfuncOverrideTesting, ) require.NoError(t, err) require.True( t, - tokenVerified.Valid, + td.AccessTokenJwt.Valid, "JWT token must be valid", ) }) } } -func TestParseTokenData_Errors(t *testing.T) { - t.Run("ParseUnverifiedError", func(t *testing.T) { +func TestParseUnverified(t *testing.T) { + t.Run("Success", func(t *testing.T) { t.Parallel() - td := JwtTokenData{AccessToken: "notValidString"} + tokenRaw, err := createToken("demo") + require.NoError(t, err) + + td := JwtTokenData{AccessTokenRaw: tokenRaw} - _, err := ParseTokenData( - td, + err = td.ParseUnverified() + require.NoError( + t, + err, + ) + }) + + t.Run("Error", func(t *testing.T) { + t.Parallel() + + td := JwtTokenData{AccessTokenRaw: "notValidString"} + + err := td.ParseUnverified() + require.ErrorContains( + t, + err, + "token is malformed: token contains an invalid number of segments", + ) + }) +} + +func TestParseWithJWKS_Errors(t *testing.T) { + t.Run("EmpryJwksURL", func(t *testing.T) { + t.Parallel() + + token, err := createToken("demo") + require.NoError(t, err) + + td := JwtTokenData{AccessTokenRaw: token} + + err = td.ParseWithJWKS( "", keyfunc.Override{}, ) require.ErrorContains( t, err, - "token is malformed: token contains an invalid number of segments", + "emptyString string provided as JWKS url", ) }) @@ -447,10 +469,9 @@ func TestParseTokenData_Errors(t *testing.T) { token, err := createToken("demo") require.NoError(t, err) - td := JwtTokenData{AccessToken: token} + td := JwtTokenData{AccessTokenRaw: token} - _, err = ParseTokenData( - td, + err = td.ParseWithJWKS( "https://localhost:54321/jkws.wrong.json", keyfunc.Override{}, ) @@ -467,10 +488,9 @@ func TestParseTokenData_Errors(t *testing.T) { token, err := createToken("demo") require.NoError(t, err) - td := JwtTokenData{AccessToken: token} + td := JwtTokenData{AccessTokenRaw: token} - _, err = ParseTokenData( - td, + err = td.ParseWithJWKS( "https://loca#$%^/jkws.json", keyfunc.Override{}, ) @@ -553,7 +573,7 @@ func TestDecodeBase64(t *testing.T) { accessTokenRaw, err := createToken("demo") require.NoError(t, err) - td := JwtTokenData{AccessToken: accessTokenRaw} + td := JwtTokenData{AccessTokenRaw: accessTokenRaw} err = td.DecodeBase64() require.NoError(t, err) @@ -561,19 +581,19 @@ func TestDecodeBase64(t *testing.T) { require.NoError(t, err) tdAccessTokenTest := td - tdAccessTokenTest.AccessToken = tt.tokenString + tdAccessTokenTest.AccessTokenRaw = tt.tokenString err = tdAccessTokenTest.DecodeBase64() require.ErrorContains(t, err, tt.errMsg) refreshTokenRaw, err := createToken("demo") require.NoError(t, err) - tdR := JwtTokenData{RefreshToken: refreshTokenRaw} + tdR := JwtTokenData{RefreshTokenRaw: refreshTokenRaw} err = tdR.DecodeBase64() require.NoError(t, err) tdRefreshTokenTest := tdR - tdRefreshTokenTest.RefreshToken = tt.tokenString + tdRefreshTokenTest.RefreshTokenRaw = tt.tokenString err = tdRefreshTokenTest.DecodeBase64() require.ErrorContains(t, err, tt.errMsg) }) @@ -608,7 +628,7 @@ func TestUnmarshallTokenTimeClaims(t *testing.T) { CustomerInfo{"demo", "human"}, } - jtd.AccessToken, err = token.SignedString(signKey) + jtd.AccessTokenRaw, err = token.SignedString(signKey) require.NoError(t, err) err = jtd.DecodeBase64() @@ -747,15 +767,14 @@ func TestPrintTokenInfo(t *testing.T) { Client: server.Client(), } - tokenVerified, err := ParseTokenData( - td, + err = td.ParseWithJWKS( serverJwksEndpoint, keyfuncOverrideTesting, ) require.NoError(t, err) require.True( t, - tokenVerified.Valid, + td.AccessTokenJwt.Valid, "JWT token must be valid", ) @@ -770,6 +789,7 @@ func TestPrintTokenInfo(t *testing.T) { stringsToCheck := []string{ "JwtInfo", + "Valid", "Header", "Claims", "alg", diff --git a/internal/style/style.go b/internal/style/style.go index 1557508..9485929 100644 --- a/internal/style/style.go +++ b/internal/style/style.go @@ -39,6 +39,10 @@ var ( Foreground(catLavander).Bold(true). PaddingLeft(1) + Title2 = lipgloss.NewStyle(). + Foreground(catBase).Background(catPeach).Bold(true). + PaddingLeft(1).PaddingRight(1) + ItemKey = lipgloss.NewStyle(). Foreground(catBlue). PaddingLeft(1).Bold(true)