Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from user_management_module.user_container import UserContainer
from user_management_module.services.user_service import UserService
from user_management_module.utils.password_utils import verify_password
from user_management_module.utils.user_utils import get_session_cache_key
from user_management_module.utils.user_utils import get_session_cache_key, validate_redirect_url

from authenticator import AuthenticatorType
from authenticator.helper import validate_email
Expand Down Expand Up @@ -355,13 +355,13 @@ async def _handle_oauth_callback(
auth_uuid, authenticator_repository
)

# Helper to get failure URL from config
# Helper to get failure URL from config with validation
def get_failure_redirect(error_msg: str) -> RedirectResponse:
if config_data:
failure_url = config_data.get('config', {}).get(
'client_redirect_failure_url'
)
if failure_url:
if failure_url and validate_redirect_url(failure_url):
provider = config_data.get('auth_type')
params = urlencode({'provider': provider, 'error': error_msg})
return RedirectResponse(url=f'{failure_url}?{params}')
Expand All @@ -375,10 +375,14 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse:
if authenticator is None:
return get_failure_redirect(f'Authenticator {auth_id} is not enabled')

# Extract redirect URLs
# Extract and validate redirect URLs
provider = config_data.get('auth_type')
success_url = config_data.get('config', {}).get('client_redirect_success_url')
failure_url = config_data.get('config', {}).get('client_redirect_failure_url')
success_url_raw = config_data.get('config', {}).get('client_redirect_success_url')
failure_url_raw = config_data.get('config', {}).get('client_redirect_failure_url')

# Validate URLs to prevent open redirect vulnerabilities
success_url = success_url_raw if validate_redirect_url(success_url_raw) else None
failure_url = failure_url_raw if validate_redirect_url(failure_url_raw) else None

# Handle OAuth error from provider
if callback_data.get('error'):
Expand Down Expand Up @@ -484,7 +488,8 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse:
failure_url = config_data.get('config', {}).get(
'client_redirect_failure_url'
)
if failure_url:
# Validate URL to prevent open redirect vulnerabilities
if failure_url and validate_redirect_url(failure_url):
provider = config_data.get('auth_type')
params = urlencode(
{
Expand All @@ -493,7 +498,7 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse:
}
)
return RedirectResponse(url=f'{failure_url}?{params}')
except Exception as e:
except Exception:
pass

return RedirectResponse(url='about:blank')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import Optional
from typing import Optional, List
from urllib.parse import urlparse

from common_module.response_formatter import ResponseFormatter
import uuid
Expand Down Expand Up @@ -80,3 +81,62 @@ def create_account_lockout_response(

def get_session_cache_key(session_id: Union[str, uuid.UUID]) -> str:
return f'session_{str(session_id)}'


def validate_redirect_url(
url: Optional[str], allowed_domains: Optional[List[str]] = None
) -> bool:
"""
Validate a redirect URL to prevent open redirect vulnerabilities.

Args:
url: The URL to validate
allowed_domains: Optional list of allowed domains. If provided, the URL's
domain must match one of these domains.

Returns:
True if the URL is safe to redirect to, False otherwise
"""
if not url:
return False

try:
parsed = urlparse(url)

# Must have a valid scheme (http or https only)
if parsed.scheme not in ('http', 'https'):
return False

# Must have a valid netloc (host)
if not parsed.netloc:
return False

# Prevent protocol-relative URL bypass (e.g., //evil.com)
if url.startswith('//'):
return False

# Prevent backslash-based bypasses (e.g., /\evil.com)
if '\\' in url:
return False

# Prevent URL with credentials (e.g., https://attacker.com@legitimate.com)
if parsed.username or parsed.password:
return False

# If allowed_domains is specified, validate against it
if allowed_domains:
# Extract the domain (without port)
domain = parsed.netloc.split(':')[0].lower()
allowed_lower = [d.lower() for d in allowed_domains]

# Check exact match or subdomain match
if not any(
domain == allowed or domain.endswith('.' + allowed)
for allowed in allowed_lower
):
return False

return True

except Exception:
return False
Loading