diff --git a/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py b/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py index f634ff00..3de9a6f3 100644 --- a/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py +++ b/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py @@ -28,7 +28,7 @@ from user_management_module.user_container import UserContainer from user_management_module.services.user_service import UserService from user_management_module.utils.password_utils import verify_password -from user_management_module.utils.user_utils import get_session_cache_key +from user_management_module.utils.user_utils import get_session_cache_key, validate_redirect_url from authenticator import AuthenticatorType from authenticator.helper import validate_email @@ -355,13 +355,13 @@ async def _handle_oauth_callback( auth_uuid, authenticator_repository ) - # Helper to get failure URL from config + # Helper to get failure URL from config with validation def get_failure_redirect(error_msg: str) -> RedirectResponse: if config_data: failure_url = config_data.get('config', {}).get( 'client_redirect_failure_url' ) - if failure_url: + if failure_url and validate_redirect_url(failure_url): provider = config_data.get('auth_type') params = urlencode({'provider': provider, 'error': error_msg}) return RedirectResponse(url=f'{failure_url}?{params}') @@ -375,10 +375,14 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: if authenticator is None: return get_failure_redirect(f'Authenticator {auth_id} is not enabled') - # Extract redirect URLs + # Extract and validate redirect URLs provider = config_data.get('auth_type') - success_url = config_data.get('config', {}).get('client_redirect_success_url') - failure_url = config_data.get('config', {}).get('client_redirect_failure_url') + success_url_raw = config_data.get('config', {}).get('client_redirect_success_url') + failure_url_raw = config_data.get('config', {}).get('client_redirect_failure_url') + + # Validate URLs to prevent open redirect vulnerabilities + success_url = success_url_raw if validate_redirect_url(success_url_raw) else None + failure_url = failure_url_raw if validate_redirect_url(failure_url_raw) else None # Handle OAuth error from provider if callback_data.get('error'): @@ -484,7 +488,8 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: failure_url = config_data.get('config', {}).get( 'client_redirect_failure_url' ) - if failure_url: + # Validate URL to prevent open redirect vulnerabilities + if failure_url and validate_redirect_url(failure_url): provider = config_data.get('auth_type') params = urlencode( { @@ -493,7 +498,7 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: } ) return RedirectResponse(url=f'{failure_url}?{params}') - except Exception as e: + except Exception: pass return RedirectResponse(url='about:blank') diff --git a/wavefront/server/modules/user_management_module/user_management_module/utils/user_utils.py b/wavefront/server/modules/user_management_module/user_management_module/utils/user_utils.py index cd93562d..8ed771ba 100644 --- a/wavefront/server/modules/user_management_module/user_management_module/utils/user_utils.py +++ b/wavefront/server/modules/user_management_module/user_management_module/utils/user_utils.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import Optional +from typing import Optional, List +from urllib.parse import urlparse from common_module.response_formatter import ResponseFormatter import uuid @@ -80,3 +81,62 @@ def create_account_lockout_response( def get_session_cache_key(session_id: Union[str, uuid.UUID]) -> str: return f'session_{str(session_id)}' + + +def validate_redirect_url( + url: Optional[str], allowed_domains: Optional[List[str]] = None +) -> bool: + """ + Validate a redirect URL to prevent open redirect vulnerabilities. + + Args: + url: The URL to validate + allowed_domains: Optional list of allowed domains. If provided, the URL's + domain must match one of these domains. + + Returns: + True if the URL is safe to redirect to, False otherwise + """ + if not url: + return False + + try: + parsed = urlparse(url) + + # Must have a valid scheme (http or https only) + if parsed.scheme not in ('http', 'https'): + return False + + # Must have a valid netloc (host) + if not parsed.netloc: + return False + + # Prevent protocol-relative URL bypass (e.g., //evil.com) + if url.startswith('//'): + return False + + # Prevent backslash-based bypasses (e.g., /\evil.com) + if '\\' in url: + return False + + # Prevent URL with credentials (e.g., https://attacker.com@legitimate.com) + if parsed.username or parsed.password: + return False + + # If allowed_domains is specified, validate against it + if allowed_domains: + # Extract the domain (without port) + domain = parsed.netloc.split(':')[0].lower() + allowed_lower = [d.lower() for d in allowed_domains] + + # Check exact match or subdomain match + if not any( + domain == allowed or domain.endswith('.' + allowed) + for allowed in allowed_lower + ): + return False + + return True + + except Exception: + return False