From 4b0a2373e85eb0da71d28ff1e30124b8e9687780 Mon Sep 17 00:00:00 2001 From: Val Redchenko Date: Thu, 14 May 2026 10:47:38 +0100 Subject: [PATCH] feat: validate Keycloak Bearer tokens via FastAPI dependency Add `verify_token` as a global FastAPI dependency that validates incoming `Authorization: Bearer ` headers against the configured Keycloak realm's JWKS. Verification is offline - JWKS is fetched once and cached by PyJWT's PyJWKClient; the backend never calls Keycloak per-request. The dependency is gated by `KEYCLOAK_AUTH_REQUIRED` (default false) so the existing test suite, dev workflows, and any already-running deployment behave unchanged. When enabled, every request that isn't on a small exempt list (`/health`, `/status`, `/openapi.json`, `/docs`, `/redoc`) must carry a valid token; the rest receive 401 with a `WWW-Authenticate: Bearer` header. Config knobs (all env vars): - KEYCLOAK_AUTH_REQUIRED: master switch - KEYCLOAK_URL: base URL of the Keycloak server - KEYCLOAK_REALM: realm name (default `dls`) - KEYCLOAK_CLIENT_ID: client ID (reserved for future aud check) - KEYCLOAK_VERIFY_ISS: enforce iss claim (default true) The verify-iss escape hatch exists for the dev mock, where tokens are minted with the browser-facing URL (localhost:30090) while the pod fetches JWKS via in-cluster DNS (keycloak-service:8080) - signature is sound but issuer strings don't match. Staging and production point at the real DLS realm and leave iss verification on. Adds `pyjwt[crypto]` to the `backend` extra and a focused test module covering the gating, exempt-path, malformed-token, and config-helper paths. The remaining 191 backend tests are unaffected (auth defaults to off). --- pyproject.toml | 1 + src/smartem_backend/api_server.py | 2 + src/smartem_backend/auth.py | 110 +++++++++++++++++++++++++++++ tests/smartem_backend/test_auth.py | 91 ++++++++++++++++++++++++ 4 files changed, 204 insertions(+) create mode 100644 src/smartem_backend/auth.py create mode 100644 tests/smartem_backend/test_auth.py diff --git a/pyproject.toml b/pyproject.toml index d2747029..7ea7720a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ backend = [ "sqlmodel>=0.0.24,<1.0.0", "alembic>=1.13.0,<2.0.0", "sse-starlette>=2.1.0,<4.0.0", + "pyjwt[crypto]>=2.9.0,<3.0.0", ] # Packages for cryoEM file I/O; used for image displays from API diff --git a/src/smartem_backend/api_server.py b/src/smartem_backend/api_server.py index e2488958..15c52078 100644 --- a/src/smartem_backend/api_server.py +++ b/src/smartem_backend/api_server.py @@ -24,6 +24,7 @@ from smartem_backend import mq_publisher as mq_publisher_module from smartem_backend.agent_connection_manager import get_connection_manager +from smartem_backend.auth import verify_token from smartem_backend.frontend_stream import ( query_acquisition_progress, query_agent_logs, @@ -228,6 +229,7 @@ async def lifespan(app: FastAPI): version=__version__, redoc_url=None, lifespan=lifespan, + dependencies=[Depends(verify_token)], ) # Resolve runtime config (env var overrides appconfig.yml, which overrides hard default) diff --git a/src/smartem_backend/auth.py b/src/smartem_backend/auth.py new file mode 100644 index 00000000..904c036d --- /dev/null +++ b/src/smartem_backend/auth.py @@ -0,0 +1,110 @@ +"""Keycloak JWT validation as a FastAPI dependency. + +Validates `Authorization: Bearer ` against the configured Keycloak realm's +JWKS. Verification is offline - JWKS is fetched once and cached by PyJWKClient; +no per-request call to Keycloak. + +Gated by `KEYCLOAK_AUTH_REQUIRED` (default false) so tests and existing dev +workflows that don't pass tokens are unaffected. When true, every request that +isn't on the exempt path list must carry a valid token. +""" + +import logging +import os + +import jwt +from fastapi import HTTPException, Request, status +from jwt import PyJWKClient + +logger = logging.getLogger("smartem_backend.auth") + +EXEMPT_PATHS: frozenset[str] = frozenset({"/health", "/status", "/openapi.json", "/docs", "/redoc"}) + +_jwks_client: PyJWKClient | None = None + + +def _is_auth_required() -> bool: + return os.getenv("KEYCLOAK_AUTH_REQUIRED", "false").lower() == "true" + + +def _verify_iss() -> bool: + return os.getenv("KEYCLOAK_VERIFY_ISS", "true").lower() == "true" + + +def _keycloak_url() -> str: + return os.getenv("KEYCLOAK_URL", "").rstrip("/") + + +def _realm() -> str: + return os.getenv("KEYCLOAK_REALM", "dls") + + +def _issuer() -> str: + return f"{_keycloak_url()}/realms/{_realm()}" + + +def _jwks_url() -> str: + return f"{_issuer()}/protocol/openid-connect/certs" + + +def _get_jwks_client() -> PyJWKClient: + global _jwks_client + if _jwks_client is None: + if not _keycloak_url(): + raise RuntimeError("KEYCLOAK_URL must be set when KEYCLOAK_AUTH_REQUIRED=true") + _jwks_client = PyJWKClient(_jwks_url(), cache_keys=True, lifespan=600) + logger.info("Initialised Keycloak JWKS client at %s", _jwks_url()) + return _jwks_client + + +def _unauthorized(detail: str) -> HTTPException: + return HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + headers={"WWW-Authenticate": "Bearer"}, + ) + + +def _extract_bearer(request: Request) -> str: + auth = request.headers.get("authorization", "") + if not auth.lower().startswith("bearer "): + raise _unauthorized("Missing or malformed Authorization header") + token = auth.split(" ", 1)[1].strip() + if not token: + raise _unauthorized("Empty bearer token") + return token + + +def _verify(token: str) -> dict: + try: + signing_key = _get_jwks_client().get_signing_key_from_jwt(token).key + except jwt.PyJWKClientError as e: + logger.warning("JWKS lookup failed: %s", e) + raise _unauthorized("Cannot resolve signing key") from e + except jwt.PyJWTError as e: + logger.warning("Could not parse token header: %s", e) + raise _unauthorized("Invalid token") from e + + options = {"verify_iss": _verify_iss(), "verify_aud": False} + decode_kwargs: dict = {"algorithms": ["RS256"], "options": options} + if _verify_iss(): + decode_kwargs["issuer"] = _issuer() + + try: + return jwt.decode(token, signing_key, **decode_kwargs) + except jwt.ExpiredSignatureError as e: + raise _unauthorized("Token expired") from e + except jwt.PyJWTError as e: + logger.warning("Token validation failed: %s", e) + raise _unauthorized("Invalid token") from e + + +async def verify_token(request: Request) -> dict | None: + """Route dependency. Returns claims when validation succeeded, otherwise None + (auth disabled or exempt path). Raises 401 on validation failure. + """ + if not _is_auth_required(): + return None + if request.url.path in EXEMPT_PATHS: + return None + return _verify(_extract_bearer(request)) diff --git a/tests/smartem_backend/test_auth.py b/tests/smartem_backend/test_auth.py new file mode 100644 index 00000000..cc3eda71 --- /dev/null +++ b/tests/smartem_backend/test_auth.py @@ -0,0 +1,91 @@ +"""Tests for the Keycloak JWT verification dependency. + +These cover the gating and routing behaviour. End-to-end signature verification +against a live Keycloak is exercised in staging - constructing a fake JWKS here +adds machinery without adding confidence in the small validate-and-decode call. +""" + +import os + +import pytest + +from .conftest import app, client # noqa: F401 - re-export fixture + + +@pytest.fixture(autouse=True) +def reset_auth_env(monkeypatch): + """Ensure each test starts with KEYCLOAK_AUTH_REQUIRED unset (effectively false).""" + monkeypatch.delenv("KEYCLOAK_AUTH_REQUIRED", raising=False) + monkeypatch.delenv("KEYCLOAK_URL", raising=False) + + +class TestAuthDisabledByDefault: + def test_unauthenticated_status_call_succeeds(self, client): + # Disabled-by-default behaviour is implicitly covered by the rest of the + # backend test suite, which all runs without an Authorization header and + # expects normal responses. This case just pins the contract explicitly. + response = client.get("/status") + assert response.status_code == 200 + + +class TestAuthRequired: + @pytest.fixture(autouse=True) + def enable_auth(self, monkeypatch): + monkeypatch.setenv("KEYCLOAK_AUTH_REQUIRED", "true") + monkeypatch.setenv("KEYCLOAK_URL", "http://keycloak-service:8080") + # Reset cached JWKS client so it picks up env on next request. + import smartem_backend.auth as auth_module + monkeypatch.setattr(auth_module, "_jwks_client", None) + + def test_exempt_path_no_token_succeeds(self, client): + assert client.get("/status").status_code == 200 + assert client.get("/health").status_code in (200, 503) + assert client.get("/openapi.json").status_code == 200 + + def test_protected_endpoint_without_token_returns_401(self, client): + response = client.get("/acquisitions") + assert response.status_code == 401 + assert response.headers.get("www-authenticate") == "Bearer" + assert "Authorization" in response.json()["detail"] + + def test_protected_endpoint_with_malformed_token_returns_401(self, client): + response = client.get( + "/acquisitions", + headers={"Authorization": "Bearer not-a-real-jwt"}, + ) + assert response.status_code == 401 + assert response.headers.get("www-authenticate") == "Bearer" + + def test_non_bearer_scheme_returns_401(self, client): + response = client.get( + "/acquisitions", + headers={"Authorization": "Basic dXNlcjpwYXNz"}, + ) + assert response.status_code == 401 + + def test_empty_bearer_returns_401(self, client): + response = client.get("/acquisitions", headers={"Authorization": "Bearer "}) + assert response.status_code == 401 + + +class TestConfigHelpers: + def test_issuer_from_env(self, monkeypatch): + monkeypatch.setenv("KEYCLOAK_URL", "https://identity-test.diamond.ac.uk/") + monkeypatch.setenv("KEYCLOAK_REALM", "dls") + from smartem_backend import auth + + assert auth._issuer() == "https://identity-test.diamond.ac.uk/realms/dls" + assert auth._jwks_url() == ( + "https://identity-test.diamond.ac.uk/realms/dls/protocol/openid-connect/certs" + ) + + def test_jwks_client_requires_url(self, monkeypatch): + monkeypatch.delenv("KEYCLOAK_URL", raising=False) + from smartem_backend import auth + + monkeypatch.setattr(auth, "_jwks_client", None) + with pytest.raises(RuntimeError, match="KEYCLOAK_URL"): + auth._get_jwks_client() + + +_ = os # silence unused-import lint when env helpers aren't touched