From 92b4d7087a6ee046538248676be82720245dad78 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 7 May 2026 11:35:14 -0700 Subject: [PATCH 1/5] use local list instead of fetch from snowflake --- src/snowflake/snowpark/context.py | 42 +++++++++++++++++-------------- src/snowflake/snowpark/session.py | 17 +------------ 2 files changed, 24 insertions(+), 35 deletions(-) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index a111839050..e2f9a465bc 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -57,9 +57,6 @@ "ai_agg", "ai_summarize_agg", "any_value", - "approximate_count_distinct", - "approximate_jaccard_index", - "approximate_similarity", "approx_count_distinct", "approx_percentile", "approx_percentile_accumulate", @@ -67,32 +64,38 @@ "approx_top_k", "approx_top_k_accumulate", "approx_top_k_combine", - "arrayagg", + "approximate_count_distinct", + "approximate_jaccard_index", + "approximate_similarity", "array_agg", "array_union_agg", "array_unique_agg", + "arrayagg", "avg", - "bitandagg", + "bit_and_agg", + "bit_andagg", + "bit_or_agg", + "bit_oragg", + "bit_xor_agg", + "bit_xoragg", "bitand_agg", + "bitandagg", + "bitmap_and_agg", "bitmap_construct_agg", "bitmap_or_agg", - "bitoragg", "bitor_agg", - "bitxoragg", + "bitoragg", "bitxor_agg", - "bit_andagg", - "bit_and_agg", - "bit_oragg", - "bit_or_agg", - "bit_xoragg", - "bit_xor_agg", + "bitxoragg", "booland_agg", "boolor_agg", "boolxor_agg", "corr", "count", + "count(*)", "count_if", "count_internal", + "count_internal(*)", "covar_pop", "covar_samp", "datasketches_hll", @@ -110,12 +113,12 @@ "max_by", "median", "min", + "min_by", "minhash", "minhash_combine", - "min_by", "mode", - "objectagg", "object_agg", + "objectagg", "percentile_cont", "percentile_disc", "regr_avgx", @@ -128,20 +131,21 @@ "regr_sxy", "regr_syy", "skew", + "st_intersection_agg_geography_internal", + "st_union_agg_geography_internal", "stddev", "stddev_pop", "stddev_samp", - "st_intersection_agg_geography_internal", - "st_union_agg_geography_internal", "sum", "sum_internal", "sum_internal_real", "sum_real", + "summarize_agg", + "var_pop", + "var_samp", "variance", "variance_pop", "variance_samp", - "var_pop", - "var_samp", "vector_avg", "vector_max", "vector_min", diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 457f28f95b..c6da487c7d 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5073,22 +5073,7 @@ def _retrieve_aggregation_function_list(self) -> None: ) # System built-in aggregation functions - try: - retrieved_set.update( - { - r[0].lower() - for r in self.sql( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect() - } - ) - except Exception as e: - _logger.debug( - "Unable to get system aggregation functions, " - "falling back to hardcoded list: %s", - e, - ) - retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set) From 937650ac96d3668d008cd53fe8ffa17190551a3f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 11 May 2026 16:14:36 -0700 Subject: [PATCH 2/5] test --- src/snowflake/snowpark/context.py | 1 + src/snowflake/snowpark/session.py | 115 +++++++++++++++++++++++++----- 2 files changed, 100 insertions(+), 16 deletions(-) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index e2f9a465bc..240672a571 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -153,6 +153,7 @@ ] ) + _cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session. # Following are internal-only global flags, used to enable development features. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index c6da487c7d..a2b4cec3dd 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -856,8 +856,11 @@ def __init__( self._dataframe_profiler = DataframeProfiler(session=self) self._catalog = None self._client_telemetry = EventTableTelemetry(session=self) + self._system_agg_function_prefetch_job: Optional[AsyncJob] = None + self._user_agg_function_prefetch_job: Optional[AsyncJob] = None self._ast_batch = AstBatch(self) + self._start_async_aggregation_prefetch() _logger.info("Snowpark Session information: %s", self._session_info) @@ -5056,28 +5059,108 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set = set() - # User-defined aggregation functions - try: - retrieved_set.update( - { - r[0].lower() - for r in self.sql( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect() - } - ) - except Exception as e: - _logger.debug( - "Unable to get user-defined aggregation functions: %s", - e, - ) + # User-defined aggregation functions. + # If init has already issued the async query, wait and use it. + # Otherwise, execute synchronously now for select-statement correctness. + if self._user_agg_function_prefetch_job is not None: + try: + retrieved_set.update( + { + r[0].lower() + for r in self._user_agg_function_prefetch_job.result() + } + ) + except Exception as e: + _logger.debug( + "Unable to use async user-defined aggregation function prefetch: %s", + e, + ) + finally: + self._user_agg_function_prefetch_job = None + else: + try: + retrieved_set.update( + { + r[0].lower() + for r in self.sql( + """select function_name from information_schema.functions where is_aggregate = 'YES'""" + ).collect() + } + ) + except Exception as e: + _logger.debug( + "Unable to get user-defined aggregation functions: %s", + e, + ) - # System built-in aggregation functions + # System aggregation functions from metadata query. + if self._system_agg_function_prefetch_job is not None: + try: + retrieved_set.update( + { + r[0].lower() + for r in self._system_agg_function_prefetch_job.result() + } + ) + except Exception as e: + _logger.debug( + "Unable to use async system aggregation function prefetch: %s", + e, + ) + finally: + self._system_agg_function_prefetch_job = None + else: + try: + retrieved_set.update( + { + r[0].lower() + for r in self.sql( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" + ).collect() + } + ) + except Exception as e: + _logger.debug( + "Unable to get system aggregation functions: %s", + e, + ) + + # Keep hardcoded fallback behavior. retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set) + def _start_async_aggregation_prefetch(self) -> None: + """Issue async prefetch query for aggregation metadata once.""" + if not ( + context._is_snowpark_connect_compatible_mode + and context._snowpark_connect_flatten_select_after_sort + ): + return + + try: + self._user_agg_function_prefetch_job = self.sql( + """select function_name from information_schema.functions where is_aggregate = 'YES'""" + ).collect_nowait() + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async user-defined aggregation metadata prefetch: %s", + e, + ) + self._user_agg_function_prefetch_job = None + + try: + self._system_agg_function_prefetch_job = self.sql( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" + ).collect_nowait() + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async system aggregation metadata prefetch: %s", + e, + ) + self._system_agg_function_prefetch_job = None + def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ Returns a DataFrame representing the results of a directory table query on the specified stage. From 1650ef6415afd78bbc8df075f12285c63c641c1b Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 12 May 2026 10:08:37 -0700 Subject: [PATCH 3/5] aysnc update --- src/snowflake/snowpark/session.py | 63 +++++++++++++++++++------------ 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a2b4cec3dd..55c656efcb 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -860,7 +860,7 @@ def __init__( self._user_agg_function_prefetch_job: Optional[AsyncJob] = None self._ast_batch = AstBatch(self) - self._start_async_aggregation_prefetch() + self._start_async_aggregation_prefetch_if_needed() _logger.info("Snowpark Session information: %s", self._session_info) @@ -5058,6 +5058,7 @@ def _retrieve_aggregation_function_list(self) -> None: return retrieved_set = set() + system_fetch_succeeded = False # User-defined aggregation functions. # If init has already issued the async query, wait and use it. @@ -5102,6 +5103,7 @@ def _retrieve_aggregation_function_list(self) -> None: for r in self._system_agg_function_prefetch_job.result() } ) + system_fetch_succeeded = True except Exception as e: _logger.debug( "Unable to use async system aggregation function prefetch: %s", @@ -5119,47 +5121,58 @@ def _retrieve_aggregation_function_list(self) -> None: ).collect() } ) + system_fetch_succeeded = True except Exception as e: _logger.debug( "Unable to get system aggregation functions: %s", e, ) - # Keep hardcoded fallback behavior. - retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + # Fallback to the local hardcoded list only when both metadata fetches fail. + if not system_fetch_succeeded: + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set) - def _start_async_aggregation_prefetch(self) -> None: - """Issue async prefetch query for aggregation metadata once.""" + def _start_async_aggregation_prefetch_if_needed(self) -> None: + """Start aggregation metadata prefetch only when not already in progress.""" if not ( context._is_snowpark_connect_compatible_mode and context._snowpark_connect_flatten_select_after_sort ): return + if context._aggregation_function_set: + return + if ( + self._user_agg_function_prefetch_job is not None + and self._system_agg_function_prefetch_job is not None + ): + return - try: - self._user_agg_function_prefetch_job = self.sql( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect_nowait() - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async user-defined aggregation metadata prefetch: %s", - e, - ) - self._user_agg_function_prefetch_job = None + if self._user_agg_function_prefetch_job is None: + try: + self._user_agg_function_prefetch_job = self.sql( + """select function_name from information_schema.functions where is_aggregate = 'YES'""" + ).collect_nowait() + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async user-defined aggregation metadata prefetch: %s", + e, + ) + self._user_agg_function_prefetch_job = None - try: - self._system_agg_function_prefetch_job = self.sql( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect_nowait() - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async system aggregation metadata prefetch: %s", - e, - ) - self._system_agg_function_prefetch_job = None + if self._system_agg_function_prefetch_job is None: + try: + self._system_agg_function_prefetch_job = self.sql( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" + ).collect_nowait() + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async system aggregation metadata prefetch: %s", + e, + ) + self._system_agg_function_prefetch_job = None def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ From 09b3946c908f27d9429ec71163973c46d96b68d0 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 12 May 2026 11:07:15 -0700 Subject: [PATCH 4/5] add test --- src/snowflake/snowpark/session.py | 34 ++++++++---- tests/integ/test_simplifier_suite.py | 79 ++++++++++++++++++++++++++++ tests/unit/test_session.py | 48 +++++++++++------ 3 files changed, 136 insertions(+), 25 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 55c656efcb..441bbb23d0 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5083,9 +5083,10 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set.update( { r[0].lower() - for r in self.sql( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect() + for r in self._conn.run_query( + """select function_name from information_schema.functions where is_aggregate = 'YES'""", + _is_internal=True, + )["data"] } ) except Exception as e: @@ -5116,9 +5117,10 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set.update( { r[0].lower() - for r in self.sql( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect() + for r in self._conn.run_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + _is_internal=True, + )["data"] } ) system_fetch_succeeded = True @@ -5152,9 +5154,9 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: if self._user_agg_function_prefetch_job is None: try: - self._user_agg_function_prefetch_job = self.sql( + self._user_agg_function_prefetch_job = self._submit_internal_async_prefetch_query( """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect_nowait() + ) except Exception as e: # pragma: no cover _logger.debug( "Unable to start async user-defined aggregation metadata prefetch: %s", @@ -5164,9 +5166,9 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: if self._system_agg_function_prefetch_job is None: try: - self._system_agg_function_prefetch_job = self.sql( + self._system_agg_function_prefetch_job = self._submit_internal_async_prefetch_query( """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect_nowait() + ) except Exception as e: # pragma: no cover _logger.debug( "Unable to start async system aggregation metadata prefetch: %s", @@ -5174,6 +5176,18 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: ) self._system_agg_function_prefetch_job = None + def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]: + """Submit a prefetch query as internal async and return an AsyncJob handle.""" + try: + result = self._conn.execute_async_and_notify_query_listener( + query, + _is_internal=True, + ) + return self.create_async_job(result["queryId"]) + except Exception as e: # pragma: no cover + _logger.debug("Unable to submit internal async prefetch query: %s", e) + return None + def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ Returns a DataFrame representing the results of a directory table query on the specified stage. diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index b446347d51..76ed0502c4 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2519,3 +2519,82 @@ def test_retrieving_aggregation_funcs(session, monkeypatch): assert not context._aggregation_function_set session._retrieve_aggregation_function_list() assert not context._aggregation_function_set + + +def test_internal_async_aggregation_prefetch_submission(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._user_agg_function_prefetch_job = None + session._system_agg_function_prefetch_job = None + + call_kwargs = [] + + def _fake_execute_async(query, **kwargs): + call_kwargs.append(kwargs) + return {"queryId": f"qid_{len(call_kwargs)}"} + + monkeypatch.setattr( + session._conn, "execute_async_and_notify_query_listener", _fake_execute_async + ) + session._start_async_aggregation_prefetch_if_needed() + + assert len(call_kwargs) == 2 + assert all(kwargs.get("_is_internal") is True for kwargs in call_kwargs) + assert session._user_agg_function_prefetch_job.query_id == "qid_1" + assert session._system_agg_function_prefetch_job.query_id == "qid_2" + + +def test_aggregation_fallback_used_when_system_source_fails(session, monkeypatch): + import snowflake.snowpark.context as context + + class _FakeAsyncJob: + def __init__(self, rows=None, error=None) -> None: + self._rows = rows + self._error = error + + def result(self): + if self._error is not None: + raise self._error + return self._rows + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._user_agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)]) + session._system_agg_function_prefetch_job = _FakeAsyncJob( + error=RuntimeError("system fetch failed") + ) + + session._retrieve_aggregation_function_list() + + assert "sum" in context._aggregation_function_set + assert "sum_internal" in context._aggregation_function_set + + +def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._user_agg_function_prefetch_job = None + session._system_agg_function_prefetch_job = None + + call_kwargs = [] + + def _fake_run_query(query, **kwargs): + call_kwargs.append(kwargs) + if "information_schema.functions" in query: + return {"data": [("SUM",)]} + return {"data": [("AVG",)]} + + monkeypatch.setattr(session._conn, "run_query", _fake_run_query) + session._retrieve_aggregation_function_list() + + assert len(call_kwargs) == 2 + assert all(kwargs.get("_is_internal") is True for kwargs in call_kwargs) + assert "sum" in context._aggregation_function_set + assert "avg" in context._aggregation_function_set diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0349618659..6b44195b0e 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -818,34 +818,34 @@ def test_retrieve_aggregation_function_list_handles_user_defined_error(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() - mock_df = MagicMock() - call_count = [0] - - def sql_side_effect(query, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + if "information_schema.functions" in query: raise RuntimeError("user-defined query failed") - mock_df.collect.return_value = [["SUM"], ["AVG"]] - return mock_df + return {"data": [["SUM"], ["AVG"]]} - with mock.patch.object(session, "sql", side_effect=sql_side_effect): + with mock.patch.object( + fake_server_connection, "run_query", side_effect=run_query_side_effect + ): session._retrieve_aggregation_function_list() assert "sum" in ctx._aggregation_function_set assert "avg" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set def test_retrieve_aggregation_function_list_handles_system_error(): - """When querying system aggregation functions fails, the method falls back - to the hardcoded _KNOWN_AGGREGATION_FUNCTIONS set.""" + """When system aggregation metadata retrieval fails, hardcoded fallback applies.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -853,20 +853,29 @@ def test_retrieve_aggregation_function_list_handles_system_error(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() - mock_df = MagicMock() - mock_df.collect.side_effect = RuntimeError("system query failed") + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + if "show functions" in query: + raise RuntimeError("system query failed") + return {"data": [["SUM"]]} - with mock.patch.object(session, "sql", return_value=mock_df): + with mock.patch.object( + fake_server_connection, "run_query", side_effect=run_query_side_effect + ): session._retrieve_aggregation_function_list() + assert "sum" in ctx._aggregation_function_set assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set @@ -880,17 +889,26 @@ def test_retrieve_aggregation_function_list_handles_both_errors(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + raise RuntimeError("query failed") + with mock.patch.object( - session, "sql", side_effect=RuntimeError("query failed") + fake_server_connection, + "run_query", + side_effect=run_query_side_effect, ): session._retrieve_aggregation_function_list() assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set From c606f3f846b4db56aba2c6edbf4420e68f285b3a Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 12 May 2026 16:17:55 -0700 Subject: [PATCH 5/5] update change --- src/snowflake/snowpark/session.py | 101 +++++++-------------------- tests/integ/test_simplifier_suite.py | 45 ++++++------ tests/unit/test_session.py | 37 +++++----- 3 files changed, 69 insertions(+), 114 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 441bbb23d0..f943881e7c 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -856,8 +856,7 @@ def __init__( self._dataframe_profiler = DataframeProfiler(session=self) self._catalog = None self._client_telemetry = EventTableTelemetry(session=self) - self._system_agg_function_prefetch_job: Optional[AsyncJob] = None - self._user_agg_function_prefetch_job: Optional[AsyncJob] = None + self._agg_function_prefetch_job: Optional[AsyncJob] = None self._ast_batch = AstBatch(self) self._start_async_aggregation_prefetch_if_needed() @@ -5060,59 +5059,27 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set = set() system_fetch_succeeded = False - # User-defined aggregation functions. - # If init has already issued the async query, wait and use it. - # Otherwise, execute synchronously now for select-statement correctness. - if self._user_agg_function_prefetch_job is not None: + # Try async result first if prefetch was already started. + if self._agg_function_prefetch_job is not None: try: retrieved_set.update( - { - r[0].lower() - for r in self._user_agg_function_prefetch_job.result() - } - ) - except Exception as e: - _logger.debug( - "Unable to use async user-defined aggregation function prefetch: %s", - e, - ) - finally: - self._user_agg_function_prefetch_job = None - else: - try: - retrieved_set.update( - { - r[0].lower() - for r in self._conn.run_query( - """select function_name from information_schema.functions where is_aggregate = 'YES'""", - _is_internal=True, - )["data"] - } - ) - except Exception as e: - _logger.debug( - "Unable to get user-defined aggregation functions: %s", - e, - ) - - # System aggregation functions from metadata query. - if self._system_agg_function_prefetch_job is not None: - try: - retrieved_set.update( - { - r[0].lower() - for r in self._system_agg_function_prefetch_job.result() - } + {r[0].lower() for r in self._agg_function_prefetch_job.result()} ) system_fetch_succeeded = True except Exception as e: _logger.debug( - "Unable to use async system aggregation function prefetch: %s", + "Unable to use async aggregation function prefetch: %s", e, ) finally: - self._system_agg_function_prefetch_job = None + self._agg_function_prefetch_job = None else: + _logger.debug( + "Async aggregation function prefetch job is unavailable; using sync fallback." + ) + + # Sync fallback query. + if not system_fetch_succeeded: try: retrieved_set.update( { @@ -5126,11 +5093,11 @@ def _retrieve_aggregation_function_list(self) -> None: system_fetch_succeeded = True except Exception as e: _logger.debug( - "Unable to get system aggregation functions: %s", + "Unable to get aggregation functions via sync fallback query: %s", e, ) - # Fallback to the local hardcoded list only when both metadata fetches fail. + # Fallback to the local hardcoded list only when metadata retrieval fails. if not system_fetch_succeeded: retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) @@ -5146,35 +5113,21 @@ def _start_async_aggregation_prefetch_if_needed(self) -> None: return if context._aggregation_function_set: return - if ( - self._user_agg_function_prefetch_job is not None - and self._system_agg_function_prefetch_job is not None - ): + if self._agg_function_prefetch_job is not None: return - if self._user_agg_function_prefetch_job is None: - try: - self._user_agg_function_prefetch_job = self._submit_internal_async_prefetch_query( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ) - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async user-defined aggregation metadata prefetch: %s", - e, - ) - self._user_agg_function_prefetch_job = None - - if self._system_agg_function_prefetch_job is None: - try: - self._system_agg_function_prefetch_job = self._submit_internal_async_prefetch_query( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ) - except Exception as e: # pragma: no cover - _logger.debug( - "Unable to start async system aggregation metadata prefetch: %s", - e, - ) - self._system_agg_function_prefetch_job = None + try: + self._agg_function_prefetch_job = self._submit_internal_async_prefetch_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' +union +select function_name from information_schema.functions where is_aggregate = 'YES'""" + ) + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async aggregation metadata prefetch: %s", + e, + ) + self._agg_function_prefetch_job = None def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]: """Submit a prefetch query as internal async and return an AsyncJob handle.""" diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 76ed0502c4..976035b1c7 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2527,27 +2527,29 @@ def test_internal_async_aggregation_prefetch_submission(session, monkeypatch): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._user_agg_function_prefetch_job = None - session._system_agg_function_prefetch_job = None + session._agg_function_prefetch_job = None - call_kwargs = [] + calls = [] def _fake_execute_async(query, **kwargs): - call_kwargs.append(kwargs) - return {"queryId": f"qid_{len(call_kwargs)}"} + calls.append((query, kwargs)) + return {"queryId": "qid_combined"} monkeypatch.setattr( session._conn, "execute_async_and_notify_query_listener", _fake_execute_async ) session._start_async_aggregation_prefetch_if_needed() - assert len(call_kwargs) == 2 - assert all(kwargs.get("_is_internal") is True for kwargs in call_kwargs) - assert session._user_agg_function_prefetch_job.query_id == "qid_1" - assert session._system_agg_function_prefetch_job.query_id == "qid_2" + assert len(calls) == 1 + assert calls[0][1].get("_is_internal") is True + assert "show functions" in calls[0][0] + assert "information_schema.functions" in calls[0][0] + assert session._agg_function_prefetch_job.query_id == "qid_combined" -def test_aggregation_fallback_used_when_system_source_fails(session, monkeypatch): +def test_aggregation_fallback_not_used_when_combined_async_succeeds( + session, monkeypatch +): import snowflake.snowpark.context as context class _FakeAsyncJob: @@ -2563,15 +2565,12 @@ def result(self): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._user_agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)]) - session._system_agg_function_prefetch_job = _FakeAsyncJob( - error=RuntimeError("system fetch failed") - ) + session._agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)]) session._retrieve_aggregation_function_list() assert "sum" in context._aggregation_function_set - assert "sum_internal" in context._aggregation_function_set + assert "sum_internal" not in context._aggregation_function_set def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): @@ -2580,21 +2579,19 @@ def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) monkeypatch.setattr(context, "_aggregation_function_set", set()) - session._user_agg_function_prefetch_job = None - session._system_agg_function_prefetch_job = None + session._agg_function_prefetch_job = None - call_kwargs = [] + calls = [] def _fake_run_query(query, **kwargs): - call_kwargs.append(kwargs) - if "information_schema.functions" in query: - return {"data": [("SUM",)]} + calls.append((query, kwargs)) return {"data": [("AVG",)]} monkeypatch.setattr(session._conn, "run_query", _fake_run_query) session._retrieve_aggregation_function_list() - assert len(call_kwargs) == 2 - assert all(kwargs.get("_is_internal") is True for kwargs in call_kwargs) - assert "sum" in context._aggregation_function_set + assert len(calls) == 1 + assert calls[0][1].get("_is_internal") is True + assert "show functions" in calls[0][0] + assert "information_schema.functions" not in calls[0][0] assert "avg" in context._aggregation_function_set diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6b44195b0e..4b508c6dc9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -808,9 +808,8 @@ def test_infer_is_return_table_uses_internal_describe(): assert mocked_run_query.call_count == 1 -def test_retrieve_aggregation_function_list_handles_user_defined_error(): - """When querying user-defined aggregation functions fails, the error is - swallowed and the method continues to query system functions.""" +def test_retrieve_aggregation_function_list_handles_async_error(): + """When async metadata prefetch fails, sync internal fallback is used.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -825,10 +824,13 @@ def test_retrieve_aggregation_function_list_handles_user_defined_error(): ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + fake_async_job = MagicMock() + fake_async_job.result.side_effect = RuntimeError("async query failed") + session._agg_function_prefetch_job = fake_async_job + def run_query_side_effect(query, **kwargs): assert kwargs.get("_is_internal") is True - if "information_schema.functions" in query: - raise RuntimeError("user-defined query failed") + assert "show functions" in query return {"data": [["SUM"], ["AVG"]]} with mock.patch.object( @@ -844,8 +846,8 @@ def run_query_side_effect(query, **kwargs): ctx._aggregation_function_set = original_agg_set -def test_retrieve_aggregation_function_list_handles_system_error(): - """When system aggregation metadata retrieval fails, hardcoded fallback applies.""" +def test_retrieve_aggregation_function_list_handles_sync_error(): + """When sync metadata query fails, hardcoded fallback applies.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -862,16 +864,14 @@ def test_retrieve_aggregation_function_list_handles_system_error(): def run_query_side_effect(query, **kwargs): assert kwargs.get("_is_internal") is True - if "show functions" in query: - raise RuntimeError("system query failed") - return {"data": [["SUM"]]} + assert "show functions" in query + raise RuntimeError("sync query failed") with mock.patch.object( fake_server_connection, "run_query", side_effect=run_query_side_effect ): session._retrieve_aggregation_function_list() - assert "sum" in ctx._aggregation_function_set assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) finally: ctx._is_snowpark_connect_compatible_mode = original_compat @@ -879,9 +879,8 @@ def run_query_side_effect(query, **kwargs): ctx._aggregation_function_set = original_agg_set -def test_retrieve_aggregation_function_list_handles_both_errors(): - """When both aggregation function queries fail, the hardcoded fallback - set is still populated.""" +def test_retrieve_aggregation_function_list_uses_single_internal_sync_query(): + """Sync fallback executes exactly one internal metadata query.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -896,9 +895,12 @@ def test_retrieve_aggregation_function_list_handles_both_errors(): ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + called_queries = [] + def run_query_side_effect(query, **kwargs): + called_queries.append(query) assert kwargs.get("_is_internal") is True - raise RuntimeError("query failed") + return {"data": [["SUM"]]} with mock.patch.object( fake_server_connection, @@ -907,7 +909,10 @@ def run_query_side_effect(query, **kwargs): ): session._retrieve_aggregation_function_list() - assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) + assert len(called_queries) == 1 + assert "show functions" in called_queries[0] + assert "information_schema.functions" not in called_queries[0] + assert "sum" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat ctx._snowpark_connect_flatten_select_after_sort = original_flatten