diff --git a/app.py b/app.py index 514c4bc..0da6a97 100644 --- a/app.py +++ b/app.py @@ -270,6 +270,9 @@ def _configure_all_cli_auth(token): """ import json + from utils import resolve_and_cache_gateway + resolve_and_cache_gateway() + home = os.environ.get("HOME", "/app/python/source_code") if not home or home == "/": home = "/app/python/source_code" @@ -332,6 +335,10 @@ def run_setup(): setup_state["status"] = "running" setup_state["started_at"] = time.time() + # Probe AI Gateway once; result is cached in _GATEWAY_RESOLVED for subprocesses + from utils import resolve_and_cache_gateway + resolve_and_cache_gateway() + # --- Sequential prerequisites (git identity + editor) --- # Git config — done directly in Python, not as a subprocess _update_step("git", status="running", started_at=time.time()) @@ -391,7 +398,7 @@ def get_token_owner(): try: w = WorkspaceClient() # auto-detects SP credentials app = w.apps.get(name=app_name) - owner = app.creator + owner = (app.creator or "").lower() logger.info(f"Owner resolved from app.creator: {owner}") return owner except Exception as e: @@ -404,17 +411,24 @@ def get_token_owner(): if not host or not token: return None w = WorkspaceClient(host=host, token=token, auth_type="pat") - return w.current_user.me().user_name + username = w.current_user.me().user_name + return username.lower() if username else username except Exception as e: logger.warning(f"Could not determine token owner: {e}") return None def get_request_user(): - """Extract user email from Databricks Apps request headers.""" - return request.headers.get("X-Forwarded-Email") or \ - request.headers.get("X-Forwarded-User") or \ - request.headers.get("X-Databricks-User-Email") + """Extract user email from Databricks Apps request headers. + + Returns lowercase email to ensure case-insensitive matching against app_owner. + """ + email = ( + request.headers.get("X-Forwarded-Email") + or request.headers.get("X-Forwarded-User") + or request.headers.get("X-Databricks-User-Email") + ) + return email.lower() if email else email def _is_databricks_apps(): @@ -467,9 +481,12 @@ def _check_ws_authorization(): return True # Local dev only # Socket.IO passes HTTP headers from the initial handshake via request context - current_user = request.headers.get("X-Forwarded-Email") or \ - request.headers.get("X-Forwarded-User") or \ - request.headers.get("X-Databricks-User-Email") + raw_user = ( + request.headers.get("X-Forwarded-Email") + or request.headers.get("X-Forwarded-User") + or request.headers.get("X-Databricks-User-Email") + ) + current_user = raw_user.lower() if raw_user else raw_user if not current_user: if _is_databricks_apps(): diff --git a/tests/test_gateway_discovery.py b/tests/test_gateway_discovery.py index d27a951..698445a 100644 --- a/tests/test_gateway_discovery.py +++ b/tests/test_gateway_discovery.py @@ -14,18 +14,46 @@ class TestGetGatewayHost: - """Test the 3-tier priority logic in get_gateway_host().""" + """Test the 4-tier priority logic in get_gateway_host().""" def _get_fn(self): from utils import get_gateway_host return get_gateway_host + # -- Tier 0: _GATEWAY_RESOLVED cache -- + + @mock.patch.dict(os.environ, { + "_GATEWAY_RESOLVED": "https://cached.gateway.com", + "DATABRICKS_GATEWAY_HOST": "https://explicit.gateway.com", + }) + def test_resolved_cache_wins_over_explicit(self): + """Tier 0: _GATEWAY_RESOLVED takes priority over explicit DATABRICKS_GATEWAY_HOST.""" + assert self._get_fn()() == "https://cached.gateway.com" + + @mock.patch.dict(os.environ, {"_GATEWAY_RESOLVED": ""}, clear=True) + def test_resolved_empty_returns_empty(self): + """Tier 0: _GATEWAY_RESOLVED='' means 'probed, no gateway' — returns empty.""" + assert self._get_fn()() == "" + + @mock.patch.dict(os.environ, { + "_GATEWAY_RESOLVED": "https://cached.gw.com", + "DATABRICKS_WORKSPACE_ID": "12345", + }) + @mock.patch("utils._probe_gateway") + def test_resolved_skips_probe(self, mock_probe): + """Tier 0: when _GATEWAY_RESOLVED is set, probe is never called.""" + self._get_fn()() + mock_probe.assert_not_called() + + # -- Tier 1: explicit DATABRICKS_GATEWAY_HOST -- + @mock.patch.dict(os.environ, { "DATABRICKS_GATEWAY_HOST": "https://custom.gateway.com", "DATABRICKS_WORKSPACE_ID": "12345", }) def test_explicit_override_wins(self): """Tier 1: explicit DATABRICKS_GATEWAY_HOST takes priority over workspace ID.""" + os.environ.pop("_GATEWAY_RESOLVED", None) assert self._get_fn()() == "https://custom.gateway.com" @mock.patch.dict(os.environ, { @@ -34,6 +62,7 @@ def test_explicit_override_wins(self): }) def test_explicit_override_gets_https(self): """Tier 1: explicit value without https:// gets it added.""" + os.environ.pop("_GATEWAY_RESOLVED", None) assert self._get_fn()() == "https://custom.gateway.com" @mock.patch.dict(os.environ, { @@ -42,16 +71,36 @@ def test_explicit_override_gets_https(self): }) def test_explicit_override_trailing_slash_stripped(self): """Tier 1: trailing slash is stripped from explicit value.""" + os.environ.pop("_GATEWAY_RESOLVED", None) assert self._get_fn()() == "https://custom.gateway.com" + # -- Tier 2: auto-construct from DATABRICKS_WORKSPACE_ID -- + + @mock.patch("utils._probe_gateway", return_value=True) @mock.patch.dict(os.environ, {"DATABRICKS_WORKSPACE_ID": "6280049833385130"}, clear=False) - def test_auto_construct_from_workspace_id(self): - """Tier 2: construct gateway URL from DATABRICKS_WORKSPACE_ID.""" + def test_auto_construct_from_workspace_id(self, mock_probe): + """Tier 2: construct gateway URL from DATABRICKS_WORKSPACE_ID when reachable.""" env = os.environ.copy() env.pop("DATABRICKS_GATEWAY_HOST", None) + env.pop("_GATEWAY_RESOLVED", None) with mock.patch.dict(os.environ, env, clear=True): result = self._get_fn()() assert result == "https://6280049833385130.ai-gateway.cloud.databricks.com" + mock_probe.assert_called_once() + + @mock.patch("utils._probe_gateway", return_value=False) + @mock.patch.dict(os.environ, {"DATABRICKS_WORKSPACE_ID": "6280049833385130"}, clear=False) + def test_auto_construct_falls_back_when_unreachable(self, mock_probe): + """Tier 2 fallback: returns empty when auto-discovered gateway is unreachable.""" + env = os.environ.copy() + env.pop("DATABRICKS_GATEWAY_HOST", None) + env.pop("_GATEWAY_RESOLVED", None) + with mock.patch.dict(os.environ, env, clear=True): + result = self._get_fn()() + assert result == "" + mock_probe.assert_called_once() + + # -- Tier 3: nothing available -- @mock.patch.dict(os.environ, {}, clear=True) def test_empty_when_nothing_set(self): @@ -61,16 +110,21 @@ def test_empty_when_nothing_set(self): @mock.patch.dict(os.environ, {"DATABRICKS_GATEWAY_HOST": "", "DATABRICKS_WORKSPACE_ID": ""}) def test_empty_when_both_blank(self): """Tier 3: returns empty when both vars are set but blank.""" + os.environ.pop("_GATEWAY_RESOLVED", None) assert self._get_fn()() == "" + @mock.patch("utils._probe_gateway", return_value=True) @mock.patch.dict(os.environ, {"DATABRICKS_GATEWAY_HOST": " ", "DATABRICKS_WORKSPACE_ID": "12345"}) - def test_whitespace_only_gateway_falls_through(self): + def test_whitespace_only_gateway_falls_through(self, mock_probe): """Whitespace-only DATABRICKS_GATEWAY_HOST falls through to workspace ID.""" + os.environ.pop("_GATEWAY_RESOLVED", None) assert self._get_fn()() == "https://12345.ai-gateway.cloud.databricks.com" + @mock.patch("utils._probe_gateway", return_value=True) @mock.patch.dict(os.environ, {"DATABRICKS_GATEWAY_HOST": "", "DATABRICKS_WORKSPACE_ID": " 99999 "}) - def test_workspace_id_whitespace_stripped(self): + def test_workspace_id_whitespace_stripped(self, mock_probe): """Leading/trailing whitespace in workspace ID is stripped.""" + os.environ.pop("_GATEWAY_RESOLVED", None) assert self._get_fn()() == "https://99999.ai-gateway.cloud.databricks.com" @@ -93,6 +147,8 @@ def _run_setup(self, script_name, tmp_path, env_overrides=None): "DATABRICKS_WORKSPACE_ID": "6280049833385130", "PATH": os.environ.get("PATH", ""), "PYTHONPATH": str(SETUP_DIR), + # Pre-resolve gateway so subprocess skips the network probe + "_GATEWAY_RESOLVED": "", } # Ensure DATABRICKS_GATEWAY_HOST is NOT set (test auto-discovery) env.pop("DATABRICKS_GATEWAY_HOST", None) @@ -107,29 +163,34 @@ def _run_setup(self, script_name, tmp_path, env_overrides=None): env=env, capture_output=True, text=True, - timeout=30, + timeout=60, ) return result - def test_setup_claude_uses_gateway(self, tmp_path): - """setup_claude.py should use auto-discovered gateway for anthropic URL.""" + def test_setup_claude_falls_back_when_gateway_unreachable(self, tmp_path): + """setup_claude.py should fall back to serving-endpoints when gateway probe fails.""" result = self._run_setup("setup_claude.py", tmp_path) assert result.returncode == 0, f"stderr: {result.stderr}" - assert "AI Gateway" in result.stdout or "6280049833385130" in result.stdout - # Verify settings.json has gateway-based URL + # Gateway is unreachable from test env, so should fall back import json settings_path = tmp_path / ".claude" / "settings.json" if settings_path.exists(): settings = json.loads(settings_path.read_text()) base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "") - assert "6280049833385130.ai-gateway.cloud.databricks.com" in base_url assert base_url.endswith("/anthropic") + # Either gateway or serving-endpoints is valid + assert ( + "ai-gateway.cloud.databricks.com" in base_url + or "serving-endpoints/anthropic" in base_url + ) def test_setup_claude_explicit_override(self, tmp_path): """setup_claude.py should prefer explicit DATABRICKS_GATEWAY_HOST.""" result = self._run_setup("setup_claude.py", tmp_path, { "DATABRICKS_GATEWAY_HOST": "https://custom.gateway.example.com", + # Simulate parent having resolved to the explicit gateway + "_GATEWAY_RESOLVED": "https://custom.gateway.example.com", }) assert result.returncode == 0, f"stderr: {result.stderr}" @@ -154,7 +215,8 @@ def test_setup_claude_fallback_no_gateway(self, tmp_path): base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "") assert "test.cloud.databricks.com/serving-endpoints/anthropic" in base_url - def test_codex_gateway_url_construction(self): + @mock.patch("utils._probe_gateway", return_value=True) + def test_codex_gateway_url_construction(self, mock_probe): """Codex endpoint should use gateway /openai/v1 path.""" from utils import get_gateway_host with mock.patch.dict(os.environ, { @@ -162,12 +224,14 @@ def test_codex_gateway_url_construction(self): }, clear=False): env = os.environ.copy() env.pop("DATABRICKS_GATEWAY_HOST", None) + env.pop("_GATEWAY_RESOLVED", None) with mock.patch.dict(os.environ, env, clear=True): gw = get_gateway_host() codex_url = f"{gw}/openai/v1" assert codex_url == "https://6280049833385130.ai-gateway.cloud.databricks.com/openai/v1" - def test_gemini_gateway_url_construction(self): + @mock.patch("utils._probe_gateway", return_value=True) + def test_gemini_gateway_url_construction(self, mock_probe): """Gemini endpoint should use gateway /gemini path.""" from utils import get_gateway_host with mock.patch.dict(os.environ, { @@ -175,12 +239,14 @@ def test_gemini_gateway_url_construction(self): }, clear=False): env = os.environ.copy() env.pop("DATABRICKS_GATEWAY_HOST", None) + env.pop("_GATEWAY_RESOLVED", None) with mock.patch.dict(os.environ, env, clear=True): gw = get_gateway_host() gemini_url = f"{gw}/gemini" assert gemini_url == "https://6280049833385130.ai-gateway.cloud.databricks.com/gemini" - def test_anthropic_gateway_url_construction(self): + @mock.patch("utils._probe_gateway", return_value=True) + def test_anthropic_gateway_url_construction(self, mock_probe): """Anthropic endpoint should use gateway /anthropic path.""" from utils import get_gateway_host with mock.patch.dict(os.environ, { @@ -188,12 +254,14 @@ def test_anthropic_gateway_url_construction(self): }, clear=False): env = os.environ.copy() env.pop("DATABRICKS_GATEWAY_HOST", None) + env.pop("_GATEWAY_RESOLVED", None) with mock.patch.dict(os.environ, env, clear=True): gw = get_gateway_host() anthropic_url = f"{gw}/anthropic" assert anthropic_url == "https://6280049833385130.ai-gateway.cloud.databricks.com/anthropic" - def test_proxy_gateway_url_construction(self): + @mock.patch("utils._probe_gateway", return_value=True) + def test_proxy_gateway_url_construction(self, mock_probe): """Proxy endpoint should use gateway /mlflow/v1 path.""" from utils import get_gateway_host with mock.patch.dict(os.environ, { @@ -201,6 +269,7 @@ def test_proxy_gateway_url_construction(self): }, clear=False): env = os.environ.copy() env.pop("DATABRICKS_GATEWAY_HOST", None) + env.pop("_GATEWAY_RESOLVED", None) with mock.patch.dict(os.environ, env, clear=True): gw = get_gateway_host() proxy_url = f"{gw}/mlflow/v1" diff --git a/utils.py b/utils.py index 3e6c103..94237bf 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,7 @@ """Shared utilities for Databricks App setup scripts.""" +from __future__ import annotations + import os import re import subprocess @@ -61,25 +63,69 @@ def adapt_instructions_file( return True +def _probe_gateway(url: str, timeout: float = 2.0) -> bool: + """Quick connectivity check against an AI Gateway host. + + Sends a lightweight GET to the root. Any HTTP response (even 401/404) + means the host exists. Only a connection failure means it doesn't. + Timeout is 2s — the gateway is same-region, so it responds fast if it exists. + """ + import requests + + try: + requests.get(url, timeout=timeout, allow_redirects=False) + return True + except (requests.ConnectionError, requests.Timeout): + return False + except Exception: + return False + + def get_gateway_host() -> str: """Resolve the AI Gateway host URL. Priority: - 1. Explicit DATABRICKS_GATEWAY_HOST env var (override) - 2. Auto-constructed from DATABRICKS_WORKSPACE_ID + 0. _GATEWAY_RESOLVED env var (set by parent process after probing — avoids + re-probing in subprocesses). None = never probed, "" = probed, no gateway. + 1. Explicit DATABRICKS_GATEWAY_HOST env var (trusted — no probe) + 2. Auto-constructed from DATABRICKS_WORKSPACE_ID (probed for reachability) 3. Empty string (caller falls back to DATABRICKS_HOST/serving-endpoints) """ + # Tier 0: already resolved by a parent process + resolved = os.environ.get("_GATEWAY_RESOLVED") + if resolved is not None: + return resolved + + # Tier 1: explicit override (trusted, no probe) explicit = os.environ.get("DATABRICKS_GATEWAY_HOST", "").strip().rstrip("/") if explicit: return ensure_https(explicit) + # Tier 2: auto-construct from workspace ID and probe for reachability workspace_id = os.environ.get("DATABRICKS_WORKSPACE_ID", "").strip() if workspace_id: - return f"https://{workspace_id}.ai-gateway.cloud.databricks.com" + candidate = f"https://{workspace_id}.ai-gateway.cloud.databricks.com" + if _probe_gateway(candidate): + return candidate + print( + f"AI Gateway not reachable at {candidate}, " + "falling back to serving-endpoints" + ) return "" +def resolve_and_cache_gateway() -> str: + """Probe the gateway once and cache the result in the environment. + + Subsequent calls to get_gateway_host() — including those in child + processes — will see _GATEWAY_RESOLVED and skip the probe. + """ + result = get_gateway_host() + os.environ["_GATEWAY_RESOLVED"] = result + return result + + def ensure_https(url: str) -> str: """Ensure a URL has the https:// prefix.