From ba19fbf9a7c67458e3ed16f41e726ba0eadb4eb1 Mon Sep 17 00:00:00 2001 From: Rohit Sharma Date: Thu, 30 Apr 2026 16:34:21 +0100 Subject: [PATCH] feat(spark): unify Databricks connector auth with PAT, OAuth M2M, and OIDC federation strategies Databricks has marked Personal Access Tokens (PATs) as a legacy auth method (https://docs.databricks.com/aws/en/dev-tools/auth/pat) and is steering customers toward OAuth machine-to-machine (M2M) and OIDC workload-identity federation. This change brings both modern auth modes to the Flyte Databricks connector and refactors the existing PAT support into a shared strategy module so all four modes resolve identically. What is added: * OAuth M2M (client credentials) using a per-namespace 'databricks-oauth' K8s secret with operator-level fallbacks via env vars. * OIDC federation, Model 1: the connector pod's own projected JWT (e.g. EKS IRSA) is exchanged for a Databricks bearer token. * OIDC federation, Model 2: per-workflow-namespace ServiceAccount discovery driven by labels and annotations on the SA. The connector mints a JWT via the Kubernetes TokenRequest API and exchanges it for a Databricks token. This preserves the existing per-namespace tenancy model that PAT customers rely on for Unity Catalog access. * A unified DatabricksAuth strategy abstraction in 'databricks_auth.py' with auto-detection, per-strategy token caching, and token refresh on 401 responses for long-running jobs. What changes for existing PAT users: * This PR refactors the PAT support that was added in flyteorg/flytekit#3394 from a direct function call into a 'PATAuth' strategy that lives alongside the new modes. The behaviour, env vars, and per-namespace 'databricks-token' lookup are preserved end-to-end. Reviewers may want to read 'connector.py' and the new tests with this refactor in mind: PAT now flows through the same 'select_auth' resolver as the new modes so we have one extension point instead of two. * Workflow code is unchanged. 'DatabricksV2' gains optional override fields for power users, but existing tasks keep working without edits. Validation: * 'pytest plugins/flytekit-spark/tests/test_databricks_auth.py plugins/flytekit-spark/tests/test_databricks_token.py plugins/flytekit-spark/tests/test_connector.py' passes (100 tests). * End-to-end tested on an EKS test cluster against a real Databricks workspace for PAT, OAuth M2M, and OIDC Model 2. * Pre-commit (ruff, ruff-format, codespell, pydoclint) clean on the changed plugin files. Tracking: flyteorg/flyte#7319 Related: * flyteorg/flytekit#3394 (PAT multi-tenancy, refactored here) * flyteorg/flytekit#3392 (Databricks Serverless compute) * flyteorg/flyte#6911 (original PAT multi-tenancy issue) Signed-off-by: Rohit Sharma --- plugins/flytekit-spark/README.md | 223 +++++ .../flytekitplugins/spark/connector.py | 238 ++++- .../flytekitplugins/spark/databricks_auth.py | 881 ++++++++++++++++++ .../flytekitplugins/spark/task.py | 93 ++ .../flytekit-spark/tests/test_connector.py | 12 + .../tests/test_databricks_auth.py | 858 +++++++++++++++++ .../tests/test_databricks_token.py | 25 +- 7 files changed, 2297 insertions(+), 33 deletions(-) create mode 100644 plugins/flytekit-spark/flytekitplugins/spark/databricks_auth.py create mode 100644 plugins/flytekit-spark/tests/test_databricks_auth.py diff --git a/plugins/flytekit-spark/README.md b/plugins/flytekit-spark/README.md index 9cc7c7cf9d..df658008bc 100644 --- a/plugins/flytekit-spark/README.md +++ b/plugins/flytekit-spark/README.md @@ -11,3 +11,226 @@ pip install flytekitplugins-spark To configure Spark in the Flyte deployment's backend, follow [Step 1](https://docs.flyte.org/en/latest/deployment/plugins/k8s/index.html#deployment-plugin-setup-k8s), [2](https://docs.flyte.org/en/latest/flytesnacks/examples/k8s_spark_plugin/index.html). All [examples](https://docs.flyte.org/en/latest/flytesnacks/examples/k8s_spark_plugin/index.html) showcasing execution of Spark jobs using the plugin can be found in the documentation. + +## Databricks Connector Authentication + +The `DatabricksV2` task config drives Flyte's Databricks connector, which calls the Databricks Jobs REST API to submit runs. The connector supports three authentication types: + +| Auth type | Credentials | Typical use case | +|-------------------|----------------------------------------------|---------------------------------------------------------| +| `pat` | Personal Access Token | Existing deployments; simplest to configure | +| `oauth_m2m` | Service Principal `client_id` + `client_secret` | Shared credentials across many workflows | +| `oidc_federation` | Workload Identity JWT (IRSA / Kubernetes SA) | No long-lived secrets; identity tied to the workload | + +All three produce an `Authorization: Bearer ` header for the Jobs API. The choice is transparent to the workflow code — you can migrate between modes by changing connector deployment config without touching any Python tasks. + +### Resolution order + +For every auth-related field (`databricks_auth_type`, `databricks_client_id`, ...), the connector checks sources in this order and uses the first hit: + +1. The task config field on `DatabricksV2(...)`. +2. The corresponding `FLYTE_DATABRICKS_*` or `DATABRICKS_*` environment variable on the connector pod. +3. A well-known default (e.g. `databricks-oauth` for the OAuth k8s secret name, `databricks` for the OIDC audience). + +If `FLYTE_DATABRICKS_AUTH_TYPE` is not set at all, the connector auto-detects an auth type per task: OIDC federation if a subject JWT file and a `client_id` are reachable, else OAuth M2M if `DATABRICKS_CLIENT_ID` + `DATABRICKS_CLIENT_SECRET` are set, else PAT. This means that in the common case the operator only needs to set connector env vars and mount the right IRSA/secret — workflows don't change. + +For OIDC federation specifically, the choice between Model 2 (per-workflow-namespace SA) and Model 1 (connector pod identity) is made automatically at submit time based on whether an annotated ServiceAccount is found in the workflow's namespace. See the OIDC Federation section below. + +### Environment variables (connector defaults) + +| Variable | Purpose | +|-------------------------------------------------|---------------------------------------------------------| +| `FLYTE_DATABRICKS_INSTANCE` | Default Databricks workspace hostname | +| `FLYTE_DATABRICKS_AUTH_TYPE` | `pat` / `oauth_m2m` / `oidc_federation` (optional) | +| `FLYTE_DATABRICKS_ACCESS_TOKEN` | Fallback PAT if no namespace secret is found | +| `FLYTE_DATABRICKS_TOKEN_SECRET_NAME` | k8s secret name holding the PAT (default `databricks-token`) | +| `FLYTE_DATABRICKS_OAUTH_SECRET_NAME` | k8s secret with `client_id` / `client_secret` keys (default `databricks-oauth`) | +| `DATABRICKS_CLIENT_ID` | Service Principal client ID for M2M / OIDC | +| `DATABRICKS_CLIENT_SECRET` | Service Principal client secret for M2M | +| `AWS_WEB_IDENTITY_TOKEN_FILE` | IRSA-injected subject JWT (for OIDC Model 1) | +| `FLYTE_DATABRICKS_OIDC_TOKEN_FILE` | Override path to subject JWT (for OIDC Model 1) | +| `FLYTE_DATABRICKS_OIDC_AUDIENCE` | OIDC audience (default `databricks`); per-namespace override via SA annotation | +| `FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER` | Provider name for serverless S3 credentials | + +### PAT (Personal Access Token) + +Default and simplest mode. The connector looks up a k8s secret named `databricks-token` in the workflow namespace, and falls back to `FLYTE_DATABRICKS_ACCESS_TOKEN` on the connector pod. + +Per-task override: + +```python +DatabricksV2( + databricks_conf={...}, + databricks_instance="my-workspace.cloud.databricks.com", + databricks_token_secret="team-a-databricks-token", +) +``` + +### OAuth M2M (Service Principal) + +Exchanges a `client_id` + `client_secret` for a short-lived Databricks OAuth token via the `client_credentials` grant on `https:///oidc/v1/token`. Tokens are cached per `(workspace, client_id, namespace)` until they are close to expiry. + +Credentials are resolved from: + +- `client_id` / `client_secret` keys of a k8s secret named `databricks-oauth` (override with `databricks_oauth_secret` or `FLYTE_DATABRICKS_OAUTH_SECRET_NAME`) in the workflow namespace, or +- `DATABRICKS_CLIENT_ID` + `DATABRICKS_CLIENT_SECRET` env vars on the connector. + +Operator-side setup (connector deployment): + +```yaml +env: + - name: FLYTE_DATABRICKS_AUTH_TYPE + value: oauth_m2m + - name: DATABRICKS_CLIENT_ID + value: "" + - name: DATABRICKS_CLIENT_SECRET + valueFrom: + secretKeyRef: { name: databricks-oauth, key: client_secret } +``` + +No workflow code changes are needed once the connector is reconfigured. + +### OIDC Federation + +Databricks verifies a JWT signed by a trusted external OIDC provider (e.g. the EKS OIDC issuer for your cluster, or the Kubernetes API server itself) and issues a short-lived Databricks OAuth token. No long-lived client secret is stored anywhere. The connector uses the `urn:ietf:params:oauth:grant-type:token-exchange` grant on `https:///oidc/v1/token` with `subject_token_type=urn:ietf:params:oauth:token-type:jwt`. + +Two deployment models are supported, **automatically dispatched at submit time**: + +- **Model 2 — Per-workflow-namespace ServiceAccount** is selected when the connector finds a labelled+annotated ServiceAccount in the workflow's namespace. Each namespace can federate as a different Databricks Service Principal, giving per-namespace Unity Catalog tenancy. +- **Model 1 — Connector pod identity** is the fallback when no annotated SA is found. All workflows share the connector pod's own Databricks identity. + +#### Model 2 — Per-workflow-namespace ServiceAccount (auto-discovered) + +Operators opt a workflow namespace into Model 2 by **labelling and annotating a ServiceAccount** in that namespace. The connector lists SAs at submit time, picks the one carrying both markers, mints a JWT for it via the Kubernetes [`TokenRequest`](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.29/#tokenrequest-v1-authentication-k8s-io) API, and exchanges that JWT for a Databricks bearer issued to the Service Principal named in the annotation. + +Per-workflow-namespace setup: + +```yaml +apiVersion: v1 +kind: ServiceAccount +metadata: + name: dbx-runner # any name - connector discovers it + namespace: + labels: + flyte.org/databricks-enabled: "true" + annotations: + flyte.org/databricks-client-id: "" + flyte.org/databricks-audience: "databricks" # optional; defaults to "databricks" +``` + +Connector deployment env (the same one-time config covers every workflow namespace; per-namespace tenancy is driven by the SA annotations, not by per-namespace connector env): + +```yaml +env: + - name: FLYTE_DATABRICKS_AUTH_TYPE + value: oidc_federation + # DATABRICKS_CLIENT_ID is *only* needed if you also want a Model 1 fallback + # for namespaces that have no annotated SA. If unset, missing-SA namespaces fail + # loudly rather than silently downgrading to a shared identity. +``` + +Required RBAC on the connector's ServiceAccount (cluster-wide, since the connector serves any workflow namespace): + +```yaml +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: databricks-connector-discovery +rules: + - apiGroups: [""] + resources: ["serviceaccounts"] + verbs: ["get", "list"] + - apiGroups: [""] + resources: ["serviceaccounts/token"] + verbs: ["create"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: databricks-connector-discovery +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: databricks-connector-discovery +subjects: + - kind: ServiceAccount + name: + namespace: +``` + +For tighter scoping, swap `ClusterRoleBinding` for per-namespace `RoleBinding`s — each new workflow namespace then needs an explicit binding. + +The Databricks Federation Policy on the SP referenced by `flyte.org/databricks-client-id` must list the workflow-namespace SA as the federated subject: + +```text +issuer: +subject: system:serviceaccount:: +audience: databricks +``` + +**Discovery semantics:** + +- **Exactly one** annotated SA per namespace → Model 2 with that SA + its annotated `client_id` and `audience`. +- **Zero** annotated SAs → Model 1 fallback, but only if the connector itself has both `DATABRICKS_CLIENT_ID` and a reachable subject token file (e.g. IRSA). Otherwise the create() fails with an actionable `DatabricksAuthError`. +- **Multiple** annotated SAs in the same namespace → hard error listing the conflicting names. Annotate exactly one. +- **Caching:** discovery results (hits and misses) are cached per namespace for 5 minutes. After changing an SA annotation, restart the connector pod for immediate effect; otherwise the change picks up after TTL. + +#### Model 1 — Connector pod identity (fallback) + +The connector pod's own projected JWT is exchanged. On EKS this is the token at `$AWS_WEB_IDENTITY_TOKEN_FILE` provisioned by IRSA. All workflows share the same Databricks identity — the connector's Service Principal. This is the right mode when you don't need per-namespace UC tenancy. + +Required setup: + +1. Create a Service Principal in Databricks, grant it workspace access and the permissions your jobs need. +2. Configure a Databricks Federation Policy for that Service Principal pointing at your OIDC issuer (e.g. the EKS cluster OIDC URL) with `subject` claim matching the connector pod's ServiceAccount: `system:serviceaccount::`. +3. On the connector pod: annotate its ServiceAccount with the IAM role ARN (IRSA) so `AWS_WEB_IDENTITY_TOKEN_FILE` is mounted, and set: + + ```yaml + env: + - name: FLYTE_DATABRICKS_AUTH_TYPE + value: oidc_federation + - name: DATABRICKS_CLIENT_ID + value: "" + ``` + +4. Make sure no workflow namespace contains an SA labelled `flyte.org/databricks-enabled=true` — otherwise that namespace will be auto-promoted to Model 2. + +The JWT file is re-read on every token refresh, so IRSA rotation is handled automatically. + +### Migration guide + +All of these paths require only operator-side changes; workflow Python code stays the same unless a single workflow wants to diverge from the connector default. + +**PAT → OAuth M2M** + +1. Create a Service Principal in Databricks and grant it the same access your PAT had. +2. Store its `client_id` and `client_secret` either in a `databricks-oauth` k8s secret per workflow namespace, or as env vars on the connector. +3. Set `FLYTE_DATABRICKS_AUTH_TYPE=oauth_m2m` on the connector deployment, and redeploy. + +**PAT → OIDC Federation (Model 1)** + +1. Create a Service Principal in Databricks (no client secret needed). +2. Configure a Federation Policy on that Service Principal that trusts your Kubernetes/EKS OIDC issuer and matches the connector pod's ServiceAccount (by `subject`). +3. Annotate the connector's ServiceAccount with the IAM role (IRSA) if you're on EKS, or otherwise ensure a projected JWT is mounted. +4. Set `FLYTE_DATABRICKS_AUTH_TYPE=oidc_federation` + `DATABRICKS_CLIENT_ID=` on the connector, and redeploy. + +**PAT → OIDC Federation (Model 2, per-namespace tenancy)** + +1. Create a separate Databricks Service Principal **per workflow namespace** that needs its own UC permissions. +2. For each SP, configure a Federation Policy whose `subject` is `system:serviceaccount::` and `audience` is `databricks`. +3. In each workflow namespace, create a ServiceAccount with these markers: + + ```yaml + labels: { flyte.org/databricks-enabled: "true" } + annotations: { flyte.org/databricks-client-id: "" } + ``` + +4. Add the RBAC ClusterRole + ClusterRoleBinding above to the connector's ServiceAccount. +5. Set `FLYTE_DATABRICKS_AUTH_TYPE=oidc_federation` on the connector — that's the only connector-side change. Discovery handles the rest. + +No workflow code changes anywhere. + +### Backward compatibility + +- Jobs created by a pre-OAuth/OIDC connector version store only `auth_token` in their metadata. Newer connectors polling those jobs will use the stored token as-is and skip the refresh-on-401 path, preserving the previous behavior. +- Jobs created by a newer connector and polled by an older connector will also work: the newer connector stores a PAT token in `auth_token` whenever `auth_type == "pat"`, and older connectors ignore the additional metadata fields they don't know about. diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index 06b02048d1..480c9df7a6 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -29,9 +29,25 @@ @dataclass class DatabricksJobMetadata(ResourceMeta): + """Metadata persisted for a Databricks run. + + ``auth_token`` preserves backward compatibility with jobs created by older + connector versions that only knew about PAT auth. The ``auth_type`` and + companion fields let ``get``/``delete`` rebuild a :class:`DatabricksAuth` + and refresh short-lived OAuth tokens for long-running jobs. + """ + databricks_instance: str run_id: str - auth_token: Optional[str] = None # Store auth token for get/delete operations + auth_token: Optional[str] = None + auth_type: Optional[str] = None + client_id: Optional[str] = None + oauth_secret_name: Optional[str] = None + token_secret_name: Optional[str] = None + oidc_token_file: Optional[str] = None + oidc_service_account: Optional[str] = None + oidc_audience: Optional[str] = None + namespace: Optional[str] = None def _configure_serverless(databricks_job: dict, envs: dict) -> str: @@ -254,6 +270,14 @@ class DatabricksConnector(AsyncConnectorBase): def __init__(self): super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata) + # Lazy import to avoid a top-level cycle with databricks_auth -> connector. + from .databricks_auth import validate_connector_config + + try: + validate_connector_config() + except Exception as e: # pragma: no cover - validation re-raises cleanly + logger.error("Databricks connector startup validation failed: %s", e) + raise async def create( self, @@ -262,6 +286,8 @@ async def create( task_execution_metadata: Optional[TaskExecutionMetadata] = None, **kwargs, ) -> DatabricksJobMetadata: + from .databricks_auth import select_auth + data = json.dumps(_get_databricks_job_spec(task_template)) databricks_instance = task_template.custom.get( "databricksInstance", os.getenv(DEFAULT_DATABRICKS_INSTANCE_ENV_KEY) @@ -272,30 +298,49 @@ async def create( f"Missing databricks instance. Please set the value through the task config or set the {DEFAULT_DATABRICKS_INSTANCE_ENV_KEY} environment variable in the connector." ) - # Get workflow-specific token or fall back to default namespace = task_execution_metadata.namespace if task_execution_metadata else None - # Extract custom secret name from task template (if provided) - custom_secret_name = task_template.custom.get("databricksTokenSecret") - - logger.info(f"Creating Databricks job for namespace: {namespace or 'unknown'}") - if custom_secret_name: - logger.info(f"Using custom secret name: {custom_secret_name}") + auth = await select_auth(task_template=task_template, workspace_url=databricks_instance, namespace=namespace) + logger.info("Databricks auth resolved: %s", auth.describe()) - auth_token = get_databricks_token( - namespace=namespace, task_template=task_template, secret_name=custom_secret_name - ) databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit" async with aiohttp.ClientSession() as session: + auth_token = await auth.get_bearer_token(session) async with session.post(databricks_url, headers=get_header(auth_token=auth_token), data=data) as resp: response = await resp.json() if resp.status != http.HTTPStatus.OK: raise RuntimeError(f"Failed to create databricks job with error: {response}") logger.info(f"Successfully created Databricks job with run_id: {response['run_id']}") + # Stash enough context on the metadata to rebuild auth during get/delete. + # For PAT we also keep the resolved token as a fast-path and as a safety net for + # old connectors that might still be polling jobs created by a newer one. + stored_token: Optional[str] = auth_token if auth.auth_type == "pat" else None + + # OIDC Model 2 captures the discovered SA + per-namespace client_id; other + # strategies persist whatever was on settings. + client_id_to_persist = auth.settings.client_id + oidc_audience_to_persist = auth.settings.oidc_audience + oidc_sa_to_persist: Optional[str] = None + discovered = getattr(auth, "discovered", None) + if discovered is not None: + client_id_to_persist = discovered.client_id + oidc_audience_to_persist = discovered.audience + oidc_sa_to_persist = discovered.service_account + return DatabricksJobMetadata( - databricks_instance=databricks_instance, run_id=str(response["run_id"]), auth_token=auth_token + databricks_instance=databricks_instance, + run_id=str(response["run_id"]), + auth_token=stored_token, + auth_type=auth.auth_type, + client_id=client_id_to_persist, + oauth_secret_name=auth.settings.oauth_secret_name, + token_secret_name=auth.settings.token_secret_name, + oidc_token_file=auth.settings.oidc_token_file, + oidc_service_account=oidc_sa_to_persist, + oidc_audience=oidc_audience_to_persist, + namespace=namespace, ) async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: @@ -304,14 +349,14 @@ async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}" ) - # Use the stored auth token if available, otherwise fall back to default - headers = get_header(auth_token=resource_meta.auth_token) - async with aiohttp.ClientSession() as session: - async with session.get(databricks_url, headers=headers) as resp: - if resp.status != http.HTTPStatus.OK: - raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") - response = await resp.json() + response = await self._request_with_auth( + session=session, + method="GET", + url=databricks_url, + resource_meta=resource_meta, + action_label=f"get databricks job {resource_meta.run_id}", + ) cur_phase = TaskExecution.UNDEFINED message = "" @@ -339,16 +384,86 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs): databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel" data = json.dumps({"run_id": resource_meta.run_id}) - # Use the stored auth token if available, otherwise fall back to default - headers = get_header(auth_token=resource_meta.auth_token) - async with aiohttp.ClientSession() as session: - async with session.post(databricks_url, headers=headers, data=data) as resp: - if resp.status != http.HTTPStatus.OK: - raise RuntimeError( - f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}" - ) - await resp.json() + await self._request_with_auth( + session=session, + method="POST", + url=databricks_url, + resource_meta=resource_meta, + data=data, + action_label=f"cancel databricks job {resource_meta.run_id}", + ) + + async def _request_with_auth( + self, + session: "aiohttp.ClientSession", # type: ignore[name-defined] + method: str, + url: str, + resource_meta: DatabricksJobMetadata, + action_label: str, + data: Optional[str] = None, + ) -> dict: + """Issue a single Databricks Jobs API call, refreshing auth on 401 when possible. + + Jobs created by an older connector will have ``auth_type=None`` in their metadata; + in that case we fall back to the stored ``auth_token`` with no refresh attempt. + For OAuth/OIDC-backed jobs we rebuild the :class:`DatabricksAuth` strategy from the + metadata, invalidate the cached token on 401, and retry exactly once. + """ + from .databricks_auth import DatabricksAuthError, build_auth + + auth = None + if resource_meta.auth_type is not None: + try: + auth = build_auth( + workspace_url=resource_meta.databricks_instance, + auth_type=resource_meta.auth_type, + namespace=resource_meta.namespace, + client_id=resource_meta.client_id, + oauth_secret_name=resource_meta.oauth_secret_name, + token_secret_name=resource_meta.token_secret_name, + oidc_token_file=resource_meta.oidc_token_file, + oidc_service_account=resource_meta.oidc_service_account, + oidc_audience=resource_meta.oidc_audience, + ) + except DatabricksAuthError as e: + logger.warning("Failed to rebuild auth from metadata (%s); falling back to stored token", e) + + token: Optional[str] = None + if auth is not None: + try: + token = await auth.get_bearer_token(session) + except DatabricksAuthError as e: + if resource_meta.auth_token is None: + raise RuntimeError(f"Failed to {action_label}: could not refresh Databricks auth: {e}") + logger.warning("Token refresh failed (%s); falling back to stored auth_token", e) + if token is None: + token = resource_meta.auth_token + + def _request(bearer: Optional[str]): + headers = get_header(auth_token=bearer) + if method.upper() == "GET": + return session.get(url, headers=headers) + return session.post(url, headers=headers, data=data) + + async with _request(token) as resp: + if resp.status == http.HTTPStatus.UNAUTHORIZED and auth is not None: + logger.info( + "Databricks API returned 401 for %s; invalidating token cache and retrying once", + action_label, + ) + await auth.invalidate_cache() + try: + token = await auth.get_bearer_token(session) + except DatabricksAuthError as e: + raise RuntimeError(f"Failed to {action_label}: auth refresh failed after 401: {e}") + async with _request(token) as resp2: + if resp2.status != http.HTTPStatus.OK: + raise RuntimeError(f"Failed to {action_label} with error: {resp2.reason}") + return await resp2.json() + if resp.status != http.HTTPStatus.OK: + raise RuntimeError(f"Failed to {action_label} with error: {resp.reason}") + return await resp.json() class DatabricksConnectorV2(DatabricksConnector): @@ -364,6 +479,73 @@ def __init__(self): super(DatabricksConnector, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata) +def list_serviceaccounts_in_k8s(namespace: str, label_selector: Optional[str] = None) -> list: + """List ServiceAccounts in a Kubernetes namespace, optionally filtered by label. + + Used by the OIDC federation auto-discovery path to find a workflow-namespace SA + that's annotated as the Databricks federation identity. + + The connector's ServiceAccount needs ``list`` and ``get`` on ``serviceaccounts`` + cluster-wide (or scoped per-namespace via RoleBindings) for this to succeed. RBAC + is documented in the plugin README's "Authentication" section. + + Args: + namespace (str): The workflow's Kubernetes namespace. + label_selector (Optional[str]): Optional label selector, e.g. + ``"flyte.org/databricks-enabled=true"``. + + Returns: + list: List of ``V1ServiceAccount`` objects (each carries ``.metadata.name`` and + ``.metadata.annotations``). Empty list if none match or the K8s API is + unreachable / forbidden (call sites treat that as "not discovered" and fall + back to OIDC Model 1). + """ + try: + from kubernetes import client, config + + try: + config.load_incluster_config() + except config.ConfigException: + try: + config.load_kube_config() + except Exception as e: + logger.warning(f"Failed to load Kubernetes config: {e}") + return [] + + v1 = client.CoreV1Api() + try: + kwargs = {"namespace": namespace} + if label_selector: + kwargs["label_selector"] = label_selector + resp = v1.list_namespaced_service_account(**kwargs) + return list(resp.items or []) + except client.exceptions.ApiException as e: + if e.status == 403: + logger.warning( + "Forbidden listing ServiceAccounts in namespace '%s' (label_selector=%r). " + "The connector's ServiceAccount needs 'list' on 'serviceaccounts' for OIDC " + "federation auto-discovery. See the plugin README's Authentication section.", + namespace, + label_selector, + ) + elif e.status == 404: + logger.debug("Namespace '%s' not found while listing ServiceAccounts", namespace) + else: + logger.warning( + "Error listing ServiceAccounts in namespace '%s' (label_selector=%r): %s", + namespace, + label_selector, + e, + ) + return [] + except ImportError: + logger.warning("kubernetes Python package not installed - cannot list ServiceAccounts") + return [] + except Exception as e: + logger.warning(f"Unexpected error listing ServiceAccounts: {e}") + return [] + + def get_secret_from_k8s(secret_name: str, secret_key: str, namespace: str) -> Optional[str]: """Read a secret from Kubernetes using the Kubernetes Python client. diff --git a/plugins/flytekit-spark/flytekitplugins/spark/databricks_auth.py b/plugins/flytekit-spark/flytekitplugins/spark/databricks_auth.py new file mode 100644 index 0000000000..366895922a --- /dev/null +++ b/plugins/flytekit-spark/flytekitplugins/spark/databricks_auth.py @@ -0,0 +1,881 @@ +"""Databricks authentication strategies for the flytekit-spark connector. + +This module centralises Databricks auth behind a small strategy abstraction. Three +user-visible auth types are supported, all of which end up producing an +``Authorization: Bearer `` header for the Databricks Jobs API: + +- ``pat`` - Personal Access Token (legacy, long-lived). Unchanged from the + upstream multi-tenant PAT flow: the token is read from a cross-namespace k8s + secret and/or the connector's ``FLYTE_DATABRICKS_ACCESS_TOKEN`` env var. +- ``oauth_m2m`` - OAuth Service Principal, ``client_credentials`` grant. Reads + ``client_id`` / ``client_secret`` from a k8s secret in the workflow namespace + (default name ``databricks-oauth``) or from connector env vars. +- ``oidc_federation`` - OAuth Workload Identity Federation, ``token-exchange`` + grant. Two flavours, dispatched automatically: + + - **Model 2** (per-workflow-namespace identity): if the workflow namespace + contains a ``ServiceAccount`` labelled ``flyte.org/databricks-enabled=true`` + and annotated with ``flyte.org/databricks-client-id``, the connector mints + a fresh JWT for that SA via Kubernetes ``TokenRequest`` and exchanges it + for a token issued to the SP named in the annotation. Each namespace can + federate to a different Databricks Service Principal. + - **Model 1** (single connector identity): fallback when no annotated SA is + discovered. Uses the connector pod's own projected OIDC JWT + (``AWS_WEB_IDENTITY_TOKEN_FILE`` from EKS IRSA, or an equivalent path) + and the connector-level ``DATABRICKS_CLIENT_ID``. + +Every setting follows a consistent resolution order:: + + task config field -> FLYTE_DATABRICKS_* env var -> well-known default +""" + +import asyncio +import json as _json +import logging +import os +import random +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +from flytekit import lazy_module +from flytekit.models.task import TaskTemplate + +aiohttp = lazy_module("aiohttp") + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- # +# Env-var names (public so operators can grep for them) # +# --------------------------------------------------------------------------- # +FLYTE_DATABRICKS_AUTH_TYPE_ENV = "FLYTE_DATABRICKS_AUTH_TYPE" +FLYTE_DATABRICKS_OAUTH_SECRET_NAME_ENV = "FLYTE_DATABRICKS_OAUTH_SECRET_NAME" +FLYTE_DATABRICKS_OIDC_TOKEN_FILE_ENV = "FLYTE_DATABRICKS_OIDC_TOKEN_FILE" +FLYTE_DATABRICKS_OIDC_AUDIENCE_ENV = "FLYTE_DATABRICKS_OIDC_AUDIENCE" +FLYTE_DATABRICKS_TOKEN_SECRET_NAME_ENV = "FLYTE_DATABRICKS_TOKEN_SECRET_NAME" +DATABRICKS_CLIENT_ID_ENV = "DATABRICKS_CLIENT_ID" +DATABRICKS_CLIENT_SECRET_ENV = "DATABRICKS_CLIENT_SECRET" +AWS_WEB_IDENTITY_TOKEN_FILE_ENV = "AWS_WEB_IDENTITY_TOKEN_FILE" + +# --------------------------------------------------------------------------- # +# OIDC Model 2 discovery: SA annotation/label keys (community convention) # +# --------------------------------------------------------------------------- # +# A ServiceAccount in the workflow namespace becomes a Model 2 federation +# identity by carrying: +# labels: flyte.org/databricks-enabled: "true" +# annotations: flyte.org/databricks-client-id: "" +# flyte.org/databricks-audience: "databricks" # optional +LABEL_DATABRICKS_ENABLED = "flyte.org/databricks-enabled" +ANNOTATION_DATABRICKS_CLIENT_ID = "flyte.org/databricks-client-id" +ANNOTATION_DATABRICKS_AUDIENCE = "flyte.org/databricks-audience" +DATABRICKS_ENABLED_LABEL_SELECTOR = f"{LABEL_DATABRICKS_ENABLED}=true" + +# --------------------------------------------------------------------------- # +# Defaults # +# --------------------------------------------------------------------------- # +DEFAULT_OAUTH_SECRET_NAME = "databricks-oauth" +DEFAULT_TOKEN_SECRET_NAME = "databricks-token" +DEFAULT_OIDC_AUDIENCE = "databricks" +DEFAULT_PROJECTED_SA_TOKEN_PATH = "/var/run/secrets/databricks/token" + +TOKEN_REFRESH_BUFFER_SECONDS = 60 +TOKEN_ENDPOINT_MAX_RETRIES = 3 +TOKEN_ENDPOINT_BACKOFF_BASE_SECONDS = 0.2 + +# Per-namespace SA discovery cache: avoid listing SAs on every create() call. +NS_DISCOVERY_CACHE_TTL_SECONDS = 300 + +VALID_AUTH_TYPES = {"pat", "oauth_m2m", "oidc_federation"} + + +class DatabricksAuthError(Exception): + """Raised when Databricks authentication cannot be obtained.""" + + +# =========================================================================== # +# Settings # +# =========================================================================== # + + +@dataclass +class _Settings: + """Resolved auth settings for a single task. + + Resolution for each field is ``task cfg`` -> ``connector env`` -> ``default``. + + For OIDC Model 2 the workflow-namespace SA name and per-namespace ``client_id`` + do not appear here - they are auto-discovered at submit time from labelled SA + annotations in the workflow namespace (see :class:`_DiscoveredOIDCConfig`). + """ + + auth_type: Optional[str] + client_id: Optional[str] + oauth_secret_name: str + token_secret_name: str + oidc_token_file: Optional[str] + oidc_audience: str + namespace: Optional[str] + + @staticmethod + def from_task(task_template: Optional[TaskTemplate], namespace: Optional[str]) -> "_Settings": + custom: Dict[str, Any] = task_template.custom if task_template is not None else {} + + def _pick(task_key: str, env_key: Optional[str], default: Optional[str] = None) -> Optional[str]: + v = custom.get(task_key) + if v: + return v + if env_key: + env_v = os.getenv(env_key) + if env_v: + return env_v + return default + + return _Settings( + auth_type=_pick("databricksAuthType", FLYTE_DATABRICKS_AUTH_TYPE_ENV), + client_id=_pick("databricksClientId", DATABRICKS_CLIENT_ID_ENV), + oauth_secret_name=_pick( + "databricksOauthSecret", FLYTE_DATABRICKS_OAUTH_SECRET_NAME_ENV, DEFAULT_OAUTH_SECRET_NAME + ) + or DEFAULT_OAUTH_SECRET_NAME, + token_secret_name=_pick( + "databricksTokenSecret", FLYTE_DATABRICKS_TOKEN_SECRET_NAME_ENV, DEFAULT_TOKEN_SECRET_NAME + ) + or DEFAULT_TOKEN_SECRET_NAME, + oidc_token_file=_pick("databricksOidcTokenFile", FLYTE_DATABRICKS_OIDC_TOKEN_FILE_ENV), + oidc_audience=_pick("databricksOidcAudience", FLYTE_DATABRICKS_OIDC_AUDIENCE_ENV, DEFAULT_OIDC_AUDIENCE) + or DEFAULT_OIDC_AUDIENCE, + namespace=namespace, + ) + + +@dataclass +class _DiscoveredOIDCConfig: + """A federation identity discovered from an annotated workflow-namespace SA.""" + + service_account: str + client_id: str + audience: str + + +def _resolve_oidc_token_file(settings: _Settings) -> Optional[str]: + """Return the first subject-JWT file path that exists, or ``None``. + + Order: explicit cfg/env override -> IRSA-injected path -> well-known projected SA path. + """ + if settings.oidc_token_file and os.path.exists(settings.oidc_token_file): + return settings.oidc_token_file + irsa = os.getenv(AWS_WEB_IDENTITY_TOKEN_FILE_ENV) + if irsa and os.path.exists(irsa): + return irsa + if os.path.exists(DEFAULT_PROJECTED_SA_TOKEN_PATH): + return DEFAULT_PROJECTED_SA_TOKEN_PATH + return None + + +# =========================================================================== # +# Token cache # +# =========================================================================== # + + +@dataclass +class _CachedToken: + access_token: str + expires_at_unix: float + + +class _TokenCache: + """Async-safe in-memory cache of Databricks bearer tokens. + + Keys are ``(workspace_url, client_id, subject_identity)``. Entries expire + ``TOKEN_REFRESH_BUFFER_SECONDS`` early so callers never receive a token + that is about to expire. + """ + + def __init__(self) -> None: + self._store: Dict[Tuple[str, str, str], _CachedToken] = {} + self._lock = asyncio.Lock() + + @staticmethod + def _now() -> float: + return time.time() + + async def get(self, key: Tuple[str, str, str]) -> Optional[str]: + async with self._lock: + entry = self._store.get(key) + if entry is None: + return None + if entry.expires_at_unix - self._now() < TOKEN_REFRESH_BUFFER_SECONDS: + self._store.pop(key, None) + return None + return entry.access_token + + async def put(self, key: Tuple[str, str, str], access_token: str, expires_in: int) -> None: + async with self._lock: + self._store[key] = _CachedToken( + access_token=access_token, + expires_at_unix=self._now() + max(int(expires_in) - TOKEN_REFRESH_BUFFER_SECONDS, 30), + ) + + async def invalidate(self, key: Tuple[str, str, str]) -> None: + async with self._lock: + self._store.pop(key, None) + + +# One cache per connector pod. +_TOKEN_CACHE = _TokenCache() + + +# =========================================================================== # +# Per-namespace OIDC discovery cache # +# =========================================================================== # + + +@dataclass +class _CachedDiscovery: + config: Optional[_DiscoveredOIDCConfig] # ``None`` is a valid cached "nothing here" answer + expires_at_unix: float + + +class _NSDiscoveryCache: + """TTL-bounded cache of per-namespace OIDC SA discovery results. + + Caches both hits and misses (a namespace with no annotated SA stays a miss for the + TTL duration). Operators who change SA annotations and want immediate effect can + restart the connector pod. + """ + + def __init__(self) -> None: + self._store: Dict[str, _CachedDiscovery] = {} + self._lock = asyncio.Lock() + + @staticmethod + def _now() -> float: + return time.time() + + async def get(self, namespace: str) -> Tuple[bool, Optional[_DiscoveredOIDCConfig]]: + """Return ``(is_cached, config_or_None)``.""" + async with self._lock: + entry = self._store.get(namespace) + if entry is None or entry.expires_at_unix <= self._now(): + self._store.pop(namespace, None) + return False, None + return True, entry.config + + async def put(self, namespace: str, config: Optional[_DiscoveredOIDCConfig]) -> None: + async with self._lock: + self._store[namespace] = _CachedDiscovery( + config=config, + expires_at_unix=self._now() + NS_DISCOVERY_CACHE_TTL_SECONDS, + ) + + async def invalidate(self, namespace: str) -> None: + async with self._lock: + self._store.pop(namespace, None) + + +_NS_DISCOVERY_CACHE = _NSDiscoveryCache() + + +# =========================================================================== # +# Low-level helpers: Databricks token endpoint + Kubernetes TokenRequest # +# =========================================================================== # + + +async def _post_oidc_token( + session: "aiohttp.ClientSession", # type: ignore[name-defined] + workspace_url: str, + form: Dict[str, str], +) -> Dict[str, Any]: + """POST to ``https:///oidc/v1/token`` with retries. + + Retries 429/500/502/503/504 and transient network errors with exponential + backoff plus jitter. Fails fast on 400/401/403/404 and surfaces the + server-provided error body (minus any secrets, which the body never + contains for this endpoint). + """ + url = f"https://{workspace_url.rstrip('/')}/oidc/v1/token" + last_err: Optional[str] = None + for attempt in range(TOKEN_ENDPOINT_MAX_RETRIES): + try: + async with session.post( + url, + data=form, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) as resp: + body_text = await resp.text() + if resp.status == 200: + try: + return _json.loads(body_text) + except Exception as e: + raise DatabricksAuthError(f"Databricks token endpoint returned 200 but body was not JSON: {e}") + if resp.status in (429, 500, 502, 503, 504): + last_err = f"HTTP {resp.status}: {body_text[:500]}" + else: + raise DatabricksAuthError( + f"Databricks token endpoint returned HTTP {resp.status}: {body_text[:500]}" + ) + except aiohttp.ClientError as e: # type: ignore[attr-defined] + last_err = f"network error: {e}" + except asyncio.TimeoutError: + last_err = "timeout" + if attempt < TOKEN_ENDPOINT_MAX_RETRIES - 1: + await asyncio.sleep(TOKEN_ENDPOINT_BACKOFF_BASE_SECONDS * (2**attempt) + random.uniform(0, 0.1)) + raise DatabricksAuthError( + f"Databricks token endpoint failed after {TOKEN_ENDPOINT_MAX_RETRIES} attempts: {last_err}" + ) + + +def _discover_namespace_oidc_sa_sync(namespace: str, default_audience: str) -> Optional[_DiscoveredOIDCConfig]: + """Look for a labelled+annotated ServiceAccount in ``namespace`` (synchronous). + + Returns the discovered config when exactly one SA in the namespace carries + ``flyte.org/databricks-enabled=true`` and a non-empty + ``flyte.org/databricks-client-id`` annotation. Returns ``None`` when zero + matches are found (caller decides whether to fall back to Model 1). Raises + :class:`DatabricksAuthError` when the namespace contains more than one + candidate SA - silent "first wins" would mask configuration mistakes. + """ + from .connector import list_serviceaccounts_in_k8s # lazy: avoid import cycle + + sas = list_serviceaccounts_in_k8s(namespace=namespace, label_selector=DATABRICKS_ENABLED_LABEL_SELECTOR) + + matches: list = [] + for sa in sas: + annotations = (sa.metadata.annotations or {}) if sa.metadata is not None else {} + client_id = (annotations.get(ANNOTATION_DATABRICKS_CLIENT_ID) or "").strip() + if not client_id: + continue + audience = (annotations.get(ANNOTATION_DATABRICKS_AUDIENCE) or "").strip() or default_audience + matches.append( + _DiscoveredOIDCConfig( + service_account=sa.metadata.name, + client_id=client_id, + audience=audience, + ) + ) + + if not matches: + return None + if len(matches) > 1: + names = ", ".join(sorted(m.service_account for m in matches)) + raise DatabricksAuthError( + f"Ambiguous OIDC federation configuration in namespace '{namespace}': " + f"multiple ServiceAccounts carry both the '{LABEL_DATABRICKS_ENABLED}=true' label " + f"and a '{ANNOTATION_DATABRICKS_CLIENT_ID}' annotation: [{names}]. " + "Annotate exactly one SA per namespace." + ) + return matches[0] + + +async def _discover_namespace_oidc_sa(namespace: str, default_audience: str) -> Optional[_DiscoveredOIDCConfig]: + """TTL-cached async wrapper around :func:`_discover_namespace_oidc_sa_sync`.""" + cached, value = await _NS_DISCOVERY_CACHE.get(namespace) + if cached: + return value + loop = asyncio.get_event_loop() + discovered = await loop.run_in_executor(None, _discover_namespace_oidc_sa_sync, namespace, default_audience) + await _NS_DISCOVERY_CACHE.put(namespace, discovered) + return discovered + + +def _request_sa_token(namespace: str, service_account: str, audience: str, expiration_seconds: int = 3600) -> str: + """Mint a JWT for ``namespace/service_account`` via the Kubernetes TokenRequest API. + + Required RBAC on the connector's ServiceAccount (sample in the plugin README):: + + apiGroups: [""] + resources: ["serviceaccounts/token"] + verbs: ["create"] + """ + try: + from kubernetes import client, config + from kubernetes.client.exceptions import ApiException + except ImportError as e: + raise DatabricksAuthError( + f"kubernetes python client required for OIDC Model 2 (annotated workflow-namespace SA) but not installed: {e}" + ) + try: + config.load_incluster_config() + except Exception: + try: + config.load_kube_config() + except Exception as e: + raise DatabricksAuthError(f"unable to load Kubernetes config for TokenRequest: {e}") + v1 = client.CoreV1Api() + body = client.AuthenticationV1TokenRequest( + spec=client.V1TokenRequestSpec(audiences=[audience], expiration_seconds=expiration_seconds) + ) + try: + resp = v1.create_namespaced_service_account_token(name=service_account, namespace=namespace, body=body) + return resp.status.token + except ApiException as e: + if e.status == 404: + raise DatabricksAuthError( + f"ServiceAccount '{namespace}/{service_account}' not found. " + "The discovery cache is stale - the SA was annotated but has since been deleted. " + "Re-create the SA or remove the 'flyte.org/databricks-enabled' label." + ) + if e.status == 403: + raise DatabricksAuthError( + f"Forbidden creating a token for '{namespace}/{service_account}'. " + "The connector's ServiceAccount needs RBAC: verbs=['create'] on resources=['serviceaccounts/token']. " + "See the 'Authentication' section of the flytekit-spark README for a sample ClusterRole." + ) + raise DatabricksAuthError(f"Kubernetes TokenRequest for '{namespace}/{service_account}' failed: {e}") + + +# =========================================================================== # +# Strategy interface + concrete strategies # +# =========================================================================== # + + +class DatabricksAuth(ABC): + """Abstract base class for Databricks auth strategies.""" + + auth_type: str = "unknown" + strategy_name: str = "DatabricksAuth" + + def __init__(self, workspace_url: str, settings: _Settings): + self.workspace_url = workspace_url + self.settings = settings + + @abstractmethod + async def get_bearer_token(self, session: "aiohttp.ClientSession") -> str: # type: ignore[name-defined] + ... + + @property + @abstractmethod + def cache_key(self) -> Tuple[str, str, str]: ... + + async def invalidate_cache(self) -> None: + await _TOKEN_CACHE.invalidate(self.cache_key) + + def describe(self) -> str: + """One-line human-readable description, safe to log (no secrets).""" + cid = self.settings.client_id or "N/A" + masked = f"{cid[:4]}...{cid[-4:]}" if len(cid) > 12 else cid + return ( + f"strategy={self.strategy_name} auth_type={self.auth_type} " + f"client_id={masked} audience={self.settings.oidc_audience} " + f"namespace={self.settings.namespace or 'N/A'}" + ) + + +class PATAuth(DatabricksAuth): + """Personal Access Token: delegates to the existing multi-tenant PAT flow.""" + + auth_type = "pat" + strategy_name = "PATAuth" + + @property + def cache_key(self) -> Tuple[str, str, str]: + return (self.workspace_url, "pat", self.settings.namespace or "_") + + async def get_bearer_token(self, session: "aiohttp.ClientSession") -> str: # type: ignore[name-defined] + # Lazy import to avoid a circular dependency with connector.py. + from .connector import get_databricks_token + + return get_databricks_token( + namespace=self.settings.namespace, + secret_name=self.settings.token_secret_name, + ) + + async def invalidate_cache(self) -> None: + # PATs are long-lived; a 401 is a misconfig, not a stale-cache issue. + return + + +class OAuthM2MAuth(DatabricksAuth): + """OAuth M2M (``grant_type=client_credentials``).""" + + auth_type = "oauth_m2m" + strategy_name = "OAuthM2MAuth" + + @property + def cache_key(self) -> Tuple[str, str, str]: + return ( + self.workspace_url, + self.settings.client_id or "", + f"m2m:{self.settings.namespace or '_'}", + ) + + def _resolve_creds(self) -> Tuple[str, str]: + from .connector import get_secret_from_k8s # lazy import for the same reason as above + + client_id = self.settings.client_id + client_secret: Optional[str] = None + + if self.settings.namespace: + cid_from_secret = get_secret_from_k8s( + secret_name=self.settings.oauth_secret_name, + secret_key="client_id", + namespace=self.settings.namespace, + ) + csecret_from_secret = get_secret_from_k8s( + secret_name=self.settings.oauth_secret_name, + secret_key="client_secret", + namespace=self.settings.namespace, + ) + if cid_from_secret: + client_id = client_id or cid_from_secret + if csecret_from_secret: + client_secret = csecret_from_secret + + if client_id is None: + client_id = os.getenv(DATABRICKS_CLIENT_ID_ENV) + if client_secret is None: + client_secret = os.getenv(DATABRICKS_CLIENT_SECRET_ENV) + + if not client_id: + raise DatabricksAuthError( + "OAuth M2M selected but no client_id found. Set databricks_client_id on the task, " + f"{DATABRICKS_CLIENT_ID_ENV} on the connector, or add 'client_id' to the " + f"'{self.settings.oauth_secret_name}' k8s secret in the workflow namespace." + ) + if not client_secret: + raise DatabricksAuthError( + "OAuth M2M selected but no client_secret found. Set " + f"{DATABRICKS_CLIENT_SECRET_ENV} on the connector, or add 'client_secret' to the " + f"'{self.settings.oauth_secret_name}' k8s secret in the workflow namespace " + f"('{self.settings.namespace or 'N/A'}')." + ) + return client_id, client_secret + + async def get_bearer_token(self, session: "aiohttp.ClientSession") -> str: # type: ignore[name-defined] + cached = await _TOKEN_CACHE.get(self.cache_key) + if cached is not None: + return cached + client_id, client_secret = self._resolve_creds() + payload = await _post_oidc_token( + session, + self.workspace_url, + form={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + "scope": "all-apis", + }, + ) + access_token = payload.get("access_token") + expires_in = int(payload.get("expires_in", 3600)) + if not access_token: + raise DatabricksAuthError(f"Databricks M2M token response missing 'access_token': {payload}") + await _TOKEN_CACHE.put(self.cache_key, access_token, expires_in) + return access_token + + +class OIDCConnectorIRSAAuth(DatabricksAuth): + """OIDC Federation Model 1: exchanges the connector pod's own projected JWT.""" + + auth_type = "oidc_federation" + strategy_name = "OIDCConnectorIRSAAuth" + + @property + def cache_key(self) -> Tuple[str, str, str]: + return (self.workspace_url, self.settings.client_id or "", "irsa:connector") + + def _read_subject_jwt(self) -> str: + path = _resolve_oidc_token_file(self.settings) + if not path: + raise DatabricksAuthError( + "OIDC federation selected but no subject token file found. Looked in " + "databricks_oidc_token_file task cfg, FLYTE_DATABRICKS_OIDC_TOKEN_FILE env, " + "AWS_WEB_IDENTITY_TOKEN_FILE env, and " + f"{DEFAULT_PROJECTED_SA_TOKEN_PATH}. Configure IRSA on the connector pod or mount a " + "projected ServiceAccount token." + ) + try: + with open(path, "r") as f: + return f.read().strip() + except OSError as e: + raise DatabricksAuthError(f"Failed to read OIDC subject token from {path}: {e}") + + async def get_bearer_token(self, session: "aiohttp.ClientSession") -> str: # type: ignore[name-defined] + cached = await _TOKEN_CACHE.get(self.cache_key) + if cached is not None: + return cached + if not self.settings.client_id: + raise DatabricksAuthError( + "OIDC federation selected but no client_id. Set databricks_client_id on the task " + f"or {DATABRICKS_CLIENT_ID_ENV} on the connector." + ) + # Re-read on every refresh: projected/IRSA tokens rotate. + subject_jwt = self._read_subject_jwt() + payload = await _post_oidc_token( + session, + self.workspace_url, + form={ + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "subject_token": subject_jwt, + "client_id": self.settings.client_id, + "scope": "all-apis", + }, + ) + access_token = payload.get("access_token") + expires_in = int(payload.get("expires_in", 3600)) + if not access_token: + raise DatabricksAuthError(f"Databricks OIDC token response missing 'access_token': {payload}") + await _TOKEN_CACHE.put(self.cache_key, access_token, expires_in) + return access_token + + +class OIDCNamespaceSAAuth(DatabricksAuth): + """OIDC Federation Model 2: mints a JWT for a workflow-namespace SA via TokenRequest. + + The (service_account, client_id, audience) tuple is supplied via :class:`_DiscoveredOIDCConfig` + rather than read from :class:`_Settings` - it is auto-discovered from the workflow + namespace's annotated SA in :func:`select_auth`, or rebuilt from persisted metadata in + :func:`build_auth`. + """ + + auth_type = "oidc_federation" + strategy_name = "OIDCNamespaceSAAuth" + + def __init__(self, workspace_url: str, settings: _Settings, discovered: _DiscoveredOIDCConfig): + super().__init__(workspace_url, settings) + self.discovered = discovered + + @property + def cache_key(self) -> Tuple[str, str, str]: + ns = self.settings.namespace or "_" + return ( + self.workspace_url, + self.discovered.client_id, + f"sa:{ns}/{self.discovered.service_account}", + ) + + def describe(self) -> str: + cid = self.discovered.client_id + masked = f"{cid[:4]}...{cid[-4:]}" if len(cid) > 12 else cid + return ( + f"strategy={self.strategy_name} auth_type={self.auth_type} " + f"client_id={masked} audience={self.discovered.audience} " + f"namespace={self.settings.namespace or 'N/A'} " + f"service_account={self.discovered.service_account}" + ) + + async def get_bearer_token(self, session: "aiohttp.ClientSession") -> str: # type: ignore[name-defined] + cached = await _TOKEN_CACHE.get(self.cache_key) + if cached is not None: + return cached + if not self.settings.namespace: + raise DatabricksAuthError( + "OIDC Model 2 requires a workflow namespace, but task_execution_metadata.namespace is None." + ) + loop = asyncio.get_event_loop() + subject_jwt: str = await loop.run_in_executor( + None, + _request_sa_token, + self.settings.namespace, + self.discovered.service_account, + self.discovered.audience, + ) + payload = await _post_oidc_token( + session, + self.workspace_url, + form={ + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "subject_token": subject_jwt, + "client_id": self.discovered.client_id, + "scope": "all-apis", + }, + ) + access_token = payload.get("access_token") + expires_in = int(payload.get("expires_in", 3600)) + if not access_token: + raise DatabricksAuthError(f"Databricks OIDC token response missing 'access_token': {payload}") + await _TOKEN_CACHE.put(self.cache_key, access_token, expires_in) + return access_token + + async def invalidate_cache(self) -> None: + # Wipe both the bearer cache and the SA-discovery cache so the next call re-reads + # the SA annotation in case the operator just rotated the federation policy. + await _TOKEN_CACHE.invalidate(self.cache_key) + if self.settings.namespace: + await _NS_DISCOVERY_CACHE.invalidate(self.settings.namespace) + + +# =========================================================================== # +# Selection # +# =========================================================================== # + + +async def select_auth( + task_template: Optional[TaskTemplate], + workspace_url: str, + namespace: Optional[str] = None, +) -> DatabricksAuth: + """Build the :class:`DatabricksAuth` strategy for a single task. + + Resolution order: + + 1. Explicit ``databricks_auth_type`` from task config, or + ``FLYTE_DATABRICKS_AUTH_TYPE`` env var on the connector. + 2. If unset, auto-detect in the order OIDC federation -> OAuth M2M -> PAT, + based on which credentials / tokens are reachable. PAT is the last-resort + default so existing deployments keep working with zero config. + + When the resolved type is ``oidc_federation``: + + - Look in the workflow namespace for a ServiceAccount labelled + ``flyte.org/databricks-enabled=true`` and annotated with + ``flyte.org/databricks-client-id``. If exactly one is found, return + Model 2 (:class:`OIDCNamespaceSAAuth`) federated as that SA. + - If none are found, fall back to Model 1 (:class:`OIDCConnectorIRSAAuth`) + using the connector pod's own IRSA JWT - only when the connector has both + a reachable subject token file and a ``DATABRICKS_CLIENT_ID``. Otherwise + fail with :class:`DatabricksAuthError` so the misconfiguration is + surfaced rather than silently downgraded. + - If multiple annotated SAs exist in the namespace, fail loudly. + """ + settings = _Settings.from_task(task_template, namespace) + + auth_type = settings.auth_type or _auto_detect(settings) + + if auth_type not in VALID_AUTH_TYPES: + raise DatabricksAuthError( + f"Invalid databricks auth_type '{auth_type}'. Expected one of: {sorted(VALID_AUTH_TYPES)}." + ) + + if auth_type == "pat": + return PATAuth(workspace_url, settings) + if auth_type == "oauth_m2m": + return OAuthM2MAuth(workspace_url, settings) + + # auth_type == "oidc_federation": annotation-driven discovery picks Model 1 vs Model 2. + discovered: Optional[_DiscoveredOIDCConfig] = None + if namespace: + discovered = await _discover_namespace_oidc_sa(namespace, settings.oidc_audience) + if discovered is not None: + return OIDCNamespaceSAAuth(workspace_url, settings, discovered) + + # Fall back to Model 1 only when its config is complete; otherwise fail loudly. + if not settings.client_id: + raise DatabricksAuthError( + "OIDC federation selected but no Model 2 SA was discovered in namespace " + f"'{namespace or 'N/A'}' and Model 1 cannot run without a client_id. " + f"Either annotate a ServiceAccount with '{LABEL_DATABRICKS_ENABLED}=true' + " + f"'{ANNOTATION_DATABRICKS_CLIENT_ID}=' in the workflow namespace, " + f"or set {DATABRICKS_CLIENT_ID_ENV} on the connector for Model 1 fallback." + ) + if _resolve_oidc_token_file(settings) is None: + raise DatabricksAuthError( + "OIDC federation selected but no Model 2 SA was discovered in namespace " + f"'{namespace or 'N/A'}' and Model 1 cannot run without a subject token file. " + "Either annotate a workflow-namespace ServiceAccount, or configure IRSA on the " + "connector pod (AWS_WEB_IDENTITY_TOKEN_FILE) / mount a projected SA token at " + f"{DEFAULT_PROJECTED_SA_TOKEN_PATH}." + ) + return OIDCConnectorIRSAAuth(workspace_url, settings) + + +def build_auth( + workspace_url: str, + auth_type: str, + namespace: Optional[str] = None, + client_id: Optional[str] = None, + oauth_secret_name: Optional[str] = None, + token_secret_name: Optional[str] = None, + oidc_token_file: Optional[str] = None, + oidc_audience: Optional[str] = None, + oidc_service_account: Optional[str] = None, +) -> DatabricksAuth: + """Build a :class:`DatabricksAuth` directly from an explicit context. + + Used by the connector's ``get`` / ``delete`` paths to reconstruct the auth + strategy for long-running jobs from persisted metadata, without re-reading + the task template, environment, or running namespace discovery. + + For OIDC Model 2 the caller passes the previously-discovered SA name via + ``oidc_service_account``. When that is set, this function builds Model 2 + directly with the supplied (sa, client_id, audience) tuple. Without it, + Model 1 is built when ``auth_type=oidc_federation``. + """ + if auth_type not in VALID_AUTH_TYPES: + raise DatabricksAuthError( + f"Invalid databricks auth_type '{auth_type}'. Expected one of: {sorted(VALID_AUTH_TYPES)}." + ) + settings = _Settings( + auth_type=auth_type, + client_id=client_id, + oauth_secret_name=oauth_secret_name or DEFAULT_OAUTH_SECRET_NAME, + token_secret_name=token_secret_name or DEFAULT_TOKEN_SECRET_NAME, + oidc_token_file=oidc_token_file, + oidc_audience=oidc_audience or DEFAULT_OIDC_AUDIENCE, + namespace=namespace, + ) + if auth_type == "pat": + return PATAuth(workspace_url, settings) + if auth_type == "oauth_m2m": + return OAuthM2MAuth(workspace_url, settings) + if oidc_service_account: + if not client_id: + raise DatabricksAuthError( + "Cannot rebuild OIDC Model 2 strategy without client_id " + "(persisted metadata is missing the discovered Databricks SP)." + ) + discovered = _DiscoveredOIDCConfig( + service_account=oidc_service_account, + client_id=client_id, + audience=settings.oidc_audience, + ) + return OIDCNamespaceSAAuth(workspace_url, settings, discovered) + return OIDCConnectorIRSAAuth(workspace_url, settings) + + +def _auto_detect(settings: _Settings) -> str: + """Return ``oidc_federation`` / ``oauth_m2m`` / ``pat`` based on reachable resources.""" + subject_file = _resolve_oidc_token_file(settings) + has_client_id = bool(settings.client_id or os.getenv(DATABRICKS_CLIENT_ID_ENV)) + if subject_file and has_client_id: + return "oidc_federation" + has_client_secret = bool(os.getenv(DATABRICKS_CLIENT_SECRET_ENV)) + if has_client_id and has_client_secret: + return "oauth_m2m" + return "pat" + + +def validate_connector_config() -> None: + """Fail-fast connector startup check. + + Emits a single structured log line describing the default auth mode. Raises + :class:`DatabricksAuthError` on clearly-invalid configurations (unknown + auth_type); emits warnings for configurations that are incomplete but may + still work via per-namespace discovery (OIDC) or per-namespace secrets (M2M). + """ + env_auth = os.getenv(FLYTE_DATABRICKS_AUTH_TYPE_ENV) + if env_auth is None: + logger.info( + "Databricks connector auth: no %s set; per-task auth type will be auto-detected.", + FLYTE_DATABRICKS_AUTH_TYPE_ENV, + ) + return + if env_auth not in VALID_AUTH_TYPES: + raise DatabricksAuthError( + f"{FLYTE_DATABRICKS_AUTH_TYPE_ENV}='{env_auth}' is not valid. " + f"Expected one of: {sorted(VALID_AUTH_TYPES)}." + ) + cid = os.getenv(DATABRICKS_CLIENT_ID_ENV) + if env_auth == "oauth_m2m" and not cid: + logger.warning( + "Databricks connector auth: %s=oauth_m2m but %s is not set. " + "Each workflow namespace must supply 'client_id' in the OAuth secret.", + FLYTE_DATABRICKS_AUTH_TYPE_ENV, + DATABRICKS_CLIENT_ID_ENV, + ) + if env_auth == "oidc_federation" and not cid: + logger.info( + "Databricks connector auth: %s=oidc_federation with no %s. " + "Model 2 will be used when a workflow namespace has an annotated " + "ServiceAccount; Model 1 fallback is disabled (no connector-level client_id).", + FLYTE_DATABRICKS_AUTH_TYPE_ENV, + DATABRICKS_CLIENT_ID_ENV, + ) + logger.info( + "Databricks connector auth: default auth_type=%s (per-task overrides still apply).", + env_auth, + ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 6c447e8bd4..68ae78371e 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -92,6 +92,22 @@ class DatabricksV2(Spark): notebook_path (Optional[str]): Path to Databricks notebook (e.g., "/Users/user@example.com/notebook"). notebook_base_parameters (Optional[Dict[str, str]]): Parameters to pass to the notebook. + databricks_auth_type (Optional[str]): Per-task override for the authentication mode. + One of ``"pat"``, ``"oauth_m2m"``, ``"oidc_federation"``. When unset the + connector-level ``FLYTE_DATABRICKS_AUTH_TYPE`` env var is used; when that is + also unset the connector auto-detects based on available credentials. + databricks_client_id (Optional[str]): Client ID of the Databricks Service Principal + to use for OAuth M2M and OIDC federation. Falls back to the ``DATABRICKS_CLIENT_ID`` + env var on the connector. + databricks_oauth_secret (Optional[str]): K8s secret name that holds ``client_id`` / + ``client_secret`` keys for OAuth M2M. Defaults to ``databricks-oauth``. + databricks_oidc_token_file (Optional[str]): File path to a subject JWT for OIDC + federation (Model 1). Falls back to ``AWS_WEB_IDENTITY_TOKEN_FILE`` (IRSA) and + ``/var/run/secrets/databricks/token``. + databricks_oidc_audience (Optional[str]): Audience for the OIDC subject token. + Defaults to ``"databricks"``. For Model 2, the per-namespace + ``flyte.org/databricks-audience`` annotation on the discovered ServiceAccount + takes precedence over this value. Compute Modes: The connector auto-detects the compute mode based on the databricks_conf contents: @@ -193,6 +209,62 @@ class DatabricksV2(Spark): notebook_path="/Users/user@example.com/my-notebook", notebook_base_parameters={"param1": "value1"}, ) + + Authentication: + Three auth types are supported. In most deployments you should **not** have to + change anything in the workflow itself: operators set the defaults at the + connector level (env vars + K8s resources) and the same workflow code works + across auth modes. Override fields on this class are only needed when a single + workflow wants to diverge from the connector defaults. + + Order of resolution for each field: task config -> connector env var -> default. + + Example - Default path (connector is configured for OIDC; workflow is unchanged):: + + # Operator configures the connector deployment with: + # FLYTE_DATABRICKS_AUTH_TYPE=oidc_federation + # DATABRICKS_CLIENT_ID= + # plus IRSA on the connector pod's ServiceAccount. + # + # Workflow code stays exactly as it was under PAT - no changes required. + DatabricksV2( + databricks_conf={"run_name": "my-job", "new_cluster": {...}}, + databricks_instance="my-workspace.cloud.databricks.com", + ) + + Example - Per-task OAuth M2M override:: + + DatabricksV2( + databricks_conf={...}, + databricks_instance="my-workspace.cloud.databricks.com", + databricks_auth_type="oauth_m2m", + databricks_client_id="00000000-0000-0000-0000-000000000000", + # client_secret is read from a 'databricks-oauth' k8s secret in the + # workflow namespace, or from the DATABRICKS_CLIENT_SECRET env var + # on the connector. + ) + + Example - OIDC federation with a workflow-namespace ServiceAccount (Model 2):: + + # No workflow code change needed. Operator sets: + # FLYTE_DATABRICKS_AUTH_TYPE=oidc_federation + # Then in the workflow namespace: + # apiVersion: v1 + # kind: ServiceAccount + # metadata: + # name: + # labels: { flyte.org/databricks-enabled: "true" } + # annotations: + # flyte.org/databricks-client-id: "" + # + # The connector auto-discovers the SA at submit time and federates as it. + DatabricksV2( + databricks_conf={...}, + databricks_instance="my-workspace.cloud.databricks.com", + ) + + See the plugin README's "Authentication" section for the full setup guide, + including RBAC for Model 2 and a migration guide from PAT to M2M/OIDC. """ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None @@ -201,6 +273,15 @@ class DatabricksV2(Spark): databricks_token_secret: Optional[str] = None notebook_path: Optional[str] = None notebook_base_parameters: Optional[Dict[str, str]] = None + # --- Authentication overrides ------------------------------------------------- + # All optional. Each falls back to a connector-level env var (see README) and + # ultimately to an auto-detected default so existing workflows keep working + # with zero code change. + databricks_auth_type: Optional[str] = None + databricks_client_id: Optional[str] = None + databricks_oauth_secret: Optional[str] = None + databricks_oidc_token_file: Optional[str] = None + databricks_oidc_audience: Optional[str] = None # This method does not reset the SparkSession since it's a bit hard to handle multiple @@ -318,6 +399,18 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: custom_dict["notebookPath"] = cfg.notebook_path if cfg.notebook_base_parameters: custom_dict["notebookBaseParameters"] = cfg.notebook_base_parameters + # Authentication overrides (only serialise the ones explicitly set so that + # old connectors reading this custom dict see no unknown keys). + if cfg.databricks_auth_type: + custom_dict["databricksAuthType"] = cfg.databricks_auth_type + if cfg.databricks_client_id: + custom_dict["databricksClientId"] = cfg.databricks_client_id + if cfg.databricks_oauth_secret: + custom_dict["databricksOauthSecret"] = cfg.databricks_oauth_secret + if cfg.databricks_oidc_token_file: + custom_dict["databricksOidcTokenFile"] = cfg.databricks_oidc_token_file + if cfg.databricks_oidc_audience: + custom_dict["databricksOidcAudience"] = cfg.databricks_oidc_audience return custom_dict diff --git a/plugins/flytekit-spark/tests/test_connector.py b/plugins/flytekit-spark/tests/test_connector.py index 0fc4effbb9..c630589caa 100644 --- a/plugins/flytekit-spark/tests/test_connector.py +++ b/plugins/flytekit-spark/tests/test_connector.py @@ -125,6 +125,10 @@ async def test_databricks_agent(task_template: TaskTemplate): databricks_instance="test-account.cloud.databricks.com", run_id="123", auth_token=mocked_token, + auth_type="pat", + oauth_secret_name="databricks-oauth", + token_secret_name="databricks-token", + oidc_audience="databricks", ) mock_create_response = {"run_id": "123"} @@ -187,6 +191,10 @@ async def test_agent_create_with_default_instance(task_template: TaskTemplate): databricks_instance="test-account.cloud.databricks.com", run_id="123", auth_token=mocked_token, + auth_type="pat", + oauth_secret_name="databricks-oauth", + token_secret_name="databricks-token", + oidc_audience="databricks", ) mock_create_response = {"run_id": "123"} @@ -609,6 +617,10 @@ async def test_databricks_agent_serverless(serverless_task_template_with_env_key databricks_instance="test-account.cloud.databricks.com", run_id="456", auth_token=mocked_token, + auth_type="pat", + oauth_secret_name="databricks-oauth", + token_secret_name="databricks-token", + oidc_audience="databricks", ) mock_create_response = {"run_id": "456"} diff --git a/plugins/flytekit-spark/tests/test_databricks_auth.py b/plugins/flytekit-spark/tests/test_databricks_auth.py new file mode 100644 index 0000000000..446e93bf08 --- /dev/null +++ b/plugins/flytekit-spark/tests/test_databricks_auth.py @@ -0,0 +1,858 @@ +"""Tests for the flytekit-spark Databricks auth strategy module. + +Covers: + +- ``_Settings.from_task``: task-cfg + env-var + default resolution +- ``_resolve_oidc_token_file``: subject JWT discovery order +- ``_TokenCache``: async set/get/expire/invalidate +- ``_post_oidc_token``: retries on 429/5xx, fail-fast on 4xx +- ``_auto_detect``: auto-detection precedence +- ``select_auth`` and ``build_auth``: strategy dispatch +- Each strategy's ``get_bearer_token`` happy-path + error paths +- ``OIDCNamespaceSAAuth``: Kubernetes ``TokenRequest`` success / 404 / 403 +- Backward compatibility: legacy metadata (``auth_type=None``) continues to work +""" + +import asyncio +import time +from dataclasses import dataclass +from typing import Any, Dict +from unittest.mock import MagicMock, patch + +import pytest +from aiohttp import ClientSession +from aioresponses import aioresponses +from flytekitplugins.spark.databricks_auth import ( + ANNOTATION_DATABRICKS_AUDIENCE, + ANNOTATION_DATABRICKS_CLIENT_ID, + DATABRICKS_ENABLED_LABEL_SELECTOR, + DEFAULT_OAUTH_SECRET_NAME, + DEFAULT_OIDC_AUDIENCE, + DEFAULT_TOKEN_SECRET_NAME, + LABEL_DATABRICKS_ENABLED, + VALID_AUTH_TYPES, + DatabricksAuthError, + OAuthM2MAuth, + OIDCConnectorIRSAAuth, + OIDCNamespaceSAAuth, + PATAuth, + _auto_detect, + _DiscoveredOIDCConfig, + _discover_namespace_oidc_sa_sync, + _NS_DISCOVERY_CACHE, + _post_oidc_token, + _resolve_oidc_token_file, + _Settings, + _TokenCache, + build_auth, + select_auth, + validate_connector_config, +) + + +# --------------------------------------------------------------------------- # +# Minimal TaskTemplate-compatible stub # +# --------------------------------------------------------------------------- # + + +@dataclass +class _FakeTaskTemplate: + custom: Dict[str, Any] + + +def _tt(**custom) -> _FakeTaskTemplate: + return _FakeTaskTemplate(custom=dict(custom)) + + +def _fake_sa(name: str, client_id: str = None, audience: str = None) -> MagicMock: + """Build a minimal stand-in for ``kubernetes.client.V1ServiceAccount``. + + Discovery only reads ``.metadata.name`` and ``.metadata.annotations`` - + construct just those two attributes. + """ + annotations = {} + if client_id is not None: + annotations[ANNOTATION_DATABRICKS_CLIENT_ID] = client_id + if audience is not None: + annotations[ANNOTATION_DATABRICKS_AUDIENCE] = audience + sa = MagicMock() + sa.metadata = MagicMock() + sa.metadata.name = name + sa.metadata.annotations = annotations or None + return sa + + +@pytest.fixture(autouse=True) +def _clear_ns_discovery_cache(): + """Reset the discovery cache between tests so cached results don't leak.""" + _NS_DISCOVERY_CACHE._store.clear() + yield + _NS_DISCOVERY_CACHE._store.clear() + + +# =========================================================================== # +# _Settings # +# =========================================================================== # + + +class TestSettings: + def test_task_cfg_wins_over_env_and_default(self, monkeypatch): + monkeypatch.setenv("FLYTE_DATABRICKS_AUTH_TYPE", "oauth_m2m") + monkeypatch.setenv("DATABRICKS_CLIENT_ID", "from-env") + s = _Settings.from_task( + _tt(databricksAuthType="pat", databricksClientId="from-cfg"), + namespace="ns-a", + ) + assert s.auth_type == "pat" + assert s.client_id == "from-cfg" + assert s.namespace == "ns-a" + + def test_env_used_when_cfg_absent(self, monkeypatch): + monkeypatch.setenv("FLYTE_DATABRICKS_AUTH_TYPE", "oidc_federation") + monkeypatch.setenv("DATABRICKS_CLIENT_ID", "env-cid") + s = _Settings.from_task(_tt(), namespace=None) + assert s.auth_type == "oidc_federation" + assert s.client_id == "env-cid" + + def test_defaults_populated(self, monkeypatch): + for v in [ + "FLYTE_DATABRICKS_AUTH_TYPE", + "FLYTE_DATABRICKS_OAUTH_SECRET_NAME", + "FLYTE_DATABRICKS_TOKEN_SECRET_NAME", + "FLYTE_DATABRICKS_OIDC_AUDIENCE", + "DATABRICKS_CLIENT_ID", + ]: + monkeypatch.delenv(v, raising=False) + s = _Settings.from_task(_tt(), namespace=None) + assert s.auth_type is None + assert s.client_id is None + assert s.oauth_secret_name == DEFAULT_OAUTH_SECRET_NAME + assert s.token_secret_name == DEFAULT_TOKEN_SECRET_NAME + assert s.oidc_audience == DEFAULT_OIDC_AUDIENCE + + +# =========================================================================== # +# _resolve_oidc_token_file # +# =========================================================================== # + + +class TestResolveOidcTokenFile: + def test_explicit_cfg_wins(self, tmp_path, monkeypatch): + f = tmp_path / "explicit.jwt" + f.write_text("jwt") + monkeypatch.delenv("AWS_WEB_IDENTITY_TOKEN_FILE", raising=False) + s = _Settings.from_task(_tt(databricksOidcTokenFile=str(f)), namespace=None) + assert _resolve_oidc_token_file(s) == str(f) + + def test_falls_back_to_irsa(self, tmp_path, monkeypatch): + f = tmp_path / "irsa.jwt" + f.write_text("irsa-jwt") + monkeypatch.setenv("AWS_WEB_IDENTITY_TOKEN_FILE", str(f)) + s = _Settings.from_task(_tt(), namespace=None) + assert _resolve_oidc_token_file(s) == str(f) + + def test_returns_none_when_nothing_found(self, monkeypatch): + monkeypatch.delenv("AWS_WEB_IDENTITY_TOKEN_FILE", raising=False) + monkeypatch.delenv("FLYTE_DATABRICKS_OIDC_TOKEN_FILE", raising=False) + s = _Settings.from_task(_tt(), namespace=None) + with patch("os.path.exists", return_value=False): + assert _resolve_oidc_token_file(s) is None + + +# =========================================================================== # +# _TokenCache # +# =========================================================================== # + + +class TestTokenCache: + @pytest.mark.asyncio + async def test_put_and_get(self): + cache = _TokenCache() + await cache.put(("w", "c", "s"), "tok", expires_in=3600) + assert await cache.get(("w", "c", "s")) == "tok" + + @pytest.mark.asyncio + async def test_miss(self): + cache = _TokenCache() + assert await cache.get(("none", "none", "none")) is None + + @pytest.mark.asyncio + async def test_expired_eviction(self): + cache = _TokenCache() + await cache.put(("w", "c", "s"), "tok", expires_in=1) # clamped to >=30s buffer + # Force the stored expiry into the past. + cache._store[("w", "c", "s")].expires_at_unix = time.time() - 1 + assert await cache.get(("w", "c", "s")) is None + + @pytest.mark.asyncio + async def test_invalidate(self): + cache = _TokenCache() + await cache.put(("w", "c", "s"), "tok", expires_in=3600) + await cache.invalidate(("w", "c", "s")) + assert await cache.get(("w", "c", "s")) is None + + +# =========================================================================== # +# _post_oidc_token # +# =========================================================================== # + + +class TestPostOidcToken: + @pytest.mark.asyncio + async def test_success(self): + async with ClientSession() as session: + with aioresponses() as m: + m.post( + "https://ws.cloud.databricks.com/oidc/v1/token", + status=200, + payload={"access_token": "A", "expires_in": 3600}, + ) + result = await _post_oidc_token(session, "ws.cloud.databricks.com", {"grant_type": "x"}) + assert result == {"access_token": "A", "expires_in": 3600} + + @pytest.mark.asyncio + async def test_fast_fail_on_401(self): + async with ClientSession() as session: + with aioresponses() as m: + m.post( + "https://ws.cloud.databricks.com/oidc/v1/token", + status=401, + body="bad-client", + ) + with pytest.raises(DatabricksAuthError, match="401"): + await _post_oidc_token(session, "ws.cloud.databricks.com", {"grant_type": "x"}) + + @pytest.mark.asyncio + async def test_retries_on_5xx_then_succeeds(self, monkeypatch): + async with ClientSession() as session: + with aioresponses() as m: + url = "https://ws.cloud.databricks.com/oidc/v1/token" + m.post(url, status=503, body="busy") + m.post(url, status=200, payload={"access_token": "B", "expires_in": 60}) + result = await _post_oidc_token(session, "ws.cloud.databricks.com", {"grant_type": "x"}) + assert result["access_token"] == "B" + + @pytest.mark.asyncio + async def test_exhausts_retries(self): + async with ClientSession() as session: + with aioresponses() as m: + url = "https://ws.cloud.databricks.com/oidc/v1/token" + for _ in range(3): + m.post(url, status=500, body="boom") + with pytest.raises(DatabricksAuthError, match="failed after"): + await _post_oidc_token(session, "ws.cloud.databricks.com", {"grant_type": "x"}) + + +# =========================================================================== # +# _auto_detect # +# =========================================================================== # + + +class TestAutoDetect: + def test_oidc_when_subject_file_and_client_id_present(self, tmp_path, monkeypatch): + f = tmp_path / "tok.jwt" + f.write_text("x") + monkeypatch.setenv("AWS_WEB_IDENTITY_TOKEN_FILE", str(f)) + monkeypatch.setenv("DATABRICKS_CLIENT_ID", "cid") + s = _Settings.from_task(_tt(), namespace=None) + assert _auto_detect(s) == "oidc_federation" + + def test_m2m_when_client_id_and_secret_present(self, monkeypatch): + monkeypatch.setenv("DATABRICKS_CLIENT_ID", "cid") + monkeypatch.setenv("DATABRICKS_CLIENT_SECRET", "csec") + monkeypatch.delenv("AWS_WEB_IDENTITY_TOKEN_FILE", raising=False) + monkeypatch.delenv("FLYTE_DATABRICKS_OIDC_TOKEN_FILE", raising=False) + with patch("os.path.exists", return_value=False): + s = _Settings.from_task(_tt(), namespace=None) + assert _auto_detect(s) == "oauth_m2m" + + def test_pat_when_nothing_else_configured(self, monkeypatch): + for v in [ + "DATABRICKS_CLIENT_ID", + "DATABRICKS_CLIENT_SECRET", + "AWS_WEB_IDENTITY_TOKEN_FILE", + "FLYTE_DATABRICKS_OIDC_TOKEN_FILE", + ]: + monkeypatch.delenv(v, raising=False) + with patch("os.path.exists", return_value=False): + s = _Settings.from_task(_tt(), namespace=None) + assert _auto_detect(s) == "pat" + + +# =========================================================================== # +# select_auth / build_auth # +# =========================================================================== # + + +class TestSelectAuth: + @pytest.mark.asyncio + async def test_explicit_pat(self, monkeypatch): + monkeypatch.delenv("AWS_WEB_IDENTITY_TOKEN_FILE", raising=False) + auth = await select_auth(_tt(databricksAuthType="pat"), "ws", namespace="ns") + assert isinstance(auth, PATAuth) + + @pytest.mark.asyncio + async def test_explicit_m2m(self): + auth = await select_auth(_tt(databricksAuthType="oauth_m2m"), "ws", namespace="ns") + assert isinstance(auth, OAuthM2MAuth) + + @pytest.mark.asyncio + async def test_oidc_model_1_when_no_sa_discovered(self, tmp_path, monkeypatch): + # Provide a reachable IRSA + client_id so Model 1 is a valid fallback. + jwt_path = tmp_path / "irsa.jwt" + jwt_path.write_text("j") + monkeypatch.setenv("AWS_WEB_IDENTITY_TOKEN_FILE", str(jwt_path)) + monkeypatch.setenv("DATABRICKS_CLIENT_ID", "fallback-cid") + with patch( + "flytekitplugins.spark.databricks_auth._discover_namespace_oidc_sa_sync", + return_value=None, + ): + auth = await select_auth(_tt(databricksAuthType="oidc_federation"), "ws", namespace="ns") + assert isinstance(auth, OIDCConnectorIRSAAuth) + + @pytest.mark.asyncio + async def test_oidc_model_2_when_sa_discovered(self, monkeypatch): + discovered = _DiscoveredOIDCConfig( + service_account="dbx-runner", client_id="ns-cid", audience="databricks" + ) + with patch( + "flytekitplugins.spark.databricks_auth._discover_namespace_oidc_sa_sync", + return_value=discovered, + ): + auth = await select_auth( + _tt(databricksAuthType="oidc_federation"), "ws", namespace="team-a" + ) + assert isinstance(auth, OIDCNamespaceSAAuth) + assert auth.discovered.service_account == "dbx-runner" + assert auth.discovered.client_id == "ns-cid" + + @pytest.mark.asyncio + async def test_oidc_no_sa_and_no_fallback_errors(self, monkeypatch): + # Both discovery and Model 1 fallback are unavailable. + monkeypatch.delenv("DATABRICKS_CLIENT_ID", raising=False) + monkeypatch.delenv("AWS_WEB_IDENTITY_TOKEN_FILE", raising=False) + with patch( + "flytekitplugins.spark.databricks_auth._discover_namespace_oidc_sa_sync", + return_value=None, + ), patch("os.path.exists", return_value=False): + with pytest.raises(DatabricksAuthError, match="no Model 2 SA was discovered"): + await select_auth(_tt(databricksAuthType="oidc_federation"), "ws", namespace="ns") + + @pytest.mark.asyncio + async def test_oidc_no_sa_and_no_subject_file_errors(self, monkeypatch): + monkeypatch.setenv("DATABRICKS_CLIENT_ID", "cid") + monkeypatch.delenv("AWS_WEB_IDENTITY_TOKEN_FILE", raising=False) + with patch( + "flytekitplugins.spark.databricks_auth._discover_namespace_oidc_sa_sync", + return_value=None, + ), patch("os.path.exists", return_value=False): + with pytest.raises(DatabricksAuthError, match="without a subject token file"): + await select_auth(_tt(databricksAuthType="oidc_federation"), "ws", namespace="ns") + + @pytest.mark.asyncio + async def test_invalid_auth_type_raises(self): + with pytest.raises(DatabricksAuthError, match="Invalid databricks auth_type"): + await select_auth(_tt(databricksAuthType="not-real"), "ws") + + def test_build_auth_dispatch(self): + assert isinstance(build_auth("ws", "pat"), PATAuth) + assert isinstance(build_auth("ws", "oauth_m2m"), OAuthM2MAuth) + assert isinstance(build_auth("ws", "oidc_federation"), OIDCConnectorIRSAAuth) + assert isinstance( + build_auth( + "ws", + "oidc_federation", + oidc_service_account="sa", + client_id="cid", + namespace="ns", + ), + OIDCNamespaceSAAuth, + ) + + def test_build_auth_model_2_without_client_id_errors(self): + with pytest.raises(DatabricksAuthError, match="without client_id"): + build_auth("ws", "oidc_federation", oidc_service_account="sa", namespace="ns") + + +# =========================================================================== # +# PATAuth # +# =========================================================================== # + + +class TestPATAuth: + @pytest.mark.asyncio + async def test_delegates_to_get_databricks_token(self): + auth = build_auth("ws", "pat", namespace="team-a", token_secret_name="custom-sec") + with patch("flytekitplugins.spark.connector.get_databricks_token") as m: + m.return_value = "dapi_xyz" + async with ClientSession() as session: + tok = await auth.get_bearer_token(session) + assert tok == "dapi_xyz" + m.assert_called_once_with(namespace="team-a", secret_name="custom-sec") + + +# =========================================================================== # +# OAuthM2MAuth # +# =========================================================================== # + + +class TestOAuthM2M: + @pytest.mark.asyncio + async def test_happy_path_with_env_creds(self, monkeypatch): + monkeypatch.setenv("DATABRICKS_CLIENT_SECRET", "csec") + auth = build_auth("ws.cloud.databricks.com", "oauth_m2m", client_id="cid", namespace="ns-a") + with patch("flytekitplugins.spark.connector.get_secret_from_k8s", return_value=None): + async with ClientSession() as session: + with aioresponses() as m: + m.post( + "https://ws.cloud.databricks.com/oidc/v1/token", + status=200, + payload={"access_token": "M2M-TOK", "expires_in": 3600}, + ) + tok = await auth.get_bearer_token(session) + assert tok == "M2M-TOK" + + @pytest.mark.asyncio + async def test_namespace_secret_overrides_env(self, monkeypatch): + monkeypatch.setenv("DATABRICKS_CLIENT_SECRET", "env-secret") + auth = build_auth("ws.cloud.databricks.com", "oauth_m2m", namespace="ns-a") + + def fake_secret(secret_name, secret_key, namespace): + return {"client_id": "ns-cid", "client_secret": "ns-secret"}[secret_key] + + with patch("flytekitplugins.spark.connector.get_secret_from_k8s", side_effect=fake_secret): + posted = {} + + async def _capture(session, url, form): + posted["form"] = form + return {"access_token": "X", "expires_in": 60} + + with patch( + "flytekitplugins.spark.databricks_auth._post_oidc_token", side_effect=_capture + ): + async with ClientSession() as session: + await auth.get_bearer_token(session) + assert posted["form"]["client_id"] == "ns-cid" + assert posted["form"]["client_secret"] == "ns-secret" + + @pytest.mark.asyncio + async def test_missing_client_id_errors(self, monkeypatch): + monkeypatch.delenv("DATABRICKS_CLIENT_ID", raising=False) + monkeypatch.setenv("DATABRICKS_CLIENT_SECRET", "csec") + auth = build_auth("ws", "oauth_m2m") + with patch("flytekitplugins.spark.connector.get_secret_from_k8s", return_value=None): + async with ClientSession() as session: + with pytest.raises(DatabricksAuthError, match="no client_id"): + await auth.get_bearer_token(session) + + @pytest.mark.asyncio + async def test_caches_token(self, monkeypatch): + # Unique cache key per invocation to avoid cross-test contamination. + monkeypatch.setenv("DATABRICKS_CLIENT_SECRET", "csec") + auth = build_auth( + "cache-test.cloud.databricks.com", "oauth_m2m", client_id="unique-cid", namespace="ns-cache" + ) + calls = {"n": 0} + + async def _fake_post(session, url, form): + calls["n"] += 1 + return {"access_token": "T", "expires_in": 3600} + + with patch( + "flytekitplugins.spark.databricks_auth._post_oidc_token", side_effect=_fake_post + ), patch("flytekitplugins.spark.connector.get_secret_from_k8s", return_value=None): + async with ClientSession() as session: + await auth.get_bearer_token(session) + await auth.get_bearer_token(session) + assert calls["n"] == 1 + await auth.invalidate_cache() + + +# =========================================================================== # +# OIDCConnectorIRSAAuth (Model 1) # +# =========================================================================== # + + +class TestOIDCConnectorIRSAAuth: + @pytest.mark.asyncio + async def test_token_exchange_happy_path(self, tmp_path, monkeypatch): + jwt_path = tmp_path / "irsa.jwt" + jwt_path.write_text("irsa-jwt-contents") + monkeypatch.setenv("AWS_WEB_IDENTITY_TOKEN_FILE", str(jwt_path)) + auth = build_auth( + "ws1.cloud.databricks.com", + "oidc_federation", + client_id="irsa-cid", + ) + posted = {} + + async def _capture(session, url, form): + posted["form"] = form + return {"access_token": "OIDC-TOK", "expires_in": 3600} + + with patch( + "flytekitplugins.spark.databricks_auth._post_oidc_token", side_effect=_capture + ): + async with ClientSession() as session: + tok = await auth.get_bearer_token(session) + assert tok == "OIDC-TOK" + assert posted["form"]["subject_token"] == "irsa-jwt-contents" + assert posted["form"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange" + assert posted["form"]["client_id"] == "irsa-cid" + + @pytest.mark.asyncio + async def test_jwt_reread_each_refresh(self, tmp_path, monkeypatch): + jwt_path = tmp_path / "rotating.jwt" + jwt_path.write_text("v1") + monkeypatch.setenv("AWS_WEB_IDENTITY_TOKEN_FILE", str(jwt_path)) + auth = build_auth("ws-rotate.cloud.databricks.com", "oidc_federation", client_id="rot-cid") + + captured = [] + + async def _capture(session, url, form): + captured.append(form["subject_token"]) + return {"access_token": f"tok-{len(captured)}", "expires_in": 3600} + + with patch( + "flytekitplugins.spark.databricks_auth._post_oidc_token", side_effect=_capture + ): + async with ClientSession() as session: + await auth.get_bearer_token(session) + await auth.invalidate_cache() + jwt_path.write_text("v2") + await auth.get_bearer_token(session) + + assert captured == ["v1", "v2"] + + @pytest.mark.asyncio + async def test_missing_subject_file_errors(self, monkeypatch): + monkeypatch.delenv("AWS_WEB_IDENTITY_TOKEN_FILE", raising=False) + auth = build_auth("ws", "oidc_federation", client_id="cid") + with patch("os.path.exists", return_value=False): + async with ClientSession() as session: + with pytest.raises(DatabricksAuthError, match="no subject token file"): + await auth.get_bearer_token(session) + + @pytest.mark.asyncio + async def test_missing_client_id_errors(self, tmp_path, monkeypatch): + jwt_path = tmp_path / "irsa.jwt" + jwt_path.write_text("j") + monkeypatch.setenv("AWS_WEB_IDENTITY_TOKEN_FILE", str(jwt_path)) + auth = build_auth("ws-noid.cloud.databricks.com", "oidc_federation") + async with ClientSession() as session: + with pytest.raises(DatabricksAuthError, match="no client_id"): + await auth.get_bearer_token(session) + + +# =========================================================================== # +# OIDCNamespaceSAAuth (Model 2) # +# =========================================================================== # + + +class TestOIDCNamespaceSAAuth: + @pytest.mark.asyncio + async def test_happy_path(self): + auth = build_auth( + "ws-sa.cloud.databricks.com", + "oidc_federation", + client_id="sa-cid", + namespace="team-a", + oidc_service_account="team-sa", + oidc_audience="databricks", + ) + + async def _capture(session, url, form): + assert form["subject_token"] == "minted-jwt" + assert form["client_id"] == "sa-cid" + return {"access_token": "SA-OIDC", "expires_in": 3600} + + with patch( + "flytekitplugins.spark.databricks_auth._request_sa_token", return_value="minted-jwt" + ), patch( + "flytekitplugins.spark.databricks_auth._post_oidc_token", side_effect=_capture + ): + async with ClientSession() as session: + tok = await auth.get_bearer_token(session) + assert tok == "SA-OIDC" + + @pytest.mark.asyncio + async def test_missing_namespace_errors(self): + auth = build_auth( + "ws", "oidc_federation", client_id="cid", oidc_service_account="sa", namespace=None + ) + async with ClientSession() as session: + with pytest.raises(DatabricksAuthError, match="requires a workflow namespace"): + await auth.get_bearer_token(session) + + def test_sa_not_found_translates_404(self): + from flytekitplugins.spark.databricks_auth import _request_sa_token + from kubernetes.client.exceptions import ApiException + + with ( + patch("kubernetes.config.load_incluster_config"), + patch("kubernetes.client.CoreV1Api") as api_cls, + ): + api_cls.return_value.create_namespaced_service_account_token.side_effect = ApiException( + status=404 + ) + with pytest.raises(DatabricksAuthError, match="not found"): + _request_sa_token("ns-a", "missing-sa", "databricks") + + def test_forbidden_translates_403(self): + from flytekitplugins.spark.databricks_auth import _request_sa_token + from kubernetes.client.exceptions import ApiException + + with ( + patch("kubernetes.config.load_incluster_config"), + patch("kubernetes.client.CoreV1Api") as api_cls, + ): + api_cls.return_value.create_namespaced_service_account_token.side_effect = ApiException( + status=403 + ) + with pytest.raises(DatabricksAuthError, match="Forbidden"): + _request_sa_token("ns-a", "sa", "databricks") + + +# =========================================================================== # +# Namespace SA discovery (annotation-driven) # +# =========================================================================== # + + +class TestNamespaceDiscovery: + """Unit tests for the synchronous discovery helper (no caching).""" + + def test_zero_matches_returns_none(self): + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", + return_value=[], + ): + assert _discover_namespace_oidc_sa_sync("team-a", "databricks") is None + + def test_label_filter_passed_through(self): + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", return_value=[] + ) as mock_list: + _discover_namespace_oidc_sa_sync("team-a", "databricks") + mock_list.assert_called_once_with( + namespace="team-a", label_selector=DATABRICKS_ENABLED_LABEL_SELECTOR + ) + + def test_single_match_returns_config(self): + sa = _fake_sa("dbx-runner", client_id="cid-A") + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", + return_value=[sa], + ): + cfg = _discover_namespace_oidc_sa_sync("team-a", "databricks") + assert cfg is not None + assert cfg.service_account == "dbx-runner" + assert cfg.client_id == "cid-A" + assert cfg.audience == "databricks" # default applied + + def test_per_sa_audience_annotation_overrides_default(self): + sa = _fake_sa("dbx-runner", client_id="cid-A", audience="custom-aud") + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", + return_value=[sa], + ): + cfg = _discover_namespace_oidc_sa_sync("team-a", "databricks") + assert cfg.audience == "custom-aud" + + def test_sa_without_client_id_annotation_skipped(self): + # Labelled but no client-id annotation -> not a candidate. + sa = _fake_sa("misconfigured-sa", client_id=None) + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", + return_value=[sa], + ): + assert _discover_namespace_oidc_sa_sync("team-a", "databricks") is None + + def test_empty_client_id_annotation_skipped(self): + sa = _fake_sa("misconfigured-sa", client_id=" ") + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", + return_value=[sa], + ): + assert _discover_namespace_oidc_sa_sync("team-a", "databricks") is None + + def test_multiple_matches_raises_with_names(self): + sa1 = _fake_sa("alpha", client_id="cid-1") + sa2 = _fake_sa("beta", client_id="cid-2") + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", + return_value=[sa1, sa2], + ): + with pytest.raises(DatabricksAuthError, match=r"Ambiguous.*alpha.*beta"): + _discover_namespace_oidc_sa_sync("team-a", "databricks") + + +class TestNamespaceDiscoveryCache: + """Caching wrapper around the synchronous helper.""" + + @pytest.mark.asyncio + async def test_cache_hit_skips_listing(self): + from flytekitplugins.spark.databricks_auth import _discover_namespace_oidc_sa + + sa = _fake_sa("dbx-runner", client_id="cid") + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", return_value=[sa] + ) as mock_list: + r1 = await _discover_namespace_oidc_sa("ns-cache-a", "databricks") + r2 = await _discover_namespace_oidc_sa("ns-cache-a", "databricks") + assert r1 is not None and r2 is not None + assert r1.service_account == r2.service_account + assert mock_list.call_count == 1 # second call served from cache + + @pytest.mark.asyncio + async def test_cache_miss_is_also_cached(self): + from flytekitplugins.spark.databricks_auth import _discover_namespace_oidc_sa + + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", return_value=[] + ) as mock_list: + r1 = await _discover_namespace_oidc_sa("ns-cache-b", "databricks") + r2 = await _discover_namespace_oidc_sa("ns-cache-b", "databricks") + assert r1 is None and r2 is None + assert mock_list.call_count == 1 + + @pytest.mark.asyncio + async def test_invalidate_forces_relookup(self): + from flytekitplugins.spark.databricks_auth import ( + _NS_DISCOVERY_CACHE, + _discover_namespace_oidc_sa, + ) + + sa = _fake_sa("dbx-runner", client_id="cid") + with patch( + "flytekitplugins.spark.connector.list_serviceaccounts_in_k8s", return_value=[sa] + ) as mock_list: + await _discover_namespace_oidc_sa("ns-cache-c", "databricks") + await _NS_DISCOVERY_CACHE.invalidate("ns-cache-c") + await _discover_namespace_oidc_sa("ns-cache-c", "databricks") + assert mock_list.call_count == 2 + + +class TestOIDCNamespaceSAAuthInvalidatesDiscovery: + """Model 2's invalidate_cache should evict the discovery cache too.""" + + @pytest.mark.asyncio + async def test_invalidate_clears_discovery(self): + from flytekitplugins.spark.databricks_auth import _NS_DISCOVERY_CACHE + + auth = build_auth( + "ws.cloud.databricks.com", + "oidc_federation", + client_id="cid", + namespace="team-a", + oidc_service_account="sa", + oidc_audience="databricks", + ) + # Pre-seed the discovery cache for that namespace. + await _NS_DISCOVERY_CACHE.put( + "team-a", + _DiscoveredOIDCConfig(service_account="sa", client_id="cid", audience="databricks"), + ) + await auth.invalidate_cache() + cached, _ = await _NS_DISCOVERY_CACHE.get("team-a") + assert cached is False + + +# =========================================================================== # +# validate_connector_config # +# =========================================================================== # + + +class TestValidateConnectorConfig: + def test_no_default_logs_and_returns(self, monkeypatch, caplog): + monkeypatch.delenv("FLYTE_DATABRICKS_AUTH_TYPE", raising=False) + validate_connector_config() # must not raise + + def test_invalid_env_raises(self, monkeypatch): + monkeypatch.setenv("FLYTE_DATABRICKS_AUTH_TYPE", "bogus") + with pytest.raises(DatabricksAuthError): + validate_connector_config() + + def test_valid_env_returns(self, monkeypatch): + monkeypatch.setenv("FLYTE_DATABRICKS_AUTH_TYPE", "pat") + validate_connector_config() # no raise + + +# =========================================================================== # +# Connector integration: backward compat + 401 refresh # +# =========================================================================== # + + +class TestConnectorAuthBackwardCompat: + """Metadata written by older connectors (auth_type=None) must still work.""" + + @pytest.mark.asyncio + async def test_legacy_metadata_get_uses_stored_token_no_refresh(self): + import http as _http + + from flytekit.extend.backend.base_agent import AgentRegistry + from flytekitplugins.spark.connector import DATABRICKS_API_ENDPOINT, DatabricksJobMetadata + + meta = DatabricksJobMetadata( + databricks_instance="legacy.cloud.databricks.com", + run_id="111", + auth_token="legacy-pat", + # auth_type intentionally left None (simulating old connector) + ) + agent = AgentRegistry.get_agent("spark") + url = f"https://legacy.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=111" + with aioresponses() as m: + m.get( + url, + status=_http.HTTPStatus.OK, + payload={"job_id": "1", "run_id": "111", "state": {"life_cycle_state": "RUNNING"}}, + ) + await agent.get(meta) + call = list(m.requests.values())[0][0] + assert call.kwargs["headers"]["Authorization"] == "Bearer legacy-pat" + + +class TestConnectorAuthRefresh: + """OAuth/OIDC metadata should transparently refresh on 401.""" + + @pytest.mark.asyncio + async def test_get_refreshes_token_on_401(self, monkeypatch): + import http as _http + + from flytekit.extend.backend.base_agent import AgentRegistry + from flytekitplugins.spark.connector import DATABRICKS_API_ENDPOINT, DatabricksJobMetadata + + monkeypatch.setenv("DATABRICKS_CLIENT_SECRET", "csec") + meta = DatabricksJobMetadata( + databricks_instance="refresh-ws.cloud.databricks.com", + run_id="222", + auth_type="oauth_m2m", + client_id="refresh-cid", + namespace="ns-a", + oauth_secret_name="databricks-oauth", + ) + agent = AgentRegistry.get_agent("spark") + url = f"https://refresh-ws.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=222" + + tokens = iter(["stale-token", "fresh-token"]) + + async def _fake_post(session, workspace_url, form): + return {"access_token": next(tokens), "expires_in": 3600} + + with patch( + "flytekitplugins.spark.databricks_auth._post_oidc_token", side_effect=_fake_post + ), patch("flytekitplugins.spark.connector.get_secret_from_k8s", return_value=None): + with aioresponses() as m: + # First call: 401 with the stale token. + m.get(url, status=_http.HTTPStatus.UNAUTHORIZED, body="expired") + # Second call: 200 with the refreshed token. + m.get( + url, + status=_http.HTTPStatus.OK, + payload={"job_id": "1", "run_id": "222", "state": {"life_cycle_state": "RUNNING"}}, + ) + await agent.get(meta) + + requests = [r for rs in m.requests.values() for r in rs] + assert len(requests) == 2 + assert requests[0].kwargs["headers"]["Authorization"] == "Bearer stale-token" + assert requests[1].kwargs["headers"]["Authorization"] == "Bearer fresh-token" diff --git a/plugins/flytekit-spark/tests/test_databricks_token.py b/plugins/flytekit-spark/tests/test_databricks_token.py index 9bed13bb56..c776e5ac4a 100644 --- a/plugins/flytekit-spark/tests/test_databricks_token.py +++ b/plugins/flytekit-spark/tests/test_databricks_token.py @@ -35,6 +35,21 @@ # --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def _clear_databricks_auth_env(monkeypatch): + """Ensure PAT auto-detection is reached regardless of the dev machine's env.""" + for var in ( + "FLYTE_DATABRICKS_AUTH_TYPE", + "FLYTE_DATABRICKS_OAUTH_SECRET_NAME", + "FLYTE_DATABRICKS_OIDC_TOKEN_FILE", + "FLYTE_DATABRICKS_OIDC_AUDIENCE", + "DATABRICKS_CLIENT_ID", + "DATABRICKS_CLIENT_SECRET", + "AWS_WEB_IDENTITY_TOKEN_FILE", + ): + monkeypatch.delenv(var, raising=False) + + @pytest.fixture(scope="function") def task_template() -> TaskTemplate: """Standard Databricks task template for testing.""" @@ -405,10 +420,11 @@ async def test_create_uses_namespace_token(self, task_template): mock_token.assert_called_once_with( namespace="project-alpha", - task_template=task_template, - secret_name=None, + secret_name="databricks-token", ) assert result.auth_token == "dapi_project_alpha_token" + assert result.auth_type == "pat" + assert result.namespace == "project-alpha" assert result.run_id == "999" @pytest.mark.asyncio @@ -429,10 +445,10 @@ async def test_create_with_custom_secret_name(self, task_template_with_custom_se mock_token.assert_called_once_with( namespace="team-x", - task_template=tt, secret_name="my-team-databricks-token", ) assert result.auth_token == "dapi_team_x_token" + assert result.token_secret_name == "my-team-databricks-token" @pytest.mark.asyncio async def test_create_without_metadata_uses_no_namespace(self, task_template): @@ -450,8 +466,7 @@ async def test_create_without_metadata_uses_no_namespace(self, task_template): mock_token.assert_called_once_with( namespace=None, - task_template=task_template, - secret_name=None, + secret_name="databricks-token", ) @pytest.mark.asyncio