Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### New Features

- Added `artifact_repository` support to `udtf_configs` in `session.read.dbapi()`, enabling users to specify a custom artifact repository (e.g. PyPI) for packages used by the internal UDTF during distributed ingestion.
- Added `get_wif_token` to `snowflake.snowpark.secrets` for workload identity federation tokens on the Snowflake server (not available in SPCS file-based secret environments).

#### Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions docs/source/snowpark/secrets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Snowpark Secrets
get_secret_type
get_username_password
get_cloud_provider_token
get_wif_token
39 changes: 39 additions & 0 deletions src/snowflake/snowpark/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"get_secret_type",
"get_username_password",
"get_cloud_provider_token",
"get_wif_token",
"UsernamePassword",
"CloudProviderToken",
]
Expand Down Expand Up @@ -61,6 +62,10 @@ def get_username_password(self, secret_name: str) -> UsernamePassword:
def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
pass

@abstractmethod
def get_wif_token(self, secret_name: str, audience: str) -> str:
pass


class _SnowflakeSecretsServer(_SnowflakeSecrets):
"""Secret instance for Snowflake server environment (using _snowflake module)."""
Expand Down Expand Up @@ -89,6 +94,9 @@ def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
secret_object.token,
)

def get_wif_token(self, secret_name: str, audience: str) -> str:
return self._snowflake.get_wif_token(secret_name, audience)


class _SnowflakeSecretsSPCS(_SnowflakeSecrets):
"""Secret instance for SPCS container environment (file-based secrets)."""
Expand Down Expand Up @@ -173,6 +181,11 @@ def get_cloud_provider_token(self, secret_name: str) -> CloudProviderToken:
"Cloud provider token secrets are not supported in SPCS container environments."
)

def get_wif_token(self, secret_name: str, audience: str) -> str:
raise NotImplementedError(
"WIF token secrets are not supported in SPCS container environments."
)


def _is_spcs_environment() -> bool:
return os.getenv(_SCLS_SPCS_SECRET_ENV_NAME, None) is not None
Expand Down Expand Up @@ -259,3 +272,29 @@ def get_cloud_provider_token(secret_name: str) -> CloudProviderToken:
NotImplementedError: If running outside Snowflake server environment.
"""
return _get_secrets_instance().get_cloud_provider_token(secret_name)


def get_wif_token(secret_name: str, audience: str) -> str:
"""Get a workload identity federation (WIF) token from Snowflake.

Note:
Requires a Snowflake environment with a WIF secret configured and an
external access integration that allows the UDF or stored procedure to
use that secret. The ``audience`` must match the token audience expected
by the external system (for example, an OAuth token endpoint URL).

Args:
secret_name: The secret reference name bound to the WIF secret.
audience: The intended audience (``aud``) for the issued token.

Returns:
The issued token as a string (typically a JWT).

Raises:
NotImplementedError: If running outside the Snowflake server environment
(including SPCS file-based secret environments, where WIF tokens cannot
be minted).
ValueError: If the secret does not exist or is not authorized (when
applicable in supported environments).
"""
return _get_secrets_instance().get_wif_token(secret_name, audience)
34 changes: 34 additions & 0 deletions tests/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def set_up_external_access_integration_resources(
integration1,
integration2,
integration3,
key4,
integration4,
wif_audience,
):
try:
# IMPORTANT SETUP NOTES: the test role needs to be granted the creation privilege
Expand Down Expand Up @@ -127,6 +130,12 @@ def set_up_external_access_integration_resources(
).collect()
session.sql(
f"""
CREATE SECRET IF NOT EXISTS {key4}
TYPE = WORKLOAD_IDENTITY_FEDERATION;
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS EXTERNAL ACCESS INTEGRATION {integration1}
ALLOWED_NETWORK_RULES = ({rule1})
ALLOWED_AUTHENTICATION_SECRETS = ({key1})
Expand All @@ -147,6 +156,13 @@ def set_up_external_access_integration_resources(
ALLOWED_NETWORK_RULES = ({rule3})
ALLOWED_AUTHENTICATION_SECRETS = ({key3})
ENABLED = true;
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS EXTERNAL ACCESS INTEGRATION {integration4}
ALLOWED_AUTHENTICATION_SECRETS = ({key4})
ENABLED = true;
"""
).collect()
CONNECTION_PARAMETERS["external_access_rule1"] = rule1
Expand All @@ -155,9 +171,12 @@ def set_up_external_access_integration_resources(
CONNECTION_PARAMETERS["external_access_key1"] = key1
CONNECTION_PARAMETERS["external_access_key2"] = key2
CONNECTION_PARAMETERS["external_access_key3"] = key3
CONNECTION_PARAMETERS["external_access_key4"] = key4
CONNECTION_PARAMETERS["external_access_integration1"] = integration1
CONNECTION_PARAMETERS["external_access_integration2"] = integration2
CONNECTION_PARAMETERS["external_access_integration3"] = integration3
CONNECTION_PARAMETERS["external_access_integration4"] = integration4
CONNECTION_PARAMETERS["wif_audience"] = wif_audience
except SnowparkSQLException:
# GCP currently does not support external access integration
# we can remove the exception once the integration is available on GCP
Expand All @@ -183,9 +202,12 @@ def clean_up_external_access_integration_resources():
CONNECTION_PARAMETERS.pop("external_access_key1", None)
CONNECTION_PARAMETERS.pop("external_access_key2", None)
CONNECTION_PARAMETERS.pop("external_access_key3", None)
CONNECTION_PARAMETERS.pop("external_access_key4", None)
CONNECTION_PARAMETERS.pop("external_access_integration1", None)
CONNECTION_PARAMETERS.pop("external_access_integration2", None)
CONNECTION_PARAMETERS.pop("external_access_integration3", None)
CONNECTION_PARAMETERS.pop("external_access_integration4", None)
CONNECTION_PARAMETERS.pop("wif_audience", None)


def set_up_dataframe_processor_parameters(
Expand Down Expand Up @@ -311,9 +333,12 @@ def session(
key1 = "snowpark_python_test_key1"
key2 = "snowpark_python_test_key2"
key3 = "snowpark_python_test_key3"
key4 = "snowpark_python_test_key4"
integration1 = "snowpark_python_test_integration1"
integration2 = "snowpark_python_test_integration2"
integration3 = "snowpark_python_test_integration3"
integration4 = "snowpark_python_test_integration4"
wif_audience = "https://replace-with-your-wif-audience"

session = (
Session.builder.configs(db_parameters)
Expand Down Expand Up @@ -347,6 +372,9 @@ def session(
integration1,
integration2,
integration3,
key4,
integration4,
wif_audience,
)

if validate_ast:
Expand Down Expand Up @@ -389,9 +417,12 @@ def profiler_session(
key1 = "snowpark_python_profiler_test_key1"
key2 = "snowpark_python_profiler_test_key2"
key3 = "snowpark_python_profiler_test_key3"
key4 = "snowpark_python_profiler_test_key4"
integration1 = "snowpark_python_profiler_test_integration1"
integration2 = "snowpark_python_profiler_test_integration2"
integration3 = "snowpark_python_profiler_test_integration3"
integration4 = "snowpark_python_profiler_test_integration4"
wif_audience = "https://replace-with-your-wif-audience"
session = (
Session.builder.configs(db_parameters)
.config("local_testing", local_testing_mode)
Expand All @@ -411,6 +442,9 @@ def profiler_session(
integration1,
integration2,
integration3,
key4,
integration4,
wif_audience,
)
try:
yield session
Expand Down
53 changes: 53 additions & 0 deletions tests/integ/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_secret_type,
get_cloud_provider_token,
get_oauth_access_token,
get_wif_token,
)
from snowflake.snowpark.types import BooleanType, StringType
from tests.utils import IS_NOT_ON_GITHUB, RUNNING_ON_JENKINS, IS_IN_STORED_PROC, Utils
Expand Down Expand Up @@ -152,6 +153,56 @@ def get_secret():
)


@pytest.mark.skipif(
IS_NOT_ON_GITHUB or not RUNNING_ON_JENKINS,
reason="Secret API is only supported on Snowflake server environment",
)
def test_get_wif_token_udf(session, db_parameters):
def get_wif():
token = get_wif_token("cred", db_parameters["wif_audience"])
return len(token) > 0

try:
get_wif_udf = session.udf.register(
get_wif,
return_type=BooleanType(),
packages=["snowflake-snowpark-python"],
external_access_integrations=[
db_parameters["external_access_integration4"]
],
secrets={"cred": f"{db_parameters['external_access_key4']}"},
)
df = session.create_dataframe([[1], [2]]).to_df("x")
Utils.check_answer(df.select(get_wif_udf()), [Row(True), Row(True)])
except KeyError:
pytest.skip("External Access Integration is not supported on the deployment.")


@pytest.mark.skipif(
IS_NOT_ON_GITHUB or not RUNNING_ON_JENKINS,
reason="Secret API is only supported on Snowflake server environment",
)
def test_get_wif_token_sproc(session, db_parameters):
def get_wif_in_sproc(session_):
token = get_wif_token("cred", db_parameters["wif_audience"])
return len(token) > 0

try:
get_wif_sp = session.sproc.register(
get_wif_in_sproc,
return_type=BooleanType(),
packages=["snowflake-snowpark-python"],
external_access_integrations=[
db_parameters["external_access_integration4"]
],
secrets={"cred": f"{db_parameters['external_access_key4']}"},
anonymous=True,
)
assert get_wif_sp()
except KeyError:
pytest.skip("External Access Integration is not supported on the deployment.")


@pytest.mark.skipif(
IS_IN_STORED_PROC,
reason="Run only outside Snowflake server to validate NotImplementedError",
Expand All @@ -169,3 +220,5 @@ def test_secrets_import_error():
get_cloud_provider_token("c1")
with pytest.raises(NotImplementedError):
get_oauth_access_token("o1")
with pytest.raises(NotImplementedError):
get_wif_token("w1", "https://audience")
12 changes: 12 additions & 0 deletions tests/unit/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_secret_type,
get_username_password,
get_cloud_provider_token,
get_wif_token,
UsernamePassword,
CloudProviderToken,
_SCLS_SPCS_SECRET_ENV_NAME,
Expand All @@ -31,6 +32,7 @@ def _build_fake_snowflake_module() -> object:
get_secret_type=lambda secret_name: "PASSWORD",
get_username_password=lambda secret_name: fake_username_password,
get_cloud_provider_token=lambda secret_name: fake_cloud_token,
get_wif_token=lambda secret_name, audience: f"wif:{secret_name}:{audience}",
)


Expand All @@ -52,6 +54,11 @@ def test_secrets_mock_server_paths():
assert cloud.secret_access_key == "SECRET_TEST"
assert cloud.token == "STS_TOKEN_TEST"

assert (
get_wif_token("w1", "https://example.com/aud")
== "wif:w1:https://example.com/aud"
)


@pytest.fixture
def scls_spcs_mock_env(tmp_path):
Expand Down Expand Up @@ -135,6 +142,9 @@ def test_secrets_mock_scls_spcs_error_cases(scls_spcs_mock_env):
with pytest.raises(NotImplementedError):
get_cloud_provider_token("any_secret")

with pytest.raises(NotImplementedError):
get_wif_token("any_secret", "https://audience")

with pytest.raises(ValueError, match="Unknown secret type"):
get_secret_type("unknown_secret")

Expand All @@ -159,6 +169,8 @@ def test_secrets_import_error_paths():
get_username_password("p1")
with pytest.raises(NotImplementedError):
get_cloud_provider_token("c1")
with pytest.raises(NotImplementedError):
get_wif_token("w1", "https://audience")
finally:
if original_env is not None:
os.environ[_SCLS_SPCS_SECRET_ENV_NAME] = original_env
Loading