diff --git a/proxy/auth/redirect.go b/proxy/auth/redirect.go index df6bee799..4c765ded7 100644 --- a/proxy/auth/redirect.go +++ b/proxy/auth/redirect.go @@ -13,10 +13,18 @@ import ( const oauthRedirectURICookiePrefix = "oauth_redirect_uri_" -// oauthCallbackFromOrigin builds the OAuth redirect_uri (…/callback) for the given UI origin -// (scheme + host, no path). The path prefix comes from config.BaseUiUrl so deployments served -// from a subpath (e.g. https://host/ui) resolve to …/ui/callback instead of …/callback. -func oauthCallbackFromOrigin(origin *url.URL) (string, error) { +// oauthCallbackURL builds the OAuth redirect_uri (…/callback) for the given UI origin scheme and host. +// The path prefix comes from config.BaseUiUrl so deployments served from a subpath (e.g. https://host/ui) +// resolve to …/ui/callback instead of …/callback. +func oauthCallbackURL(scheme, host string) (string, error) { + scheme = strings.ToLower(strings.TrimSpace(scheme)) + if scheme != "http" && scheme != "https" { + return "", fmt.Errorf("invalid origin scheme") + } + host = strings.TrimSpace(host) + if host == "" || strings.ContainsAny(host, "@\r\n\t") { + return "", fmt.Errorf("invalid origin host") + } baseUI, err := url.Parse(config.BaseUiUrl) if err != nil { return "", fmt.Errorf("invalid BASE_UI_URL configuration: %w", err) @@ -26,12 +34,9 @@ func oauthCallbackFromOrigin(origin *url.URL) (string, error) { if basePath != "" { callbackPath = basePath + "/callback" } - out := &url.URL{ - Scheme: origin.Scheme, - Host: origin.Host, - Path: callbackPath, - } - return out.String(), nil + + pathURL := &url.URL{Path: callbackPath} + return scheme + "://" + host + pathURL.EscapedPath(), nil } // ResolveOAuthRedirectURI returns the OAuth redirect_uri (callback URL) for this login attempt. @@ -47,27 +52,22 @@ func ResolveOAuthRedirectURI(r *http.Request, redirectBase string) (string, erro if err != nil { return "", fmt.Errorf("invalid BASE_UI_URL configuration: %w", err) } - origin := &url.URL{Scheme: baseUI.Scheme, Host: baseUI.Host} - return oauthCallbackFromOrigin(origin) + return oauthCallbackURL(baseUI.Scheme, baseUI.Host) } - u, err := url.Parse(strings.TrimSpace(redirectBase)) + ok, err := isSameSchemeAndHost(redirectBase, r) if err != nil { - return "", fmt.Errorf("invalid redirect_base") + return "", err } - if u.Scheme != "http" && u.Scheme != "https" { - return "", fmt.Errorf("invalid redirect_base: only http and https are allowed") + if !ok { + return "", fmt.Errorf("redirect_base does not match this UI origin") } - if u.Hostname() == "" { - return "", fmt.Errorf("invalid redirect_base: host is required") - } - if u.RawQuery != "" || u.Fragment != "" { - return "", fmt.Errorf("invalid redirect_base: query and fragment are not allowed") - } - origin := &url.URL{Scheme: u.Scheme, Host: u.Host} - if err := redirectBaseMatchesRequest(r, origin); err != nil { + rs, rh := requestSchemeAndHost(r) + canon := normalizeOrigin(rs, rh) + scheme, host, err := splitCanonicalOriginURL(canon) + if err != nil { return "", err } - return oauthCallbackFromOrigin(origin) + return oauthCallbackURL(scheme, host) } // ResolveLogoutRedirectBase returns the UI base URL for OIDC post_logout_redirect_uri (no trailing slash). @@ -82,6 +82,24 @@ func ResolveLogoutRedirectBase(r *http.Request, redirectBase string) (string, er return strings.TrimSuffix(callbackURI, "/callback"), nil } +// splitCanonicalOriginURL splits a string produced only by normalizeOrigin ("scheme://host[:port]"). +func splitCanonicalOriginURL(canon string) (scheme, host string, err error) { + const sep = "://" + i := strings.Index(canon, sep) + if i < 0 { + return "", "", fmt.Errorf("invalid canonical origin") + } + scheme = canon[:i] + host = canon[i+len(sep):] + if scheme != "http" && scheme != "https" { + return "", "", fmt.Errorf("invalid canonical origin scheme") + } + if host == "" { + return "", "", fmt.Errorf("invalid canonical origin host") + } + return scheme, host, nil +} + func requestSchemeAndHost(r *http.Request) (scheme, host string) { scheme = "http" if r.TLS != nil { @@ -100,6 +118,37 @@ func requestSchemeAndHost(r *http.Request) (scheme, host string) { return scheme, host } +// isSameSchemeAndHost parses redirectBase as an http(s) URL, validates it, and reports whether +// its origin matches the effective request (see requestSchemeAndHost). Empty redirectBase +// returns (false, nil). Invalid redirect_base returns (_, err). +func isSameSchemeAndHost(redirectBase string, r *http.Request) (bool, error) { + s := strings.TrimSpace(redirectBase) + if s == "" { + return false, nil + } + u, err := url.Parse(s) + if err != nil { + return false, fmt.Errorf("invalid redirect_base") + } + if u.Scheme != "http" && u.Scheme != "https" { + return false, fmt.Errorf("invalid redirect_base: only http and https are allowed") + } + if u.Hostname() == "" { + return false, fmt.Errorf("invalid redirect_base: host is required") + } + if u.RawQuery != "" || u.Fragment != "" { + return false, fmt.Errorf("invalid redirect_base: query and fragment are not allowed") + } + if u.User != nil { + return false, fmt.Errorf("invalid redirect_base: user info is not allowed") + } + rs, rh := requestSchemeAndHost(r) + if normalizeOrigin(u.Scheme, u.Host) != normalizeOrigin(rs, rh) { + return false, nil + } + return true, nil +} + // cookieSecureForRequest is true when the Set-Cookie Secure attribute should be set: TLS is // configured on this proxy, or the effective request scheme is HTTPS (including when TLS // terminates at a reverse proxy and X-Forwarded-Proto is trusted). @@ -114,16 +163,6 @@ func cookieSecureForRequest(r *http.Request) bool { return strings.EqualFold(strings.TrimSpace(scheme), "https") } -func redirectBaseMatchesRequest(r *http.Request, u *url.URL) error { - rs, rh := requestSchemeAndHost(r) - candidate := normalizeOrigin(u.Scheme, u.Host) - actual := normalizeOrigin(rs, rh) - if candidate != actual { - return fmt.Errorf("redirect_base does not match this UI origin") - } - return nil -} - func normalizeOrigin(scheme, host string) string { scheme = strings.ToLower(strings.TrimSpace(scheme)) host = strings.ToLower(strings.TrimSpace(host))