Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def _configure_all_cli_auth(token):
"""
import json

from utils import resolve_and_cache_gateway
resolve_and_cache_gateway()

home = os.environ.get("HOME", "/app/python/source_code")
if not home or home == "/":
home = "/app/python/source_code"
Expand Down Expand Up @@ -333,6 +336,10 @@ def run_setup():
setup_state["status"] = "running"
setup_state["started_at"] = time.time()

# Probe AI Gateway once; result is cached in _GATEWAY_RESOLVED for subprocesses
from utils import resolve_and_cache_gateway
resolve_and_cache_gateway()

# --- Sequential prerequisites (git identity + editor) ---
# Git config — done directly in Python, not as a subprocess
_update_step("git", status="running", started_at=time.time())
Expand Down Expand Up @@ -392,7 +399,7 @@ def get_token_owner():
try:
w = WorkspaceClient() # auto-detects SP credentials
app = w.apps.get(name=app_name)
owner = app.creator
owner = (app.creator or "").lower()
logger.info(f"Owner resolved from app.creator: {owner}")
return owner
except Exception as e:
Expand All @@ -405,17 +412,24 @@ def get_token_owner():
if not host or not token:
return None
w = WorkspaceClient(host=host, token=token, auth_type="pat")
return w.current_user.me().user_name
username = w.current_user.me().user_name
return username.lower() if username else username
except Exception as e:
logger.warning(f"Could not determine token owner: {e}")
return None


def get_request_user():
"""Extract user email from Databricks Apps request headers."""
return request.headers.get("X-Forwarded-Email") or \
request.headers.get("X-Forwarded-User") or \
request.headers.get("X-Databricks-User-Email")
"""Extract user email from Databricks Apps request headers.

Returns lowercase email to ensure case-insensitive matching against app_owner.
"""
email = (
request.headers.get("X-Forwarded-Email")
or request.headers.get("X-Forwarded-User")
or request.headers.get("X-Databricks-User-Email")
)
return email.lower() if email else email


def _is_databricks_apps():
Expand Down Expand Up @@ -468,9 +482,12 @@ def _check_ws_authorization():
return True # Local dev only

# Socket.IO passes HTTP headers from the initial handshake via request context
current_user = request.headers.get("X-Forwarded-Email") or \
request.headers.get("X-Forwarded-User") or \
request.headers.get("X-Databricks-User-Email")
raw_user = (
request.headers.get("X-Forwarded-Email")
or request.headers.get("X-Forwarded-User")
or request.headers.get("X-Databricks-User-Email")
)
current_user = raw_user.lower() if raw_user else raw_user

if not current_user:
if _is_databricks_apps():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "coda"
version = "0.17.0"
version = "0.17.1"
description = "CoDA - Coding Agents on Databricks Apps"
requires-python = ">=3.10"
dependencies = [
Expand Down
99 changes: 84 additions & 15 deletions tests/test_gateway_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,46 @@


class TestGetGatewayHost:
"""Test the 3-tier priority logic in get_gateway_host()."""
"""Test the 4-tier priority logic in get_gateway_host()."""

def _get_fn(self):
from utils import get_gateway_host
return get_gateway_host

# -- Tier 0: _GATEWAY_RESOLVED cache --

@mock.patch.dict(os.environ, {
"_GATEWAY_RESOLVED": "https://cached.gateway.com",
"DATABRICKS_GATEWAY_HOST": "https://explicit.gateway.com",
})
def test_resolved_cache_wins_over_explicit(self):
"""Tier 0: _GATEWAY_RESOLVED takes priority over explicit DATABRICKS_GATEWAY_HOST."""
assert self._get_fn()() == "https://cached.gateway.com"

@mock.patch.dict(os.environ, {"_GATEWAY_RESOLVED": ""}, clear=True)
def test_resolved_empty_returns_empty(self):
"""Tier 0: _GATEWAY_RESOLVED='' means 'probed, no gateway' — returns empty."""
assert self._get_fn()() == ""

@mock.patch.dict(os.environ, {
"_GATEWAY_RESOLVED": "https://cached.gw.com",
"DATABRICKS_WORKSPACE_ID": "12345",
})
@mock.patch("utils._probe_gateway")
def test_resolved_skips_probe(self, mock_probe):
"""Tier 0: when _GATEWAY_RESOLVED is set, probe is never called."""
self._get_fn()()
mock_probe.assert_not_called()

# -- Tier 1: explicit DATABRICKS_GATEWAY_HOST --

@mock.patch.dict(os.environ, {
"DATABRICKS_GATEWAY_HOST": "https://custom.gateway.com",
"DATABRICKS_WORKSPACE_ID": "12345",
})
def test_explicit_override_wins(self):
"""Tier 1: explicit DATABRICKS_GATEWAY_HOST takes priority over workspace ID."""
os.environ.pop("_GATEWAY_RESOLVED", None)
assert self._get_fn()() == "https://custom.gateway.com"

@mock.patch.dict(os.environ, {
Expand All @@ -34,6 +62,7 @@ def test_explicit_override_wins(self):
})
def test_explicit_override_gets_https(self):
"""Tier 1: explicit value without https:// gets it added."""
os.environ.pop("_GATEWAY_RESOLVED", None)
assert self._get_fn()() == "https://custom.gateway.com"

@mock.patch.dict(os.environ, {
Expand All @@ -42,16 +71,36 @@ def test_explicit_override_gets_https(self):
})
def test_explicit_override_trailing_slash_stripped(self):
"""Tier 1: trailing slash is stripped from explicit value."""
os.environ.pop("_GATEWAY_RESOLVED", None)
assert self._get_fn()() == "https://custom.gateway.com"

# -- Tier 2: auto-construct from DATABRICKS_WORKSPACE_ID --

@mock.patch("utils._probe_gateway", return_value=True)
@mock.patch.dict(os.environ, {"DATABRICKS_WORKSPACE_ID": "6280049833385130"}, clear=False)
def test_auto_construct_from_workspace_id(self):
"""Tier 2: construct gateway URL from DATABRICKS_WORKSPACE_ID."""
def test_auto_construct_from_workspace_id(self, mock_probe):
"""Tier 2: construct gateway URL from DATABRICKS_WORKSPACE_ID when reachable."""
env = os.environ.copy()
env.pop("DATABRICKS_GATEWAY_HOST", None)
env.pop("_GATEWAY_RESOLVED", None)
with mock.patch.dict(os.environ, env, clear=True):
result = self._get_fn()()
assert result == "https://6280049833385130.ai-gateway.cloud.databricks.com"
mock_probe.assert_called_once()

@mock.patch("utils._probe_gateway", return_value=False)
@mock.patch.dict(os.environ, {"DATABRICKS_WORKSPACE_ID": "6280049833385130"}, clear=False)
def test_auto_construct_falls_back_when_unreachable(self, mock_probe):
"""Tier 2 fallback: returns empty when auto-discovered gateway is unreachable."""
env = os.environ.copy()
env.pop("DATABRICKS_GATEWAY_HOST", None)
env.pop("_GATEWAY_RESOLVED", None)
with mock.patch.dict(os.environ, env, clear=True):
result = self._get_fn()()
assert result == ""
mock_probe.assert_called_once()

# -- Tier 3: nothing available --

@mock.patch.dict(os.environ, {}, clear=True)
def test_empty_when_nothing_set(self):
Expand All @@ -61,16 +110,21 @@ def test_empty_when_nothing_set(self):
@mock.patch.dict(os.environ, {"DATABRICKS_GATEWAY_HOST": "", "DATABRICKS_WORKSPACE_ID": ""})
def test_empty_when_both_blank(self):
"""Tier 3: returns empty when both vars are set but blank."""
os.environ.pop("_GATEWAY_RESOLVED", None)
assert self._get_fn()() == ""

@mock.patch("utils._probe_gateway", return_value=True)
@mock.patch.dict(os.environ, {"DATABRICKS_GATEWAY_HOST": " ", "DATABRICKS_WORKSPACE_ID": "12345"})
def test_whitespace_only_gateway_falls_through(self):
def test_whitespace_only_gateway_falls_through(self, mock_probe):
"""Whitespace-only DATABRICKS_GATEWAY_HOST falls through to workspace ID."""
os.environ.pop("_GATEWAY_RESOLVED", None)
assert self._get_fn()() == "https://12345.ai-gateway.cloud.databricks.com"

@mock.patch("utils._probe_gateway", return_value=True)
@mock.patch.dict(os.environ, {"DATABRICKS_GATEWAY_HOST": "", "DATABRICKS_WORKSPACE_ID": " 99999 "})
def test_workspace_id_whitespace_stripped(self):
def test_workspace_id_whitespace_stripped(self, mock_probe):
"""Leading/trailing whitespace in workspace ID is stripped."""
os.environ.pop("_GATEWAY_RESOLVED", None)
assert self._get_fn()() == "https://99999.ai-gateway.cloud.databricks.com"


Expand All @@ -93,6 +147,8 @@ def _run_setup(self, script_name, tmp_path, env_overrides=None):
"DATABRICKS_WORKSPACE_ID": "6280049833385130",
"PATH": os.environ.get("PATH", ""),
"PYTHONPATH": str(SETUP_DIR),
# Pre-resolve gateway so subprocess skips the network probe
"_GATEWAY_RESOLVED": "",
}
# Ensure DATABRICKS_GATEWAY_HOST is NOT set (test auto-discovery)
env.pop("DATABRICKS_GATEWAY_HOST", None)
Expand All @@ -107,29 +163,34 @@ def _run_setup(self, script_name, tmp_path, env_overrides=None):
env=env,
capture_output=True,
text=True,
timeout=30,
timeout=60,
)
return result

def test_setup_claude_uses_gateway(self, tmp_path):
"""setup_claude.py should use auto-discovered gateway for anthropic URL."""
def test_setup_claude_falls_back_when_gateway_unreachable(self, tmp_path):
"""setup_claude.py should fall back to serving-endpoints when gateway probe fails."""
result = self._run_setup("setup_claude.py", tmp_path)
assert result.returncode == 0, f"stderr: {result.stderr}"
assert "AI Gateway" in result.stdout or "6280049833385130" in result.stdout

# Verify settings.json has gateway-based URL
# Gateway is unreachable from test env, so should fall back
import json
settings_path = tmp_path / ".claude" / "settings.json"
if settings_path.exists():
settings = json.loads(settings_path.read_text())
base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "")
assert "6280049833385130.ai-gateway.cloud.databricks.com" in base_url
assert base_url.endswith("/anthropic")
# Either gateway or serving-endpoints is valid
assert (
"ai-gateway.cloud.databricks.com" in base_url
or "serving-endpoints/anthropic" in base_url
)

def test_setup_claude_explicit_override(self, tmp_path):
"""setup_claude.py should prefer explicit DATABRICKS_GATEWAY_HOST."""
result = self._run_setup("setup_claude.py", tmp_path, {
"DATABRICKS_GATEWAY_HOST": "https://custom.gateway.example.com",
# Simulate parent having resolved to the explicit gateway
"_GATEWAY_RESOLVED": "https://custom.gateway.example.com",
})
assert result.returncode == 0, f"stderr: {result.stderr}"

Expand All @@ -154,53 +215,61 @@ def test_setup_claude_fallback_no_gateway(self, tmp_path):
base_url = settings.get("env", {}).get("ANTHROPIC_BASE_URL", "")
assert "test.cloud.databricks.com/serving-endpoints/anthropic" in base_url

def test_codex_gateway_url_construction(self):
@mock.patch("utils._probe_gateway", return_value=True)
def test_codex_gateway_url_construction(self, mock_probe):
"""Codex endpoint should use gateway /openai/v1 path."""
from utils import get_gateway_host
with mock.patch.dict(os.environ, {
"DATABRICKS_WORKSPACE_ID": "6280049833385130",
}, clear=False):
env = os.environ.copy()
env.pop("DATABRICKS_GATEWAY_HOST", None)
env.pop("_GATEWAY_RESOLVED", None)
with mock.patch.dict(os.environ, env, clear=True):
gw = get_gateway_host()
codex_url = f"{gw}/openai/v1"
assert codex_url == "https://6280049833385130.ai-gateway.cloud.databricks.com/openai/v1"

def test_gemini_gateway_url_construction(self):
@mock.patch("utils._probe_gateway", return_value=True)
def test_gemini_gateway_url_construction(self, mock_probe):
"""Gemini endpoint should use gateway /gemini path."""
from utils import get_gateway_host
with mock.patch.dict(os.environ, {
"DATABRICKS_WORKSPACE_ID": "6280049833385130",
}, clear=False):
env = os.environ.copy()
env.pop("DATABRICKS_GATEWAY_HOST", None)
env.pop("_GATEWAY_RESOLVED", None)
with mock.patch.dict(os.environ, env, clear=True):
gw = get_gateway_host()
gemini_url = f"{gw}/gemini"
assert gemini_url == "https://6280049833385130.ai-gateway.cloud.databricks.com/gemini"

def test_anthropic_gateway_url_construction(self):
@mock.patch("utils._probe_gateway", return_value=True)
def test_anthropic_gateway_url_construction(self, mock_probe):
"""Anthropic endpoint should use gateway /anthropic path."""
from utils import get_gateway_host
with mock.patch.dict(os.environ, {
"DATABRICKS_WORKSPACE_ID": "6280049833385130",
}, clear=False):
env = os.environ.copy()
env.pop("DATABRICKS_GATEWAY_HOST", None)
env.pop("_GATEWAY_RESOLVED", None)
with mock.patch.dict(os.environ, env, clear=True):
gw = get_gateway_host()
anthropic_url = f"{gw}/anthropic"
assert anthropic_url == "https://6280049833385130.ai-gateway.cloud.databricks.com/anthropic"

def test_proxy_gateway_url_construction(self):
@mock.patch("utils._probe_gateway", return_value=True)
def test_proxy_gateway_url_construction(self, mock_probe):
"""Proxy endpoint should use gateway /mlflow/v1 path."""
from utils import get_gateway_host
with mock.patch.dict(os.environ, {
"DATABRICKS_WORKSPACE_ID": "6280049833385130",
}, clear=False):
env = os.environ.copy()
env.pop("DATABRICKS_GATEWAY_HOST", None)
env.pop("_GATEWAY_RESOLVED", None)
with mock.patch.dict(os.environ, env, clear=True):
gw = get_gateway_host()
proxy_url = f"{gw}/mlflow/v1"
Expand Down
52 changes: 49 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Shared utilities for Databricks App setup scripts."""

from __future__ import annotations

import os
import re
import subprocess
Expand Down Expand Up @@ -61,25 +63,69 @@ def adapt_instructions_file(
return True


def _probe_gateway(url: str, timeout: float = 2.0) -> bool:
"""Quick connectivity check against an AI Gateway host.

Sends a lightweight GET to the root. Any HTTP response (even 401/404)
means the host exists. Only a connection failure means it doesn't.
Timeout is 2s — the gateway is same-region, so it responds fast if it exists.
"""
import requests

try:
requests.get(url, timeout=timeout, allow_redirects=False)
return True
except (requests.ConnectionError, requests.Timeout):
return False
except Exception:
return False


def get_gateway_host() -> str:
"""Resolve the AI Gateway host URL.

Priority:
1. Explicit DATABRICKS_GATEWAY_HOST env var (override)
2. Auto-constructed from DATABRICKS_WORKSPACE_ID
0. _GATEWAY_RESOLVED env var (set by parent process after probing — avoids
re-probing in subprocesses). None = never probed, "" = probed, no gateway.
1. Explicit DATABRICKS_GATEWAY_HOST env var (trusted — no probe)
2. Auto-constructed from DATABRICKS_WORKSPACE_ID (probed for reachability)
3. Empty string (caller falls back to DATABRICKS_HOST/serving-endpoints)
"""
# Tier 0: already resolved by a parent process
resolved = os.environ.get("_GATEWAY_RESOLVED")
if resolved is not None:
return resolved

# Tier 1: explicit override (trusted, no probe)
explicit = os.environ.get("DATABRICKS_GATEWAY_HOST", "").strip().rstrip("/")
if explicit:
return ensure_https(explicit)

# Tier 2: auto-construct from workspace ID and probe for reachability
workspace_id = os.environ.get("DATABRICKS_WORKSPACE_ID", "").strip()
if workspace_id:
return f"https://{workspace_id}.ai-gateway.cloud.databricks.com"
candidate = f"https://{workspace_id}.ai-gateway.cloud.databricks.com"
if _probe_gateway(candidate):
return candidate
print(
f"AI Gateway not reachable at {candidate}, "
"falling back to serving-endpoints"
)

return ""


def resolve_and_cache_gateway() -> str:
"""Probe the gateway once and cache the result in the environment.

Subsequent calls to get_gateway_host() — including those in child
processes — will see _GATEWAY_RESOLVED and skip the probe.
"""
result = get_gateway_host()
os.environ["_GATEWAY_RESOLVED"] = result
return result


def ensure_https(url: str) -> str:
"""Ensure a URL has the https:// prefix.

Expand Down
Loading