diff --git a/src/fastapi_cli/cli.py b/src/fastapi_cli/cli.py index 77647217..b7b8736b 100644 --- a/src/fastapi_cli/cli.py +++ b/src/fastapi_cli/cli.py @@ -96,6 +96,7 @@ def _run( command: str, app: Union[str, None] = None, proxy_headers: bool = False, + forwarded_allow_ips: Union[str, None] = None, ) -> None: with get_rich_toolkit() as toolkit: server_type = "development" if command == "dev" else "production" @@ -177,6 +178,7 @@ def _run( workers=workers, root_path=root_path, proxy_headers=proxy_headers, + forwarded_allow_ips=forwarded_allow_ips, log_config=get_uvicorn_log_config(), ) @@ -226,6 +228,12 @@ def dev( help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to populate remote address info." ), ] = True, + forwarded_allow_ips: Annotated[ + Union[str, None], + typer.Option( + help="Comma separated list of IP Addresses to trust with proxy headers. The literal '*' means trust everything." + ), + ] = None, ) -> Any: """ Run a [bold]FastAPI[/bold] app in [yellow]development[/yellow] mode. ๐Ÿงช @@ -261,6 +269,7 @@ def dev( app=app, command="dev", proxy_headers=proxy_headers, + forwarded_allow_ips=forwarded_allow_ips, ) @@ -315,6 +324,12 @@ def run( help="Enable/Disable X-Forwarded-Proto, X-Forwarded-For, X-Forwarded-Port to populate remote address info." ), ] = True, + forwarded_allow_ips: Annotated[ + Union[str, None], + typer.Option( + help="Comma separated list of IP Addresses to trust with proxy headers. The literal '*' means trust everything." + ), + ] = None, ) -> Any: """ Run a [bold]FastAPI[/bold] app in [green]production[/green] mode. ๐Ÿš€ @@ -351,6 +366,7 @@ def run( app=app, command="run", proxy_headers=proxy_headers, + forwarded_allow_ips=forwarded_allow_ips, ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 8bdba1c7..b6dc9671 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -30,6 +30,7 @@ def test_dev() -> None: "workers": None, "root_path": "", "proxy_headers": True, + "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), } assert "Using import string: single_file_app:app" in result.output @@ -59,6 +60,7 @@ def test_dev_package() -> None: "workers": None, "root_path": "", "proxy_headers": True, + "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), } assert "Using import string: nested_package.package:app" in result.output @@ -107,6 +109,7 @@ def test_dev_args() -> None: "workers": None, "root_path": "/api", "proxy_headers": False, + "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), } assert "Using import string: single_file_app:api" in result.output @@ -134,6 +137,33 @@ def test_run() -> None: "workers": None, "root_path": "", "proxy_headers": True, + "forwarded_allow_ips": None, + "log_config": get_uvicorn_log_config(), + } + assert "Using import string: single_file_app:app" in result.output + assert "Starting production server ๐Ÿš€" in result.output + assert "Server started at http://0.0.0.0:8000" in result.output + assert "Documentation at http://0.0.0.0:8000/docs" in result.output + + +def test_run_trust_proxy() -> None: + with changing_dir(assets_path): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke( + app, ["run", "single_file_app.py", "--forwarded-allow-ips", "*"] + ) + assert result.exit_code == 0, result.output + assert mock_run.called + assert mock_run.call_args + assert mock_run.call_args.kwargs == { + "app": "single_file_app:app", + "host": "0.0.0.0", + "port": 8000, + "reload": False, + "workers": None, + "root_path": "", + "proxy_headers": True, + "forwarded_allow_ips": "*", "log_config": get_uvicorn_log_config(), } assert "Using import string: single_file_app:app" in result.output @@ -179,6 +209,7 @@ def test_run_args() -> None: "workers": 2, "root_path": "/api", "proxy_headers": False, + "forwarded_allow_ips": None, "log_config": get_uvicorn_log_config(), }