diff --git a/pyproject.toml b/pyproject.toml index c3025c3fe..86ef9cdcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,10 +74,6 @@ dependencies = [ [tool.pyright] exclude = [ - # TODO(lucasagomes): This module was copied from road-core - # service/ols/src/auth/k8s.py and currently has 58 Pyright issues. It - # might need to be rewritten down the line. - "src/authentication/k8s.py", # Agent API v1 endpoints - deprecated API but still supported # Type errors due to llama-stack-client not exposing Agent API types "src/app/endpoints/conversations.py", diff --git a/src/authentication/k8s.py b/src/authentication/k8s.py index 399a3d40c..8792c8a26 100644 --- a/src/authentication/k8s.py +++ b/src/authentication/k8s.py @@ -2,7 +2,7 @@ import os from pathlib import Path -from typing import Optional, Self +from typing import Optional, Self, cast import kubernetes.client from fastapi import HTTPException, Request @@ -80,19 +80,19 @@ def __new__(cls: type[Self]) -> Self: ce, ) - k8s_config.host = ( - configuration.authentication_configuration.k8s_cluster_api - or k8s_config.host - ) + k8s_api_url = configuration.authentication_configuration.k8s_cluster_api + if k8s_api_url: + k8s_config.host = str(k8s_api_url) k8s_config.verify_ssl = ( not configuration.authentication_configuration.skip_tls_verification ) - k8s_config.ssl_ca_cert = ( + ca_cert_path = ( configuration.authentication_configuration.k8s_ca_cert_path - if configuration.authentication_configuration.k8s_ca_cert_path - not in {None, Path()} - else k8s_config.ssl_ca_cert ) + if ca_cert_path and ca_cert_path != Path(): + # Kubernetes client library has incomplete type stubs for ssl_ca_cert + k8s_config.ssl_ca_cert = str(ca_cert_path) # type: ignore[assignment] + # else keep the default k8s_config.ssl_ca_cert api_client = kubernetes.client.ApiClient(k8s_config) cls._api_client = api_client cls._custom_objects_api = kubernetes.client.CustomObjectsApi(api_client) @@ -101,7 +101,8 @@ def __new__(cls: type[Self]) -> Self: except Exception as e: logger.info("Failed to initialize Kubernetes client: %s", e) raise - return cls._instance + # At this point _instance is guaranteed to be initialized + return cast(Self, cls._instance) @classmethod def get_authn_api(cls) -> kubernetes.client.AuthenticationV1Api: @@ -159,12 +160,28 @@ def _get_cluster_id(cls) -> str: ClusterIDUnavailableError: If the cluster ID cannot be obtained due to missing keys, an API error, or any unexpected error. """ + version_data = None try: custom_objects_api = cls.get_custom_objects_api() version_data = custom_objects_api.get_cluster_custom_object( "config.openshift.io", "v1", "clusterversions", "version" ) - cluster_id = version_data["spec"]["clusterID"] + # Type validation: ensure we got a dict-like object + if not isinstance(version_data, dict): + raise TypeError( + f"Expected dict for version_data, got {type(version_data)}" + ) + # Kubernetes client library returns untyped dict-like objects + spec = version_data.get("spec") + if not isinstance(spec, dict): + raise ClusterIDUnavailableError( + "Missing or invalid 'spec' in ClusterVersion" + ) + cluster_id = spec.get("clusterID") + if not isinstance(cluster_id, str) or not cluster_id.strip(): + raise ClusterIDUnavailableError( + "Missing or invalid 'clusterID' in ClusterVersion" + ) cls._cluster_id = cluster_id return cluster_id except KeyError as e: @@ -212,14 +229,14 @@ def get_cluster_id(cls) -> str: return cls._cluster_id -def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReview]: +def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReviewStatus]: """Perform a Kubernetes TokenReview to validate a given token. Parameters: token: The bearer token to be validated. Returns: - The user information if the token is valid, None otherwise. + The V1TokenReviewStatus if the token is valid, None otherwise. Raises: HTTPException: If unable to connect to Kubernetes API or unexpected error occurs. @@ -239,8 +256,10 @@ def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReview]: ) try: response = auth_api.create_token_review(token_review) - if response.status.authenticated: - return response.status + # Kubernetes client library has incomplete type stubs + status = response.status # type: ignore[union-attr] + if status is not None and status.authenticated: + return status return None except Exception as e: # pylint: disable=broad-exception-caught logger.error("API exception during TokenReview: %s", e) @@ -307,9 +326,10 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: response = UnauthorizedResponse(cause="Invalid or expired Kubernetes token") raise HTTPException(**response.model_dump()) - if user_info.user.username == "kube:admin": + # Kubernetes client library has incomplete type stubs + if user_info.user.username == "kube:admin": # type: ignore[union-attr] try: - user_info.user.uid = K8sClientSingleton.get_cluster_id() + user_info.user.uid = K8sClientSingleton.get_cluster_id() # type: ignore[union-attr] except ClusterIDUnavailableError as e: logger.error("Failed to get cluster ID: %s", e) response = InternalServerErrorResponse( @@ -318,12 +338,22 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: ) raise HTTPException(**response.model_dump()) from e + # Validate that uid is present and is a string + user_uid = user_info.user.uid # type: ignore[union-attr] + if not isinstance(user_uid, str) or not user_uid: + logger.error("Authenticated Kubernetes user is missing a UID") + response = InternalServerErrorResponse( + response="Internal server error", + cause="Authenticated Kubernetes user is missing a UID", + ) + raise HTTPException(**response.model_dump()) + try: authorization_api = K8sClientSingleton.get_authz_api() sar = kubernetes.client.V1SubjectAccessReview( spec=kubernetes.client.V1SubjectAccessReviewSpec( - user=user_info.user.username, - groups=user_info.user.groups, + user=user_info.user.username, # type: ignore[union-attr] + groups=user_info.user.groups, # type: ignore[union-attr] non_resource_attributes=kubernetes.client.V1NonResourceAttributes( path=self.virtual_path, verb="get" ), @@ -339,13 +369,14 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]: ) raise HTTPException(**response.model_dump()) from e - if not response.status.allowed: - response = ForbiddenResponse.endpoint(user_id=user_info.user.uid) + # Kubernetes client library has incomplete type stubs + if not response.status.allowed: # type: ignore[union-attr] + response = ForbiddenResponse.endpoint(user_id=user_uid) raise HTTPException(**response.model_dump()) return ( - user_info.user.uid, - user_info.user.username, + user_uid, + user_info.user.username, # type: ignore[union-attr] self.skip_userid_check, token, )