Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 73 additions & 25 deletions pkg/commonhttp/client_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ package commonhttp
import (
"errors"
"fmt"
"io"
"mime"
"net/http"
"net/url"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -324,48 +328,92 @@ type cachedJWT struct {
// - *http.Response: the response from the underlying transport.
// - error: if credential injection fails or JWT generation fails.
func (t *clientOAuth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
urlCopy := *req.URL
newReq := req.Clone(req.Context())

q := urlCopy.Query()
var q url.Values

// Always inject client_id
q.Set("client_id", t.ClientID)
var ct string

if contentType := newReq.Header.Get("Content-Type"); contentType != "" {
var err error

ct, _, err = mime.ParseMediaType(contentType)
if err != nil {
return nil, fmt.Errorf("parsing mime type: %w", err)
}
}

isFormBody := newReq.Method == http.MethodPost && ct == "application/x-www-form-urlencoded"

if isFormBody && newReq.Body != nil {
defer func() {
if newReq.Body != nil {
_ = newReq.Body.Close()
}
}()

bodyBytes, err := io.ReadAll(newReq.Body)
if err != nil {
return nil, fmt.Errorf("reading request body: %w", err)
}

q, err = url.ParseQuery(string(bodyBytes))
if err != nil {
return nil, fmt.Errorf("parsing form body: %w", err)
}
} else {
q = newReq.URL.Query()
Comment thread
nnicora marked this conversation as resolved.
}

// Always inject client_id unless using basic auth
if t.ClientSecretBasic == nil || *t.ClientSecretBasic == "" {
q.Set("client_id", t.ClientID)
}

switch {
case t.ClientSecretPost != nil && *t.ClientSecretPost != "":
// client_secret_post → inject into query (or body)
q.Set("client_secret", *t.ClientSecretPost)
// client_secret_post mandates injection into the form body ONLY.
// Exposing secrets in URL query parameters is a severe security violation (RFC 6749).
if isFormBody {
q.Set("client_secret", *t.ClientSecretPost)
}

case t.ClientSecretBasic != nil && *t.ClientSecretBasic != "":
// client_secret_basic → set Authorization header
// client_secret_basic → set Authorization header safely
newReq.SetBasicAuth(t.ClientID, *t.ClientSecretBasic)

case t.ClientAssertion != nil && t.ClientAssertionType != nil:
// private_key_jwt → inject JWT assertion
jwtToken, err := t.requestJWT("private_key_jwt", *t.ClientAssertion)
if err != nil {
return nil, err
// private_key_jwt → inject raw JWT assertion ONLY into form body
if isFormBody {
q.Set("client_assertion_type", *t.ClientAssertionType)
q.Set("client_assertion", *t.ClientAssertion)
}

q.Set("client_assertion_type", *t.ClientAssertionType)
q.Set("client_assertion", jwtToken)

case t.ClientSecretJWT != nil:
// client_secret_jwt → generate JWT signed with shared secret
jwtToken, err := t.requestJWT("client_secret_jwt", *t.ClientSecretJWT)
if err != nil {
return nil, err
// client_secret_jwt → generate JWT and inject ONLY into form body
if isFormBody {
jwtToken, err := t.requestJWT("client_secret_jwt", *t.ClientSecretJWT)
if err != nil {
return nil, err
}

q.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
q.Set("client_assertion", jwtToken)
}

q.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
q.Set("client_assertion", jwtToken)
}

urlCopy.RawQuery = q.Encode()
newReq.URL = &urlCopy
if isFormBody {
newBodyStr := q.Encode()
newReq.Body = io.NopCloser(strings.NewReader(newBodyStr))
newReq.ContentLength = int64(len(newBodyStr))
newReq.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(newBodyStr)), nil
}
} else {
newReq.URL.RawQuery = q.Encode()
}

return t.Next.RoundTrip(&newReq)
return t.Next.RoundTrip(newReq)
}

// requestJWT generates or retrieves a cached JWT for the specified key and secret.
Expand Down
156 changes: 148 additions & 8 deletions pkg/commonhttp/client_oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand All @@ -19,6 +23,16 @@ import (
"github.com/openkcm/common-sdk/pkg/pointers"
)

type errorReader struct{}

func (errorReader) Read(p []byte) (int, error) {
return 0, errors.New("mock read error")
}

func (errorReader) Close() error {
return nil
}

const tokenURL = "https://example.com/token"

// helper to create SourceRef from a literal value for testing
Expand Down Expand Up @@ -155,6 +169,51 @@ func TestNewClientFromOAuth2(t *testing.T) {
assert.Nil(t, rt.ClientAssertionType)
},
},
{
name: "load ClientID error",
config: &commoncfg.OAuth2{
Credentials: commoncfg.OAuth2Credentials{
ClientID: commoncfg.SourceRef{Source: "invalid"},
},
},
wantErr: true,
errMessage: "OAuth2 credentials missing client ID",
},
{
name: "load URL error",
config: &commoncfg.OAuth2{
Credentials: commoncfg.OAuth2Credentials{
ClientID: *strRef("id"),
},
URL: &commoncfg.SourceRef{Source: "invalid"},
},
wantErr: true,
errMessage: "OAuth2 credentials missing URL",
},
{
name: "load MTLS error",
config: &commoncfg.OAuth2{
Credentials: commoncfg.OAuth2Credentials{
ClientID: *strRef("id"),
AuthMethod: "none",
},
MTLS: &commoncfg.MTLS{
Cert: commoncfg.SourceRef{Source: "invalid"},
},
},
wantErr: true,
},
{
name: "no AuthMethod provided",
config: &commoncfg.OAuth2{
Credentials: commoncfg.OAuth2Credentials{
ClientID: *strRef("id"),
AuthMethod: "post",
},
},
wantErr: true,
errMessage: "no client authentication method provided",
},
}

for _, tt := range tests {
Expand All @@ -180,13 +239,25 @@ func TestNewClientFromOAuth2(t *testing.T) {
}
}

func TestValidateCombinationError(t *testing.T) {
creds := &commoncfg.OAuth2{}
rt := &clientOAuth2RoundTripper{
ClientSecretPost: pointers.String("secret"),
ClientAssertion: pointers.String("assertion"),
}
err := validate(creds, rt)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot combine clientSecret with clientAssertion")
}

func TestClientOAuth2RoundTripper_RoundTrip(t *testing.T) {
clientID := "test-client"

tests := []struct {
name string
rt *clientOAuth2RoundTripper
check func(r *http.Request)
reqMod func(r *http.Request)
wantErr bool
errMessage string
}{
Expand All @@ -198,8 +269,10 @@ func TestClientOAuth2RoundTripper_RoundTrip(t *testing.T) {
Next: http.DefaultTransport,
},
check: func(r *http.Request) {
assert.Equal(t, clientID, r.URL.Query().Get("client_id"))
assert.Equal(t, "secret", r.URL.Query().Get("client_secret"))
bodyBytes, _ := io.ReadAll(r.Body)
q, _ := url.ParseQuery(string(bodyBytes))
assert.Equal(t, clientID, q.Get("client_id"))
assert.Equal(t, "secret", q.Get("client_secret"))
},
},
{
Expand All @@ -226,8 +299,10 @@ func TestClientOAuth2RoundTripper_RoundTrip(t *testing.T) {
jwtCache: make(map[string]cachedJWT),
},
check: func(r *http.Request) {
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", r.URL.Query().Get("client_assertion_type"))
assert.NotEmpty(t, r.URL.Query().Get("client_assertion"))
bodyBytes, _ := io.ReadAll(r.Body)
q, _ := url.ParseQuery(string(bodyBytes))
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", q.Get("client_assertion_type"))
assert.NotEmpty(t, q.Get("client_assertion"))
},
},
{
Expand All @@ -241,8 +316,65 @@ func TestClientOAuth2RoundTripper_RoundTrip(t *testing.T) {
jwtCache: make(map[string]cachedJWT),
},
check: func(r *http.Request) {
assert.Equal(t, "urn:custom:type", r.URL.Query().Get("client_assertion_type"))
assert.NotEmpty(t, r.URL.Query().Get("client_assertion"))
bodyBytes, _ := io.ReadAll(r.Body)
q, _ := url.ParseQuery(string(bodyBytes))
assert.Equal(t, "urn:custom:type", q.Get("client_assertion_type"))
assert.NotEmpty(t, q.Get("client_assertion"))
},
},
{
name: "mime type error",
rt: &clientOAuth2RoundTripper{
ClientID: clientID,
Next: http.DefaultTransport,
},
reqMod: func(r *http.Request) {
r.Header.Set("Content-Type", "invalid/mime;type=unquoted\"")
},
wantErr: true,
errMessage: "parsing mime type",
},
{
name: "body read error",
rt: &clientOAuth2RoundTripper{
ClientID: clientID,
Next: http.DefaultTransport,
},
reqMod: func(r *http.Request) {
r.Body = errorReader{}
},
wantErr: true,
errMessage: "reading request body",
},
{
name: "form parsing error",
rt: &clientOAuth2RoundTripper{
ClientID: clientID,
Next: http.DefaultTransport,
},
reqMod: func(r *http.Request) {
r.Body = io.NopCloser(strings.NewReader("%ZZ=1"))
},
wantErr: true,
errMessage: "parsing form body",
},
{
name: "GET request with client_secret_post",
rt: &clientOAuth2RoundTripper{
ClientID: clientID,
ClientSecretPost: pointers.String("secret"),
Next: http.DefaultTransport,
},
reqMod: func(r *http.Request) {
r.Method = http.MethodGet
r.Header.Del("Content-Type")
r.Body = nil
r.ContentLength = 0
},
check: func(r *http.Request) {
q := r.URL.Query()
assert.Equal(t, clientID, q.Get("client_id"))
assert.Empty(t, q.Get("client_secret"))
},
},
}
Expand All @@ -258,9 +390,14 @@ func TestClientOAuth2RoundTripper_RoundTrip(t *testing.T) {
}))
defer server.Close()

req, err := http.NewRequest(http.MethodGet, server.URL, nil)
req, err := http.NewRequest(http.MethodPost, server.URL, strings.NewReader("dummy=data"))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
assert.NoError(t, err)

if tt.reqMod != nil {
tt.reqMod(req)
}

rep, err := tt.rt.RoundTrip(req)
if tt.wantErr {
assert.Error(t, err)
Expand All @@ -271,7 +408,10 @@ func TestClientOAuth2RoundTripper_RoundTrip(t *testing.T) {
} else {
assert.NoError(t, err)
}
defer rep.Body.Close()

if rep != nil {
defer rep.Body.Close()
}
})
}
}
Expand Down
Loading