Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ dependencies = [

[tool.pyright]
exclude = [
# TODO(lucasagomes): This module was copied from road-core
# service/ols/src/auth/k8s.py and currently has 58 Pyright issues. It
# might need to be rewritten down the line.
"src/authentication/k8s.py",
# Agent API v1 endpoints - deprecated API but still supported
# Type errors due to llama-stack-client not exposing Agent API types
"src/app/endpoints/conversations.py",
Expand Down
77 changes: 54 additions & 23 deletions src/authentication/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from pathlib import Path
from typing import Optional, Self
from typing import Optional, Self, cast

import kubernetes.client
from fastapi import HTTPException, Request
Expand Down Expand Up @@ -80,19 +80,19 @@ def __new__(cls: type[Self]) -> Self:
ce,
)

k8s_config.host = (
configuration.authentication_configuration.k8s_cluster_api
or k8s_config.host
)
k8s_api_url = configuration.authentication_configuration.k8s_cluster_api
if k8s_api_url:
k8s_config.host = str(k8s_api_url)
k8s_config.verify_ssl = (
not configuration.authentication_configuration.skip_tls_verification
)
k8s_config.ssl_ca_cert = (
ca_cert_path = (
configuration.authentication_configuration.k8s_ca_cert_path
if configuration.authentication_configuration.k8s_ca_cert_path
not in {None, Path()}
else k8s_config.ssl_ca_cert
)
if ca_cert_path and ca_cert_path != Path():
# Kubernetes client library has incomplete type stubs for ssl_ca_cert
k8s_config.ssl_ca_cert = str(ca_cert_path) # type: ignore[assignment]
# else keep the default k8s_config.ssl_ca_cert
api_client = kubernetes.client.ApiClient(k8s_config)
cls._api_client = api_client
cls._custom_objects_api = kubernetes.client.CustomObjectsApi(api_client)
Expand All @@ -101,7 +101,8 @@ def __new__(cls: type[Self]) -> Self:
except Exception as e:
logger.info("Failed to initialize Kubernetes client: %s", e)
raise
return cls._instance
# At this point _instance is guaranteed to be initialized
return cast(Self, cls._instance)

@classmethod
def get_authn_api(cls) -> kubernetes.client.AuthenticationV1Api:
Expand Down Expand Up @@ -159,12 +160,28 @@ def _get_cluster_id(cls) -> str:
ClusterIDUnavailableError: If the cluster ID cannot be obtained due
to missing keys, an API error, or any unexpected error.
"""
version_data = None
try:
custom_objects_api = cls.get_custom_objects_api()
version_data = custom_objects_api.get_cluster_custom_object(
"config.openshift.io", "v1", "clusterversions", "version"
)
cluster_id = version_data["spec"]["clusterID"]
# Type validation: ensure we got a dict-like object
if not isinstance(version_data, dict):
raise TypeError(
f"Expected dict for version_data, got {type(version_data)}"
)
# Kubernetes client library returns untyped dict-like objects
spec = version_data.get("spec")
if not isinstance(spec, dict):
raise ClusterIDUnavailableError(
"Missing or invalid 'spec' in ClusterVersion"
)
cluster_id = spec.get("clusterID")
if not isinstance(cluster_id, str) or not cluster_id.strip():
raise ClusterIDUnavailableError(
"Missing or invalid 'clusterID' in ClusterVersion"
)
cls._cluster_id = cluster_id
return cluster_id
except KeyError as e:
Expand Down Expand Up @@ -212,14 +229,14 @@ def get_cluster_id(cls) -> str:
return cls._cluster_id


def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReview]:
def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReviewStatus]:
"""Perform a Kubernetes TokenReview to validate a given token.

Parameters:
token: The bearer token to be validated.

Returns:
The user information if the token is valid, None otherwise.
The V1TokenReviewStatus if the token is valid, None otherwise.

Raises:
HTTPException: If unable to connect to Kubernetes API or unexpected error occurs.
Expand All @@ -239,8 +256,10 @@ def get_user_info(token: str) -> Optional[kubernetes.client.V1TokenReview]:
)
try:
response = auth_api.create_token_review(token_review)
if response.status.authenticated:
return response.status
# Kubernetes client library has incomplete type stubs
status = response.status # type: ignore[union-attr]
if status is not None and status.authenticated:
return status
return None
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("API exception during TokenReview: %s", e)
Expand Down Expand Up @@ -307,9 +326,10 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]:
response = UnauthorizedResponse(cause="Invalid or expired Kubernetes token")
raise HTTPException(**response.model_dump())

if user_info.user.username == "kube:admin":
# Kubernetes client library has incomplete type stubs
if user_info.user.username == "kube:admin": # type: ignore[union-attr]
try:
user_info.user.uid = K8sClientSingleton.get_cluster_id()
user_info.user.uid = K8sClientSingleton.get_cluster_id() # type: ignore[union-attr]
except ClusterIDUnavailableError as e:
logger.error("Failed to get cluster ID: %s", e)
response = InternalServerErrorResponse(
Expand All @@ -318,12 +338,22 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]:
)
raise HTTPException(**response.model_dump()) from e

# Validate that uid is present and is a string
user_uid = user_info.user.uid # type: ignore[union-attr]
if not isinstance(user_uid, str) or not user_uid:
logger.error("Authenticated Kubernetes user is missing a UID")
response = InternalServerErrorResponse(
response="Internal server error",
cause="Authenticated Kubernetes user is missing a UID",
)
raise HTTPException(**response.model_dump())

try:
authorization_api = K8sClientSingleton.get_authz_api()
sar = kubernetes.client.V1SubjectAccessReview(
spec=kubernetes.client.V1SubjectAccessReviewSpec(
user=user_info.user.username,
groups=user_info.user.groups,
user=user_info.user.username, # type: ignore[union-attr]
groups=user_info.user.groups, # type: ignore[union-attr]
non_resource_attributes=kubernetes.client.V1NonResourceAttributes(
path=self.virtual_path, verb="get"
),
Expand All @@ -339,13 +369,14 @@ async def __call__(self, request: Request) -> tuple[str, str, bool, str]:
)
raise HTTPException(**response.model_dump()) from e

if not response.status.allowed:
response = ForbiddenResponse.endpoint(user_id=user_info.user.uid)
# Kubernetes client library has incomplete type stubs
if not response.status.allowed: # type: ignore[union-attr]
response = ForbiddenResponse.endpoint(user_id=user_uid)
raise HTTPException(**response.model_dump())

return (
user_info.user.uid,
user_info.user.username,
user_uid,
user_info.user.username, # type: ignore[union-attr]
self.skip_userid_check,
token,
)
Loading