Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,42 +57,45 @@
"ai_agg",
"ai_summarize_agg",
"any_value",
"approximate_count_distinct",
"approximate_jaccard_index",
"approximate_similarity",
"approx_count_distinct",
"approx_percentile",
"approx_percentile_accumulate",
"approx_percentile_combine",
"approx_top_k",
"approx_top_k_accumulate",
"approx_top_k_combine",
"arrayagg",
"approximate_count_distinct",
"approximate_jaccard_index",
"approximate_similarity",
"array_agg",
"array_union_agg",
"array_unique_agg",
"arrayagg",
"avg",
"bitandagg",
"bit_and_agg",
"bit_andagg",
"bit_or_agg",
"bit_oragg",
"bit_xor_agg",
"bit_xoragg",
"bitand_agg",
"bitandagg",
"bitmap_and_agg",
"bitmap_construct_agg",
"bitmap_or_agg",
"bitoragg",
"bitor_agg",
"bitxoragg",
"bitoragg",
"bitxor_agg",
"bit_andagg",
"bit_and_agg",
"bit_oragg",
"bit_or_agg",
"bit_xoragg",
"bit_xor_agg",
"bitxoragg",
"booland_agg",
"boolor_agg",
"boolxor_agg",
"corr",
"count",
"count(*)",
"count_if",
"count_internal",
"count_internal(*)",
"covar_pop",
"covar_samp",
"datasketches_hll",
Expand All @@ -110,12 +113,12 @@
"max_by",
"median",
"min",
"min_by",
"minhash",
"minhash_combine",
"min_by",
"mode",
"objectagg",
"object_agg",
"objectagg",
"percentile_cont",
"percentile_disc",
"regr_avgx",
Expand All @@ -128,27 +131,29 @@
"regr_sxy",
"regr_syy",
"skew",
"st_intersection_agg_geography_internal",
"st_union_agg_geography_internal",
"stddev",
"stddev_pop",
"stddev_samp",
"st_intersection_agg_geography_internal",
"st_union_agg_geography_internal",
"sum",
"sum_internal",
"sum_internal_real",
"sum_real",
"summarize_agg",
"var_pop",
"var_samp",
"variance",
"variance_pop",
"variance_samp",
"var_pop",
"var_samp",
"vector_avg",
"vector_max",
"vector_min",
"vector_sum",
]
)


_cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session.

# Following are internal-only global flags, used to enable development features.
Expand Down
102 changes: 75 additions & 27 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,10 @@ def __init__(
self._dataframe_profiler = DataframeProfiler(session=self)
self._catalog = None
self._client_telemetry = EventTableTelemetry(session=self)
self._agg_function_prefetch_job: Optional[AsyncJob] = None

self._ast_batch = AstBatch(self)
self._start_async_aggregation_prefetch_if_needed()

_logger.info("Snowpark Session information: %s", self._session_info)

Expand Down Expand Up @@ -5055,43 +5057,89 @@ def _retrieve_aggregation_function_list(self) -> None:
return

retrieved_set = set()
system_fetch_succeeded = False

# User-defined aggregation functions
try:
retrieved_set.update(
{
r[0].lower()
for r in self.sql(
"""select function_name from information_schema.functions where is_aggregate = 'YES'"""
).collect()
}
)
except Exception as e:
# Try async result first if prefetch was already started.
if self._agg_function_prefetch_job is not None:
try:
retrieved_set.update(
{r[0].lower() for r in self._agg_function_prefetch_job.result()}
)
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to use async aggregation function prefetch: %s",
e,
)
finally:
self._agg_function_prefetch_job = None
else:
_logger.debug(
"Unable to get user-defined aggregation functions: %s",
e,
"Async aggregation function prefetch job is unavailable; using sync fallback."
)

# System built-in aggregation functions
# Sync fallback query.
if not system_fetch_succeeded:
try:
retrieved_set.update(
{
r[0].lower()
for r in self._conn.run_query(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
_is_internal=True,
Comment on lines +5088 to +5089
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical Bug: User-defined aggregation functions are missing in sync fallback

The sync fallback query only retrieves system aggregation functions via show functions, but the async prefetch query (lines 5121-5123) uses a UNION to retrieve both system AND user-defined aggregation functions from information_schema.functions. This inconsistency means:

  • If async succeeds: Both system and user-defined aggregation functions are available
  • If async fails and sync fallback is used: Only system functions are available, user-defined functions are lost

This will cause production issues where user-defined aggregation functions work inconsistently depending on whether the async query succeeded or failed.

Fix:

# Change the sync query to match the async query:
r in self._conn.run_query(
    """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'
union
select function_name from information_schema.functions where is_aggregate = 'YES'""",
    _is_internal=True,
)["data"]
Suggested change
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
_is_internal=True,
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'
union
select function_name from information_schema.functions where is_aggregate = 'YES'""",
_is_internal=True,

Spotted by Graphite

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

)["data"]
}
)
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to get aggregation functions via sync fallback query: %s",
e,
)

# Fallback to the local hardcoded list only when metadata retrieval fails.
if not system_fetch_succeeded:
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)

def _start_async_aggregation_prefetch_if_needed(self) -> None:
"""Start aggregation metadata prefetch only when not already in progress."""
if not (
context._is_snowpark_connect_compatible_mode
and context._snowpark_connect_flatten_select_after_sort
):
return
if context._aggregation_function_set:
return
if self._agg_function_prefetch_job is not None:
return

try:
retrieved_set.update(
{
r[0].lower()
for r in self.sql(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'"""
Comment on lines -5080 to -5081
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that queries like this ensure the server is always the source of truth, and we don't have to make client code changes that are aligned with whatever Snowflake version is being run on the server. Hard-coding it to a local change is something of a philosophical change, and may unexpectedly break workloads in some circumstances.

It looks to only ever get run once per workload anyway, so I think the benefits of this would be pretty minimal.

Copy link
Copy Markdown
Collaborator Author

@sfc-gh-yuwang sfc-gh-yuwang May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a discussion with Adam and Yijun this morning. One of the issue is that this is always run when there are potentially agg function in the query, which means user would still see this even if they just issued 1 query in SCOS.
This is also a pretty expensive operation that took 3~4 seconds, which is a big portion when user workflow is small.

To solve the problem that this list may diverge from server, I created a followup ticket: https://snowflakecomputing.atlassian.net/browse/SNOW-3489271

Currently my thoughts is adding a step to release pipeline to ensure that each snowpark release's system function list is up-to-date at the point it released. Thus user would only need to use latest snowpark to avoid the mis align issue.

But I think it also make sense to add parameter protection for this. Will add one

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter-protecting makes sense. I'm slightly worried about server divergence because the same client release might get run against different version backends (I'm not sure if we the N-1 version policy still exists, but at the very least testing vs. prod environments, and customers on different release cadences might get different versions)

).collect()
}
self._agg_function_prefetch_job = self._submit_internal_async_prefetch_query(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'
union
select function_name from information_schema.functions where is_aggregate = 'YES'"""
)
except Exception as e:
except Exception as e: # pragma: no cover
_logger.debug(
"Unable to get system aggregation functions, "
"falling back to hardcoded list: %s",
"Unable to start async aggregation metadata prefetch: %s",
e,
)
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)
self._agg_function_prefetch_job = None

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)
def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]:
"""Submit a prefetch query as internal async and return an AsyncJob handle."""
try:
result = self._conn.execute_async_and_notify_query_listener(
query,
_is_internal=True,
)
return self.create_async_job(result["queryId"])
except Exception as e: # pragma: no cover
_logger.debug("Unable to submit internal async prefetch query: %s", e)
return None

def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
"""
Expand Down
76 changes: 76 additions & 0 deletions tests/integ/test_simplifier_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,3 +2519,79 @@ def test_retrieving_aggregation_funcs(session, monkeypatch):
assert not context._aggregation_function_set
session._retrieve_aggregation_function_list()
assert not context._aggregation_function_set


def test_internal_async_aggregation_prefetch_submission(session, monkeypatch):
import snowflake.snowpark.context as context

monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True)
monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True)
monkeypatch.setattr(context, "_aggregation_function_set", set())
session._agg_function_prefetch_job = None

calls = []

def _fake_execute_async(query, **kwargs):
calls.append((query, kwargs))
return {"queryId": "qid_combined"}

monkeypatch.setattr(
session._conn, "execute_async_and_notify_query_listener", _fake_execute_async
)
session._start_async_aggregation_prefetch_if_needed()

assert len(calls) == 1
assert calls[0][1].get("_is_internal") is True
assert "show functions" in calls[0][0]
assert "information_schema.functions" in calls[0][0]
assert session._agg_function_prefetch_job.query_id == "qid_combined"


def test_aggregation_fallback_not_used_when_combined_async_succeeds(
session, monkeypatch
):
import snowflake.snowpark.context as context

class _FakeAsyncJob:
def __init__(self, rows=None, error=None) -> None:
self._rows = rows
self._error = error

def result(self):
if self._error is not None:
raise self._error
return self._rows

monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True)
monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True)
monkeypatch.setattr(context, "_aggregation_function_set", set())
session._agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)])

session._retrieve_aggregation_function_list()

assert "sum" in context._aggregation_function_set
assert "sum_internal" not in context._aggregation_function_set


def test_internal_sync_aggregation_fallback_submission(session, monkeypatch):
import snowflake.snowpark.context as context

monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True)
monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True)
monkeypatch.setattr(context, "_aggregation_function_set", set())
session._agg_function_prefetch_job = None

calls = []

def _fake_run_query(query, **kwargs):
calls.append((query, kwargs))
return {"data": [("AVG",)]}

monkeypatch.setattr(session._conn, "run_query", _fake_run_query)
session._retrieve_aggregation_function_list()

assert len(calls) == 1
assert calls[0][1].get("_is_internal") is True
assert "show functions" in calls[0][0]
assert "information_schema.functions" not in calls[0][0]
assert "avg" in context._aggregation_function_set
Loading
Loading