diff --git a/docs_src/static_files/tutorial002_auth_py310.py b/docs_src/static_files/tutorial002_auth_py310.py new file mode 100644 index 0000000000000..b9cd40e404f54 --- /dev/null +++ b/docs_src/static_files/tutorial002_auth_py310.py @@ -0,0 +1,19 @@ +from fastapi import FastAPI, HTTPException, Request +from fastapi.staticfiles import AuthStaticFiles + + +async def verify_token(request: Request) -> None: + """Check for a valid Bearer token in the Authorization header.""" + token = request.headers.get("Authorization") + if token != "Bearer mysecrettoken": + raise HTTPException(status_code=401, detail="Not authenticated") + + +app = FastAPI() + +# Private files - requires a valid token to access +app.mount( + "/private", + AuthStaticFiles(directory="private_files", auth=verify_token), + name="private", +) diff --git a/fastapi/staticfiles.py b/fastapi/staticfiles.py index 299015d4fef26..978f83de11d4f 100644 --- a/fastapi/staticfiles.py +++ b/fastapi/staticfiles.py @@ -1 +1,109 @@ +from collections.abc import Awaitable, Callable +from typing import Any + +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response from starlette.staticfiles import StaticFiles as StaticFiles # noqa +from starlette.types import Receive, Scope, Send + + +class AuthStaticFiles(StaticFiles): + """ + A static files handler that requires authentication before serving files. + + This solves the problem where `app.mount("/static", StaticFiles(...))` serves + files without any authentication, making it impossible to protect private files. + + `AuthStaticFiles` accepts an `auth` callable that receives a `Request` and + should either return successfully (authenticated) or raise an `HTTPException` + (not authenticated). + + ## Usage + + ```python + from fastapi import FastAPI, HTTPException, Request + from fastapi.staticfiles import AuthStaticFiles + + app = FastAPI() + + + async def verify_token(request: Request) -> None: + token = request.headers.get("Authorization") + if token != "Bearer mysecrettoken": + raise HTTPException(status_code=401, detail="Unauthorized") + + + app.mount( + "/private", + AuthStaticFiles(directory="private_files", auth=verify_token), + name="private", + ) + ``` + + ## Parameters + + * `auth`: An async callable that takes a `Request` object and performs + authentication. It should raise an `HTTPException` if authentication + fails, or return `None` if authentication succeeds. + * `on_error`: An optional callable that takes a `Request` and an + `HTTPException` and returns a `Response`. Use this to customize + error responses (e.g., redirect to login, return HTML instead of + plain text). If not provided, a plain text error response is returned. + * `directory`: The directory to serve files from. + * `packages`: A list of Python packages to serve files from. + * `html`: If `True`, serves `index.html` files for directories. + * `check_dir`: If `True`, checks that the directory exists on startup. + * `follow_symlink`: If `True`, follows symbolic links. + + ## Performance Note + + The `auth` callable runs on **every static file request** (CSS, JS, + images, etc.). Prefer lightweight checks (header presence, JWT signature + verification) over expensive operations (database lookups) to avoid + slowing down page loads. + + Ref: https://github.com/fastapi/fastapi/issues/858 + """ + + def __init__( + self, + *, + directory: str | None = None, + packages: list[str | tuple[str, str]] | None = None, + html: bool = False, + check_dir: bool = True, + follow_symlink: bool = False, + auth: Callable[[Request], Awaitable[Any]], + on_error: Callable[[Request, Any], Awaitable[Response]] | None = None, + ) -> None: + super().__init__( + directory=directory, + packages=packages, + html=html, + check_dir=check_dir, + follow_symlink=follow_symlink, + ) + self.auth = auth + self.on_error = on_error + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + request = Request(scope, receive) + try: + await self.auth(request) + except Exception as exc: + from fastapi.exceptions import HTTPException + + if isinstance(exc, HTTPException): + if self.on_error is not None: + response = await self.on_error(request, exc) + else: + response = PlainTextResponse( + str(exc.detail), + status_code=exc.status_code, + headers=getattr(exc, "headers", None), + ) + await response(scope, receive, send) + return + raise + await super().__call__(scope, receive, send) diff --git a/tests/test_auth_static_files.py b/tests/test_auth_static_files.py new file mode 100644 index 0000000000000..10be23affd8fe --- /dev/null +++ b/tests/test_auth_static_files.py @@ -0,0 +1,208 @@ +import pytest +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import RedirectResponse +from fastapi.staticfiles import AuthStaticFiles +from fastapi.testclient import TestClient +from starlette.responses import HTMLResponse, Response + + +@pytest.fixture(scope="module") +def static_dir(tmp_path_factory): + d = tmp_path_factory.mktemp("static") + (d / "public.txt").write_text("public content") + (d / "secret.txt").write_text("secret content") + return d + + +async def verify_token(request: Request) -> None: + """Simple token-based auth for testing.""" + token = request.headers.get("Authorization") + if token != "Bearer valid-token": + raise HTTPException(status_code=401, detail="Not authenticated") + + +async def _allow_all(request: Request) -> None: + """Auth function that allows all requests.""" + pass + + +@pytest.fixture(scope="module") +def app(static_dir): + app = FastAPI() + + # Public static files (no auth) + app.mount( + "/public", + AuthStaticFiles( + directory=str(static_dir), + auth=_allow_all, + ), + name="public", + ) + + # Private static files (requires auth) + app.mount( + "/private", + AuthStaticFiles( + directory=str(static_dir), + auth=verify_token, + ), + name="private", + ) + + return app + + +@pytest.fixture(scope="module") +def client(app): + with TestClient(app) as c: + yield c + + +def test_private_file_without_auth(client: TestClient): + """Requesting a private file without auth should return 401.""" + response = client.get("/private/secret.txt") + assert response.status_code == 401 + assert response.text == "Not authenticated" + + +def test_private_file_with_wrong_token(client: TestClient): + """Requesting a private file with wrong token should return 401.""" + response = client.get( + "/private/secret.txt", + headers={"Authorization": "Bearer wrong-token"}, + ) + assert response.status_code == 401 + assert response.text == "Not authenticated" + + +def test_private_file_with_valid_token(client: TestClient): + """Requesting a private file with valid token should return the file.""" + response = client.get( + "/private/secret.txt", + headers={"Authorization": "Bearer valid-token"}, + ) + assert response.status_code == 200 + assert response.text == "secret content" + + +def test_private_file_not_found_with_valid_token(client: TestClient): + """Requesting a non-existent private file with valid auth should return 404.""" + response = client.get( + "/private/nonexistent.txt", + headers={"Authorization": "Bearer valid-token"}, + ) + assert response.status_code == 404 + + +def test_public_files_accessible(client: TestClient): + """Public mount with allow-all auth should serve files without auth.""" + response = client.get("/public/public.txt") + assert response.status_code == 200 + assert response.text == "public content" + + +def test_auth_headers_forwarded(static_dir): + """Auth errors with custom headers should forward them in the response.""" + + async def auth_with_headers(request: Request) -> None: + raise HTTPException( + status_code=401, + detail="Login required", + headers={"WWW-Authenticate": "Bearer"}, + ) + + app = FastAPI() + app.mount( + "/protected", + AuthStaticFiles(directory=str(static_dir), auth=auth_with_headers), + name="protected", + ) + + with TestClient(app) as client: + response = client.get("/protected/public.txt") + assert response.status_code == 401 + assert response.headers["WWW-Authenticate"] == "Bearer" + assert response.text == "Login required" + + +def test_cookie_based_auth(static_dir): + """AuthStaticFiles should work with cookie-based authentication.""" + + async def verify_cookie(request: Request) -> None: + session = request.cookies.get("session_id") + if session != "valid-session": + raise HTTPException(status_code=403, detail="Forbidden") + + app = FastAPI() + app.mount( + "/dashboard", + AuthStaticFiles(directory=str(static_dir), auth=verify_cookie), + name="dashboard", + ) + + with TestClient(app) as client: + # Without cookie + response = client.get("/dashboard/public.txt") + assert response.status_code == 403 + + # With valid cookie + client.cookies.set("session_id", "valid-session") + response = client.get("/dashboard/public.txt") + assert response.status_code == 200 + assert response.text == "public content" + + +def test_custom_on_error_redirect(static_dir): + """on_error can redirect to a login page.""" + + async def deny_all(request: Request) -> None: + raise HTTPException(status_code=401, detail="Unauthorized") + + async def redirect_to_login(request: Request, exc: HTTPException) -> Response: + return RedirectResponse(url="/login", status_code=302) + + app = FastAPI() + app.mount( + "/protected", + AuthStaticFiles( + directory=str(static_dir), + auth=deny_all, + on_error=redirect_to_login, + ), + name="protected", + ) + + with TestClient(app, follow_redirects=False) as client: + response = client.get("/protected/public.txt") + assert response.status_code == 302 + assert response.headers["location"] == "/login" + + +def test_custom_on_error_html(static_dir): + """on_error can return a custom HTML error page.""" + + async def deny_all(request: Request) -> None: + raise HTTPException(status_code=403, detail="Forbidden") + + async def html_error(request: Request, exc: HTTPException) -> Response: + return HTMLResponse( + f"

{exc.status_code} {exc.detail}

", + status_code=exc.status_code, + ) + + app = FastAPI() + app.mount( + "/protected", + AuthStaticFiles( + directory=str(static_dir), + auth=deny_all, + on_error=html_error, + ), + name="protected", + ) + + with TestClient(app) as client: + response = client.get("/protected/public.txt") + assert response.status_code == 403 + assert "

403 Forbidden

" in response.text