-
Notifications
You must be signed in to change notification settings - Fork 0
Add AuthStaticFiles for auth on static file serving #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| from fastapi import FastAPI, HTTPException, Request | ||
| from fastapi.staticfiles import AuthStaticFiles | ||
|
|
||
|
|
||
| async def verify_token(request: Request) -> None: | ||
| """Check for a valid Bearer token in the Authorization header.""" | ||
| token = request.headers.get("Authorization") | ||
| if token != "Bearer mysecrettoken": | ||
| raise HTTPException(status_code=401, detail="Not authenticated") | ||
|
|
||
|
|
||
| app = FastAPI() | ||
|
|
||
| # Private files - requires a valid token to access | ||
| app.mount( | ||
| "/private", | ||
| AuthStaticFiles(directory="private_files", auth=verify_token), | ||
| name="private", | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,109 @@ | ||
| from collections.abc import Awaitable, Callable | ||
| from typing import Any | ||
|
|
||
| from starlette.requests import Request | ||
| from starlette.responses import PlainTextResponse, Response | ||
| from starlette.staticfiles import StaticFiles as StaticFiles # noqa | ||
| from starlette.types import Receive, Scope, Send | ||
|
|
||
|
|
||
| class AuthStaticFiles(StaticFiles): | ||
| """ | ||
| A static files handler that requires authentication before serving files. | ||
|
|
||
| This solves the problem where `app.mount("/static", StaticFiles(...))` serves | ||
| files without any authentication, making it impossible to protect private files. | ||
|
|
||
| `AuthStaticFiles` accepts an `auth` callable that receives a `Request` and | ||
| should either return successfully (authenticated) or raise an `HTTPException` | ||
| (not authenticated). | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| from fastapi import FastAPI, HTTPException, Request | ||
| from fastapi.staticfiles import AuthStaticFiles | ||
|
|
||
| app = FastAPI() | ||
|
|
||
|
|
||
| async def verify_token(request: Request) -> None: | ||
| token = request.headers.get("Authorization") | ||
| if token != "Bearer mysecrettoken": | ||
| raise HTTPException(status_code=401, detail="Unauthorized") | ||
|
|
||
|
|
||
| app.mount( | ||
| "/private", | ||
| AuthStaticFiles(directory="private_files", auth=verify_token), | ||
| name="private", | ||
| ) | ||
| ``` | ||
|
|
||
| ## Parameters | ||
|
|
||
| * `auth`: An async callable that takes a `Request` object and performs | ||
| authentication. It should raise an `HTTPException` if authentication | ||
| fails, or return `None` if authentication succeeds. | ||
| * `on_error`: An optional callable that takes a `Request` and an | ||
| `HTTPException` and returns a `Response`. Use this to customize | ||
| error responses (e.g., redirect to login, return HTML instead of | ||
| plain text). If not provided, a plain text error response is returned. | ||
| * `directory`: The directory to serve files from. | ||
| * `packages`: A list of Python packages to serve files from. | ||
| * `html`: If `True`, serves `index.html` files for directories. | ||
| * `check_dir`: If `True`, checks that the directory exists on startup. | ||
| * `follow_symlink`: If `True`, follows symbolic links. | ||
|
|
||
| ## Performance Note | ||
|
|
||
| The `auth` callable runs on **every static file request** (CSS, JS, | ||
| images, etc.). Prefer lightweight checks (header presence, JWT signature | ||
| verification) over expensive operations (database lookups) to avoid | ||
| slowing down page loads. | ||
|
|
||
| Ref: https://github.com/fastapi/fastapi/issues/858 | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| directory: str | None = None, | ||
| packages: list[str | tuple[str, str]] | None = None, | ||
| html: bool = False, | ||
| check_dir: bool = True, | ||
| follow_symlink: bool = False, | ||
| auth: Callable[[Request], Awaitable[Any]], | ||
| on_error: Callable[[Request, Any], Awaitable[Response]] | None = None, | ||
| ) -> None: | ||
| super().__init__( | ||
| directory=directory, | ||
| packages=packages, | ||
| html=html, | ||
| check_dir=check_dir, | ||
| follow_symlink=follow_symlink, | ||
| ) | ||
| self.auth = auth | ||
| self.on_error = on_error | ||
|
|
||
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
| if scope["type"] == "http": | ||
| request = Request(scope, receive) | ||
| try: | ||
| await self.auth(request) | ||
| except Exception as exc: | ||
| from fastapi.exceptions import HTTPException | ||
|
|
||
| if isinstance(exc, HTTPException): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isinstance only catches fastapi.HTTPException, misses starlette.HTTPException The except block imports Import from starlette instead: |
||
| if self.on_error is not None: | ||
| response = await self.on_error(request, exc) | ||
| else: | ||
| response = PlainTextResponse( | ||
| str(exc.detail), | ||
| status_code=exc.status_code, | ||
| headers=getattr(exc, "headers", None), | ||
| ) | ||
| await response(scope, receive, send) | ||
| return | ||
| raise | ||
| await super().__call__(scope, receive, send) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,208 @@ | ||
| import pytest | ||
| from fastapi import FastAPI, HTTPException, Request | ||
| from fastapi.responses import RedirectResponse | ||
| from fastapi.staticfiles import AuthStaticFiles | ||
| from fastapi.testclient import TestClient | ||
| from starlette.responses import HTMLResponse, Response | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def static_dir(tmp_path_factory): | ||
| d = tmp_path_factory.mktemp("static") | ||
| (d / "public.txt").write_text("public content") | ||
| (d / "secret.txt").write_text("secret content") | ||
| return d | ||
|
|
||
|
|
||
| async def verify_token(request: Request) -> None: | ||
| """Simple token-based auth for testing.""" | ||
| token = request.headers.get("Authorization") | ||
| if token != "Bearer valid-token": | ||
| raise HTTPException(status_code=401, detail="Not authenticated") | ||
|
|
||
|
|
||
| async def _allow_all(request: Request) -> None: | ||
| """Auth function that allows all requests.""" | ||
| pass | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def app(static_dir): | ||
| app = FastAPI() | ||
|
|
||
| # Public static files (no auth) | ||
| app.mount( | ||
| "/public", | ||
| AuthStaticFiles( | ||
| directory=str(static_dir), | ||
| auth=_allow_all, | ||
| ), | ||
| name="public", | ||
| ) | ||
|
|
||
| # Private static files (requires auth) | ||
| app.mount( | ||
| "/private", | ||
| AuthStaticFiles( | ||
| directory=str(static_dir), | ||
| auth=verify_token, | ||
| ), | ||
| name="private", | ||
| ) | ||
|
|
||
| return app | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def client(app): | ||
| with TestClient(app) as c: | ||
| yield c | ||
|
|
||
|
|
||
| def test_private_file_without_auth(client: TestClient): | ||
| """Requesting a private file without auth should return 401.""" | ||
| response = client.get("/private/secret.txt") | ||
| assert response.status_code == 401 | ||
| assert response.text == "Not authenticated" | ||
|
|
||
|
|
||
| def test_private_file_with_wrong_token(client: TestClient): | ||
| """Requesting a private file with wrong token should return 401.""" | ||
| response = client.get( | ||
| "/private/secret.txt", | ||
| headers={"Authorization": "Bearer wrong-token"}, | ||
| ) | ||
| assert response.status_code == 401 | ||
| assert response.text == "Not authenticated" | ||
|
|
||
|
|
||
| def test_private_file_with_valid_token(client: TestClient): | ||
| """Requesting a private file with valid token should return the file.""" | ||
| response = client.get( | ||
| "/private/secret.txt", | ||
| headers={"Authorization": "Bearer valid-token"}, | ||
| ) | ||
| assert response.status_code == 200 | ||
| assert response.text == "secret content" | ||
|
|
||
|
|
||
| def test_private_file_not_found_with_valid_token(client: TestClient): | ||
| """Requesting a non-existent private file with valid auth should return 404.""" | ||
| response = client.get( | ||
| "/private/nonexistent.txt", | ||
| headers={"Authorization": "Bearer valid-token"}, | ||
| ) | ||
| assert response.status_code == 404 | ||
|
|
||
|
|
||
| def test_public_files_accessible(client: TestClient): | ||
| """Public mount with allow-all auth should serve files without auth.""" | ||
| response = client.get("/public/public.txt") | ||
| assert response.status_code == 200 | ||
| assert response.text == "public content" | ||
|
|
||
|
|
||
| def test_auth_headers_forwarded(static_dir): | ||
| """Auth errors with custom headers should forward them in the response.""" | ||
|
|
||
| async def auth_with_headers(request: Request) -> None: | ||
| raise HTTPException( | ||
| status_code=401, | ||
| detail="Login required", | ||
| headers={"WWW-Authenticate": "Bearer"}, | ||
| ) | ||
|
|
||
| app = FastAPI() | ||
| app.mount( | ||
| "/protected", | ||
| AuthStaticFiles(directory=str(static_dir), auth=auth_with_headers), | ||
| name="protected", | ||
| ) | ||
|
|
||
| with TestClient(app) as client: | ||
| response = client.get("/protected/public.txt") | ||
| assert response.status_code == 401 | ||
| assert response.headers["WWW-Authenticate"] == "Bearer" | ||
| assert response.text == "Login required" | ||
|
|
||
|
|
||
| def test_cookie_based_auth(static_dir): | ||
| """AuthStaticFiles should work with cookie-based authentication.""" | ||
|
|
||
| async def verify_cookie(request: Request) -> None: | ||
| session = request.cookies.get("session_id") | ||
| if session != "valid-session": | ||
| raise HTTPException(status_code=403, detail="Forbidden") | ||
|
|
||
| app = FastAPI() | ||
| app.mount( | ||
| "/dashboard", | ||
| AuthStaticFiles(directory=str(static_dir), auth=verify_cookie), | ||
| name="dashboard", | ||
| ) | ||
|
|
||
| with TestClient(app) as client: | ||
| # Without cookie | ||
| response = client.get("/dashboard/public.txt") | ||
| assert response.status_code == 403 | ||
|
|
||
| # With valid cookie | ||
| client.cookies.set("session_id", "valid-session") | ||
| response = client.get("/dashboard/public.txt") | ||
| assert response.status_code == 200 | ||
| assert response.text == "public content" | ||
|
|
||
|
|
||
| def test_custom_on_error_redirect(static_dir): | ||
| """on_error can redirect to a login page.""" | ||
|
|
||
| async def deny_all(request: Request) -> None: | ||
| raise HTTPException(status_code=401, detail="Unauthorized") | ||
|
|
||
| async def redirect_to_login(request: Request, exc: HTTPException) -> Response: | ||
| return RedirectResponse(url="/login", status_code=302) | ||
|
|
||
| app = FastAPI() | ||
| app.mount( | ||
| "/protected", | ||
| AuthStaticFiles( | ||
| directory=str(static_dir), | ||
| auth=deny_all, | ||
| on_error=redirect_to_login, | ||
| ), | ||
| name="protected", | ||
| ) | ||
|
|
||
| with TestClient(app, follow_redirects=False) as client: | ||
| response = client.get("/protected/public.txt") | ||
| assert response.status_code == 302 | ||
| assert response.headers["location"] == "/login" | ||
|
|
||
|
|
||
| def test_custom_on_error_html(static_dir): | ||
| """on_error can return a custom HTML error page.""" | ||
|
|
||
| async def deny_all(request: Request) -> None: | ||
| raise HTTPException(status_code=403, detail="Forbidden") | ||
|
|
||
| async def html_error(request: Request, exc: HTTPException) -> Response: | ||
| return HTMLResponse( | ||
| f"<h1>{exc.status_code} {exc.detail}</h1>", | ||
| status_code=exc.status_code, | ||
| ) | ||
|
|
||
| app = FastAPI() | ||
| app.mount( | ||
| "/protected", | ||
| AuthStaticFiles( | ||
| directory=str(static_dir), | ||
| auth=deny_all, | ||
| on_error=html_error, | ||
| ), | ||
| name="protected", | ||
| ) | ||
|
|
||
| with TestClient(app) as client: | ||
| response = client.get("/protected/public.txt") | ||
| assert response.status_code == 403 | ||
| assert "<h1>403 Forbidden</h1>" in response.text |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sync auth callables will crash at runtime despite no enforcement
The
authparameter is typed asCallable[[Request], Awaitable[Any]]but Python doesn't enforce this at runtime. If a user passes a sync function (returningNoneor a value),await self.auth(request)at line 93 will raiseTypeError: object ... can't be used in 'await' expression. Consider wrapping withasyncio.iscoroutinefunction()check or usingrun_in_threadpoolfor sync callables, similar to how FastAPI handles sync dependencies.Either detect sync callables and wrap them with
run_in_threadpool, or add a runtime check in__init__that raises a clear error if the callable isn't async.