diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index a111839050..240672a571 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -57,9 +57,6 @@ "ai_agg", "ai_summarize_agg", "any_value", - "approximate_count_distinct", - "approximate_jaccard_index", - "approximate_similarity", "approx_count_distinct", "approx_percentile", "approx_percentile_accumulate", @@ -67,32 +64,38 @@ "approx_top_k", "approx_top_k_accumulate", "approx_top_k_combine", - "arrayagg", + "approximate_count_distinct", + "approximate_jaccard_index", + "approximate_similarity", "array_agg", "array_union_agg", "array_unique_agg", + "arrayagg", "avg", - "bitandagg", + "bit_and_agg", + "bit_andagg", + "bit_or_agg", + "bit_oragg", + "bit_xor_agg", + "bit_xoragg", "bitand_agg", + "bitandagg", + "bitmap_and_agg", "bitmap_construct_agg", "bitmap_or_agg", - "bitoragg", "bitor_agg", - "bitxoragg", + "bitoragg", "bitxor_agg", - "bit_andagg", - "bit_and_agg", - "bit_oragg", - "bit_or_agg", - "bit_xoragg", - "bit_xor_agg", + "bitxoragg", "booland_agg", "boolor_agg", "boolxor_agg", "corr", "count", + "count(*)", "count_if", "count_internal", + "count_internal(*)", "covar_pop", "covar_samp", "datasketches_hll", @@ -110,12 +113,12 @@ "max_by", "median", "min", + "min_by", "minhash", "minhash_combine", - "min_by", "mode", - "objectagg", "object_agg", + "objectagg", "percentile_cont", "percentile_disc", "regr_avgx", @@ -128,20 +131,21 @@ "regr_sxy", "regr_syy", "skew", + "st_intersection_agg_geography_internal", + "st_union_agg_geography_internal", "stddev", "stddev_pop", "stddev_samp", - "st_intersection_agg_geography_internal", - "st_union_agg_geography_internal", "sum", "sum_internal", "sum_internal_real", "sum_real", + "summarize_agg", + "var_pop", + "var_samp", "variance", "variance_pop", "variance_samp", - "var_pop", - "var_samp", "vector_avg", "vector_max", "vector_min", @@ -149,6 +153,7 @@ ] ) + _cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session. # Following are internal-only global flags, used to enable development features. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 457f28f95b..f943881e7c 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -856,8 +856,10 @@ def __init__( self._dataframe_profiler = DataframeProfiler(session=self) self._catalog = None self._client_telemetry = EventTableTelemetry(session=self) + self._agg_function_prefetch_job: Optional[AsyncJob] = None self._ast_batch = AstBatch(self) + self._start_async_aggregation_prefetch_if_needed() _logger.info("Snowpark Session information: %s", self._session_info) @@ -5055,43 +5057,89 @@ def _retrieve_aggregation_function_list(self) -> None: return retrieved_set = set() + system_fetch_succeeded = False - # User-defined aggregation functions - try: - retrieved_set.update( - { - r[0].lower() - for r in self.sql( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect() - } - ) - except Exception as e: + # Try async result first if prefetch was already started. + if self._agg_function_prefetch_job is not None: + try: + retrieved_set.update( + {r[0].lower() for r in self._agg_function_prefetch_job.result()} + ) + system_fetch_succeeded = True + except Exception as e: + _logger.debug( + "Unable to use async aggregation function prefetch: %s", + e, + ) + finally: + self._agg_function_prefetch_job = None + else: _logger.debug( - "Unable to get user-defined aggregation functions: %s", - e, + "Async aggregation function prefetch job is unavailable; using sync fallback." ) - # System built-in aggregation functions + # Sync fallback query. + if not system_fetch_succeeded: + try: + retrieved_set.update( + { + r[0].lower() + for r in self._conn.run_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + _is_internal=True, + )["data"] + } + ) + system_fetch_succeeded = True + except Exception as e: + _logger.debug( + "Unable to get aggregation functions via sync fallback query: %s", + e, + ) + + # Fallback to the local hardcoded list only when metadata retrieval fails. + if not system_fetch_succeeded: + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + + with context._aggregation_function_set_lock: + context._aggregation_function_set.update(retrieved_set) + + def _start_async_aggregation_prefetch_if_needed(self) -> None: + """Start aggregation metadata prefetch only when not already in progress.""" + if not ( + context._is_snowpark_connect_compatible_mode + and context._snowpark_connect_flatten_select_after_sort + ): + return + if context._aggregation_function_set: + return + if self._agg_function_prefetch_job is not None: + return + try: - retrieved_set.update( - { - r[0].lower() - for r in self.sql( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect() - } + self._agg_function_prefetch_job = self._submit_internal_async_prefetch_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y' +union +select function_name from information_schema.functions where is_aggregate = 'YES'""" ) - except Exception as e: + except Exception as e: # pragma: no cover _logger.debug( - "Unable to get system aggregation functions, " - "falling back to hardcoded list: %s", + "Unable to start async aggregation metadata prefetch: %s", e, ) - retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + self._agg_function_prefetch_job = None - with context._aggregation_function_set_lock: - context._aggregation_function_set.update(retrieved_set) + def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]: + """Submit a prefetch query as internal async and return an AsyncJob handle.""" + try: + result = self._conn.execute_async_and_notify_query_listener( + query, + _is_internal=True, + ) + return self.create_async_job(result["queryId"]) + except Exception as e: # pragma: no cover + _logger.debug("Unable to submit internal async prefetch query: %s", e) + return None def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index b446347d51..976035b1c7 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2519,3 +2519,79 @@ def test_retrieving_aggregation_funcs(session, monkeypatch): assert not context._aggregation_function_set session._retrieve_aggregation_function_list() assert not context._aggregation_function_set + + +def test_internal_async_aggregation_prefetch_submission(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._agg_function_prefetch_job = None + + calls = [] + + def _fake_execute_async(query, **kwargs): + calls.append((query, kwargs)) + return {"queryId": "qid_combined"} + + monkeypatch.setattr( + session._conn, "execute_async_and_notify_query_listener", _fake_execute_async + ) + session._start_async_aggregation_prefetch_if_needed() + + assert len(calls) == 1 + assert calls[0][1].get("_is_internal") is True + assert "show functions" in calls[0][0] + assert "information_schema.functions" in calls[0][0] + assert session._agg_function_prefetch_job.query_id == "qid_combined" + + +def test_aggregation_fallback_not_used_when_combined_async_succeeds( + session, monkeypatch +): + import snowflake.snowpark.context as context + + class _FakeAsyncJob: + def __init__(self, rows=None, error=None) -> None: + self._rows = rows + self._error = error + + def result(self): + if self._error is not None: + raise self._error + return self._rows + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)]) + + session._retrieve_aggregation_function_list() + + assert "sum" in context._aggregation_function_set + assert "sum_internal" not in context._aggregation_function_set + + +def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + session._agg_function_prefetch_job = None + + calls = [] + + def _fake_run_query(query, **kwargs): + calls.append((query, kwargs)) + return {"data": [("AVG",)]} + + monkeypatch.setattr(session._conn, "run_query", _fake_run_query) + session._retrieve_aggregation_function_list() + + assert len(calls) == 1 + assert calls[0][1].get("_is_internal") is True + assert "show functions" in calls[0][0] + assert "information_schema.functions" not in calls[0][0] + assert "avg" in context._aggregation_function_set diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0349618659..4b508c6dc9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -808,9 +808,8 @@ def test_infer_is_return_table_uses_internal_describe(): assert mocked_run_query.call_count == 1 -def test_retrieve_aggregation_function_list_handles_user_defined_error(): - """When querying user-defined aggregation functions fails, the error is - swallowed and the method continues to query system functions.""" +def test_retrieve_aggregation_function_list_handles_async_error(): + """When async metadata prefetch fails, sync internal fallback is used.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -818,34 +817,37 @@ def test_retrieve_aggregation_function_list_handles_user_defined_error(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() - mock_df = MagicMock() - call_count = [0] + fake_async_job = MagicMock() + fake_async_job.result.side_effect = RuntimeError("async query failed") + session._agg_function_prefetch_job = fake_async_job - def sql_side_effect(query, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: - raise RuntimeError("user-defined query failed") - mock_df.collect.return_value = [["SUM"], ["AVG"]] - return mock_df + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + assert "show functions" in query + return {"data": [["SUM"], ["AVG"]]} - with mock.patch.object(session, "sql", side_effect=sql_side_effect): + with mock.patch.object( + fake_server_connection, "run_query", side_effect=run_query_side_effect + ): session._retrieve_aggregation_function_list() assert "sum" in ctx._aggregation_function_set assert "avg" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set -def test_retrieve_aggregation_function_list_handles_system_error(): - """When querying system aggregation functions fails, the method falls back - to the hardcoded _KNOWN_AGGREGATION_FUNCTIONS set.""" +def test_retrieve_aggregation_function_list_handles_sync_error(): + """When sync metadata query fails, hardcoded fallback applies.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -853,26 +855,32 @@ def test_retrieve_aggregation_function_list_handles_system_error(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() - mock_df = MagicMock() - mock_df.collect.side_effect = RuntimeError("system query failed") + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + assert "show functions" in query + raise RuntimeError("sync query failed") - with mock.patch.object(session, "sql", return_value=mock_df): + with mock.patch.object( + fake_server_connection, "run_query", side_effect=run_query_side_effect + ): session._retrieve_aggregation_function_list() assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set -def test_retrieve_aggregation_function_list_handles_both_errors(): - """When both aggregation function queries fail, the hardcoded fallback - set is still populated.""" +def test_retrieve_aggregation_function_list_uses_single_internal_sync_query(): + """Sync fallback executes exactly one internal metadata query.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -880,17 +888,32 @@ def test_retrieve_aggregation_function_list_handles_both_errors(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + called_queries = [] + + def run_query_side_effect(query, **kwargs): + called_queries.append(query) + assert kwargs.get("_is_internal") is True + return {"data": [["SUM"]]} + with mock.patch.object( - session, "sql", side_effect=RuntimeError("query failed") + fake_server_connection, + "run_query", + side_effect=run_query_side_effect, ): session._retrieve_aggregation_function_list() - assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) + assert len(called_queries) == 1 + assert "show functions" in called_queries[0] + assert "information_schema.functions" not in called_queries[0] + assert "sum" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set