Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 74 additions & 35 deletions proxy/auth/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@ import (

const oauthRedirectURICookiePrefix = "oauth_redirect_uri_"

// oauthCallbackFromOrigin builds the OAuth redirect_uri (…/callback) for the given UI origin
// (scheme + host, no path). The path prefix comes from config.BaseUiUrl so deployments served
// from a subpath (e.g. https://host/ui) resolve to …/ui/callback instead of …/callback.
func oauthCallbackFromOrigin(origin *url.URL) (string, error) {
// oauthCallbackURL builds the OAuth redirect_uri (…/callback) for the given UI origin scheme and host.
// The path prefix comes from config.BaseUiUrl so deployments served from a subpath (e.g. https://host/ui)
// resolve to …/ui/callback instead of …/callback.
func oauthCallbackURL(scheme, host string) (string, error) {
scheme = strings.ToLower(strings.TrimSpace(scheme))
if scheme != "http" && scheme != "https" {
return "", fmt.Errorf("invalid origin scheme")
}
host = strings.TrimSpace(host)
if host == "" || strings.ContainsAny(host, "@\r\n\t") {
return "", fmt.Errorf("invalid origin host")
}
baseUI, err := url.Parse(config.BaseUiUrl)
if err != nil {
return "", fmt.Errorf("invalid BASE_UI_URL configuration: %w", err)
Expand All @@ -26,12 +34,9 @@ func oauthCallbackFromOrigin(origin *url.URL) (string, error) {
if basePath != "" {
callbackPath = basePath + "/callback"
}
out := &url.URL{
Scheme: origin.Scheme,
Host: origin.Host,
Path: callbackPath,
}
return out.String(), nil

pathURL := &url.URL{Path: callbackPath}
return scheme + "://" + host + pathURL.EscapedPath(), nil
}

// ResolveOAuthRedirectURI returns the OAuth redirect_uri (callback URL) for this login attempt.
Expand All @@ -47,27 +52,22 @@ func ResolveOAuthRedirectURI(r *http.Request, redirectBase string) (string, erro
if err != nil {
return "", fmt.Errorf("invalid BASE_UI_URL configuration: %w", err)
}
origin := &url.URL{Scheme: baseUI.Scheme, Host: baseUI.Host}
return oauthCallbackFromOrigin(origin)
return oauthCallbackURL(baseUI.Scheme, baseUI.Host)
}
u, err := url.Parse(strings.TrimSpace(redirectBase))
ok, err := isSameSchemeAndHost(redirectBase, r)
if err != nil {
return "", fmt.Errorf("invalid redirect_base")
return "", err
}
if u.Scheme != "http" && u.Scheme != "https" {
return "", fmt.Errorf("invalid redirect_base: only http and https are allowed")
if !ok {
return "", fmt.Errorf("redirect_base does not match this UI origin")
}
if u.Hostname() == "" {
return "", fmt.Errorf("invalid redirect_base: host is required")
}
if u.RawQuery != "" || u.Fragment != "" {
return "", fmt.Errorf("invalid redirect_base: query and fragment are not allowed")
}
origin := &url.URL{Scheme: u.Scheme, Host: u.Host}
if err := redirectBaseMatchesRequest(r, origin); err != nil {
rs, rh := requestSchemeAndHost(r)
canon := normalizeOrigin(rs, rh)
scheme, host, err := splitCanonicalOriginURL(canon)
if err != nil {
return "", err
}
return oauthCallbackFromOrigin(origin)
return oauthCallbackURL(scheme, host)
}

// ResolveLogoutRedirectBase returns the UI base URL for OIDC post_logout_redirect_uri (no trailing slash).
Expand All @@ -82,6 +82,24 @@ func ResolveLogoutRedirectBase(r *http.Request, redirectBase string) (string, er
return strings.TrimSuffix(callbackURI, "/callback"), nil
}

// splitCanonicalOriginURL splits a string produced only by normalizeOrigin ("scheme://host[:port]").
func splitCanonicalOriginURL(canon string) (scheme, host string, err error) {
const sep = "://"
i := strings.Index(canon, sep)
if i < 0 {
return "", "", fmt.Errorf("invalid canonical origin")
}
scheme = canon[:i]
host = canon[i+len(sep):]
if scheme != "http" && scheme != "https" {
return "", "", fmt.Errorf("invalid canonical origin scheme")
}
if host == "" {
return "", "", fmt.Errorf("invalid canonical origin host")
}
return scheme, host, nil
}

func requestSchemeAndHost(r *http.Request) (scheme, host string) {
scheme = "http"
if r.TLS != nil {
Expand All @@ -100,6 +118,37 @@ func requestSchemeAndHost(r *http.Request) (scheme, host string) {
return scheme, host
}

// isSameSchemeAndHost parses redirectBase as an http(s) URL, validates it, and reports whether
// its origin matches the effective request (see requestSchemeAndHost). Empty redirectBase
// returns (false, nil). Invalid redirect_base returns (_, err).
func isSameSchemeAndHost(redirectBase string, r *http.Request) (bool, error) {
s := strings.TrimSpace(redirectBase)
if s == "" {
return false, nil
}
u, err := url.Parse(s)
if err != nil {
return false, fmt.Errorf("invalid redirect_base")
}
if u.Scheme != "http" && u.Scheme != "https" {
return false, fmt.Errorf("invalid redirect_base: only http and https are allowed")
}
if u.Hostname() == "" {
return false, fmt.Errorf("invalid redirect_base: host is required")
}
if u.RawQuery != "" || u.Fragment != "" {
return false, fmt.Errorf("invalid redirect_base: query and fragment are not allowed")
}
if u.User != nil {
return false, fmt.Errorf("invalid redirect_base: user info is not allowed")
}
rs, rh := requestSchemeAndHost(r)
if normalizeOrigin(u.Scheme, u.Host) != normalizeOrigin(rs, rh) {
return false, nil
}
return true, nil
}

// cookieSecureForRequest is true when the Set-Cookie Secure attribute should be set: TLS is
// configured on this proxy, or the effective request scheme is HTTPS (including when TLS
// terminates at a reverse proxy and X-Forwarded-Proto is trusted).
Expand All @@ -114,16 +163,6 @@ func cookieSecureForRequest(r *http.Request) bool {
return strings.EqualFold(strings.TrimSpace(scheme), "https")
}

func redirectBaseMatchesRequest(r *http.Request, u *url.URL) error {
rs, rh := requestSchemeAndHost(r)
candidate := normalizeOrigin(u.Scheme, u.Host)
actual := normalizeOrigin(rs, rh)
if candidate != actual {
return fmt.Errorf("redirect_base does not match this UI origin")
}
return nil
}

func normalizeOrigin(scheme, host string) string {
scheme = strings.ToLower(strings.TrimSpace(scheme))
host = strings.ToLower(strings.TrimSpace(host))
Expand Down
Loading