Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ backend = [
"sqlmodel>=0.0.24,<1.0.0",
"alembic>=1.13.0,<2.0.0",
"sse-starlette>=2.1.0,<4.0.0",
"pyjwt[crypto]>=2.9.0,<3.0.0",
]

# Packages for cryoEM file I/O; used for image displays from API
Expand Down
2 changes: 2 additions & 0 deletions src/smartem_backend/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from smartem_backend import mq_publisher as mq_publisher_module
from smartem_backend.agent_connection_manager import get_connection_manager
from smartem_backend.auth import verify_token
from smartem_backend.frontend_stream import (
query_acquisition_progress,
query_agent_logs,
Expand Down Expand Up @@ -228,6 +229,7 @@ async def lifespan(app: FastAPI):
version=__version__,
redoc_url=None,
lifespan=lifespan,
dependencies=[Depends(verify_token)],
)

# Resolve runtime config (env var overrides appconfig.yml, which overrides hard default)
Expand Down
110 changes: 110 additions & 0 deletions src/smartem_backend/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Keycloak JWT validation as a FastAPI dependency.

Validates `Authorization: Bearer <jwt>` against the configured Keycloak realm's
JWKS. Verification is offline - JWKS is fetched once and cached by PyJWKClient;
no per-request call to Keycloak.

Gated by `KEYCLOAK_AUTH_REQUIRED` (default false) so tests and existing dev
workflows that don't pass tokens are unaffected. When true, every request that
isn't on the exempt path list must carry a valid token.
"""

import logging
import os

import jwt
from fastapi import HTTPException, Request, status
from jwt import PyJWKClient

logger = logging.getLogger("smartem_backend.auth")

EXEMPT_PATHS: frozenset[str] = frozenset({"/health", "/status", "/openapi.json", "/docs", "/redoc"})

_jwks_client: PyJWKClient | None = None


def _is_auth_required() -> bool:
return os.getenv("KEYCLOAK_AUTH_REQUIRED", "false").lower() == "true"


def _verify_iss() -> bool:
return os.getenv("KEYCLOAK_VERIFY_ISS", "true").lower() == "true"


def _keycloak_url() -> str:
return os.getenv("KEYCLOAK_URL", "").rstrip("/")


def _realm() -> str:
return os.getenv("KEYCLOAK_REALM", "dls")


def _issuer() -> str:
return f"{_keycloak_url()}/realms/{_realm()}"


def _jwks_url() -> str:
return f"{_issuer()}/protocol/openid-connect/certs"


def _get_jwks_client() -> PyJWKClient:
global _jwks_client
if _jwks_client is None:
if not _keycloak_url():
raise RuntimeError("KEYCLOAK_URL must be set when KEYCLOAK_AUTH_REQUIRED=true")
_jwks_client = PyJWKClient(_jwks_url(), cache_keys=True, lifespan=600)
logger.info("Initialised Keycloak JWKS client at %s", _jwks_url())
return _jwks_client


def _unauthorized(detail: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
headers={"WWW-Authenticate": "Bearer"},
)


def _extract_bearer(request: Request) -> str:
auth = request.headers.get("authorization", "")
if not auth.lower().startswith("bearer "):
raise _unauthorized("Missing or malformed Authorization header")
token = auth.split(" ", 1)[1].strip()
if not token:
raise _unauthorized("Empty bearer token")
return token


def _verify(token: str) -> dict:
try:
signing_key = _get_jwks_client().get_signing_key_from_jwt(token).key
except jwt.PyJWKClientError as e:
logger.warning("JWKS lookup failed: %s", e)
raise _unauthorized("Cannot resolve signing key") from e
except jwt.PyJWTError as e:
logger.warning("Could not parse token header: %s", e)
raise _unauthorized("Invalid token") from e

options = {"verify_iss": _verify_iss(), "verify_aud": False}
decode_kwargs: dict = {"algorithms": ["RS256"], "options": options}
if _verify_iss():
decode_kwargs["issuer"] = _issuer()

try:
return jwt.decode(token, signing_key, **decode_kwargs)
except jwt.ExpiredSignatureError as e:
raise _unauthorized("Token expired") from e
except jwt.PyJWTError as e:
logger.warning("Token validation failed: %s", e)
raise _unauthorized("Invalid token") from e


async def verify_token(request: Request) -> dict | None:
"""Route dependency. Returns claims when validation succeeded, otherwise None
(auth disabled or exempt path). Raises 401 on validation failure.
"""
if not _is_auth_required():
return None
if request.url.path in EXEMPT_PATHS:
return None
return _verify(_extract_bearer(request))
91 changes: 91 additions & 0 deletions tests/smartem_backend/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Tests for the Keycloak JWT verification dependency.

These cover the gating and routing behaviour. End-to-end signature verification
against a live Keycloak is exercised in staging - constructing a fake JWKS here
adds machinery without adding confidence in the small validate-and-decode call.
"""

import os

import pytest

from .conftest import app, client # noqa: F401 - re-export fixture


@pytest.fixture(autouse=True)
def reset_auth_env(monkeypatch):
"""Ensure each test starts with KEYCLOAK_AUTH_REQUIRED unset (effectively false)."""
monkeypatch.delenv("KEYCLOAK_AUTH_REQUIRED", raising=False)
monkeypatch.delenv("KEYCLOAK_URL", raising=False)


class TestAuthDisabledByDefault:
def test_unauthenticated_status_call_succeeds(self, client):
# Disabled-by-default behaviour is implicitly covered by the rest of the
# backend test suite, which all runs without an Authorization header and
# expects normal responses. This case just pins the contract explicitly.
response = client.get("/status")
assert response.status_code == 200


class TestAuthRequired:
@pytest.fixture(autouse=True)
def enable_auth(self, monkeypatch):
monkeypatch.setenv("KEYCLOAK_AUTH_REQUIRED", "true")
monkeypatch.setenv("KEYCLOAK_URL", "http://keycloak-service:8080")
# Reset cached JWKS client so it picks up env on next request.
import smartem_backend.auth as auth_module
monkeypatch.setattr(auth_module, "_jwks_client", None)

def test_exempt_path_no_token_succeeds(self, client):
assert client.get("/status").status_code == 200
assert client.get("/health").status_code in (200, 503)
assert client.get("/openapi.json").status_code == 200

def test_protected_endpoint_without_token_returns_401(self, client):
response = client.get("/acquisitions")
assert response.status_code == 401
assert response.headers.get("www-authenticate") == "Bearer"
assert "Authorization" in response.json()["detail"]

def test_protected_endpoint_with_malformed_token_returns_401(self, client):
response = client.get(
"/acquisitions",
headers={"Authorization": "Bearer not-a-real-jwt"},
)
assert response.status_code == 401
assert response.headers.get("www-authenticate") == "Bearer"

def test_non_bearer_scheme_returns_401(self, client):
response = client.get(
"/acquisitions",
headers={"Authorization": "Basic dXNlcjpwYXNz"},
)
assert response.status_code == 401

def test_empty_bearer_returns_401(self, client):
response = client.get("/acquisitions", headers={"Authorization": "Bearer "})
assert response.status_code == 401


class TestConfigHelpers:
def test_issuer_from_env(self, monkeypatch):
monkeypatch.setenv("KEYCLOAK_URL", "https://identity-test.diamond.ac.uk/")
monkeypatch.setenv("KEYCLOAK_REALM", "dls")
from smartem_backend import auth

assert auth._issuer() == "https://identity-test.diamond.ac.uk/realms/dls"
assert auth._jwks_url() == (
"https://identity-test.diamond.ac.uk/realms/dls/protocol/openid-connect/certs"
)

def test_jwks_client_requires_url(self, monkeypatch):
monkeypatch.delenv("KEYCLOAK_URL", raising=False)
from smartem_backend import auth

monkeypatch.setattr(auth, "_jwks_client", None)
with pytest.raises(RuntimeError, match="KEYCLOAK_URL"):
auth._get_jwks_client()


_ = os # silence unused-import lint when env helpers aren't touched
Loading