From 3467b7b4b2a6a9f274e8ca413859c0c0c3a7916e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 21 Apr 2026 10:57:47 -0700 Subject: [PATCH 01/25] reapply change --- CHANGELOG.md | 8 + src/snowflake/snowpark/catalog.py | 229 +++++++++++++++++++-------- src/snowflake/snowpark/exceptions.py | 6 + tests/integ/test_catalog.py | 17 +- 4 files changed, 185 insertions(+), 75 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77092b7da1..e3e7f005d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History +## 1.51.0 (TBD) + +### Snowpark Python API Updates + +#### Improvements + +- Catalog API now uses SQL commands instead of SnowAPI calls to improve stability. + ## 1.50.0 (TBD) ### Snowpark Python API Updates diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 4572809ea2..730ae184a7 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -2,15 +2,24 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from ctypes import ArgumentError import re -from typing import List, Optional, Union +from typing import ( + List, + Optional, + Union, + TYPE_CHECKING, +) + +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.exceptions import SnowparkSQLException, NotFoundError try: - from snowflake.core import Root # type: ignore from snowflake.core.database import Database # type: ignore - from snowflake.core.exceptions import NotFoundError + from snowflake.core.database._generated.models import Database as ModelDatabase # type: ignore from snowflake.core.procedure import Procedure from snowflake.core.schema import Schema # type: ignore + from snowflake.core.schema._generated.models import Schema as ModelSchema # type: ignore from snowflake.core.table import Table, TableColumn from snowflake.core.user_defined_function import UserDefinedFunction from snowflake.core.view import View @@ -19,27 +28,28 @@ "Missing optional dependency: 'snowflake.core'." ) from e # pragma: no cover - -import snowflake.snowpark -from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type +from snowflake.snowpark._internal.type_utils import ( + convert_sp_to_sf_type, + type_string_to_type_object, +) from snowflake.snowpark.functions import lit, parse_json from snowflake.snowpark.types import DataType +if TYPE_CHECKING: + from snowflake.snowpark.session import Session -class Catalog: """The Catalog class provides methods to interact with and manage the Snowflake objects. It allows users to list, get, and drop various database objects such as databases, schemas, tables, views, functions, etc. """ - def __init__(self, session: "snowflake.snowpark.session.Session") -> None: # type: ignore + def __init__(self, session: "Session") -> None: self._session = session - self._root = Root(session) self._python_regex_udf = None def _parse_database( self, - database: Optional[Union[str, Database]], + database: object, model_obj: Optional[ Union[str, Schema, Table, View, Procedure, UserDefinedFunction] ] = None, @@ -66,7 +76,7 @@ def _parse_database( def _parse_schema( self, - schema: Optional[Union[str, Schema]], + schema: object, model_obj: Optional[ Union[str, Table, View, Procedure, UserDefinedFunction] ] = None, @@ -166,11 +176,28 @@ def list_databases( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - iter = self._root.databases.iter(like=like) + like_str = f"LIKE '{like}'" if like else "" + df = self._session.sql(f"SHOW AS RESOURCE DATABASES {like_str}") if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) + # initialize udf + self._initialize_regex_udf() + assert self._python_regex_udf is not None # pyright - return list(iter) + # The result of SHOW AS RESOURCE query is a json string which contains + # key 'name' to store the name of the object. We parse json for the returned + # result and apply the filter on name. + df = df.filter( + self._python_regex_udf( + lit(pattern), parse_json('"As Resource"')["name"] + ) + ) + + return list( + map( + lambda row: Database._from_model(ModelDatabase.from_json(str(row[0]))), + df.collect(), + ) + ) def list_schemas( self, @@ -188,10 +215,28 @@ def list_schemas( like: the sql style pattern for name to match. Default to None. """ db_name = self._parse_database(database) - iter = self._root.databases[db_name].schemas.iter(like=like) + like_str = f"LIKE '{like}'" if like else "" + df = self._session.sql(f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name}") if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) - return list(iter) + # initialize udf + self._initialize_regex_udf() + assert self._python_regex_udf is not None # pyright + + # The result of SHOW AS RESOURCE query is a json string which contains + # key 'name' to store the name of the object. We parse json for the returned + # result and apply the filter on name. + df = df.filter( + self._python_regex_udf( + lit(pattern), parse_json('"As Resource"')["name"] + ) + ) + + return list( + map( + lambda row: Schema._from_model(ModelSchema.from_json(str(row[0]))), + df.collect(), + ) + ) def list_tables( self, @@ -329,14 +374,27 @@ def get_current_schema(self) -> Optional[str]: def get_database(self, database: str) -> Database: """Name of the database to get""" - return self._root.databases[database].fetch() + try: + return self.list_databases(like=unquote_if_quoted(database))[0] + except IndexError: + raise NotFoundError(f"Database with name {database} could not be found") def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: """Name of the schema to get.""" db_name = self._parse_database(database) - return self._root.databases[db_name].schemas[schema].fetch() + try: + return self.list_schemas(database=db_name, like=unquote_if_quoted(schema))[ + 0 + ] + except ( + IndexError, # schema with this name doesn't exist + SnowparkSQLException, # database in which we are looking doesn't exist + ): + raise NotFoundError( + f"Schema with name {schema} could not be found in database '{db_name}'" + ) def get_table( self, @@ -355,12 +413,16 @@ def get_table( """ db_name = self._parse_database(database) schema_name = self._parse_schema(schema) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .tables[table_name] - .fetch() - ) + try: + return self.listTables( + database=db_name, + schema=schema_name, + like=unquote_if_quoted(table_name), + )[0] + except IndexError: + raise NotFoundError( + f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" + ) def get_view( self, @@ -379,9 +441,16 @@ def get_view( """ db_name = self._parse_database(database) schema_name = self._parse_schema(schema) - return ( - self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() - ) + try: + return self.list_views( + database=db_name, + schema=schema_name, + like=unquote_if_quoted(view_name), + )[0] + except IndexError: + raise NotFoundError( + f"View with name {view_name} could not be found in schema '{db_name}.{schema_name}'" + ) def get_procedure( self, @@ -403,12 +472,19 @@ def get_procedure( db_name = self._parse_database(database) schema_name = self._parse_schema(schema) procedure_id = self._parse_function_or_procedure(procedure_name, arg_types) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .procedures[procedure_id] - .fetch() - ) + + try: + procedures = self._session.sql( + f"DESCRIBE AS RESOURCE PROCEDURE {db_name}.{schema_name}.{procedure_id}" + ).collect() + return Procedure.from_json(str(procedures[0][0])) + except ( + IndexError, # when sql returned no results + SnowparkSQLException, # when database, or schema doesn't exist + ): + raise NotFoundError( + f"Procedure with name {procedure_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) def get_user_defined_function( self, @@ -431,12 +507,19 @@ def get_user_defined_function( db_name = self._parse_database(database) schema_name = self._parse_schema(schema) function_id = self._parse_function_or_procedure(udf_name, arg_types) - return ( - self._root.databases[db_name] - .schemas[schema_name] - .user_defined_functions[function_id] - .fetch() - ) + + try: + procedures = self._session.sql( + f"DESCRIBE AS RESOURCE FUNCTION {db_name}.{schema_name}.{function_id}" + ).collect() + return UserDefinedFunction.from_json(str(procedures[0][0])) + except ( + IndexError, # when sql returned no results + SnowparkSQLException, # when database, or schema doesn't exist + ): + raise NotFoundError( + f"Function with name {udf_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) # set methods def set_current_database(self, database: Union[str, Database]) -> None: @@ -466,7 +549,7 @@ def database_exists(self, database: Union[str, Database]) -> bool: """ db_name = self._parse_database(database) try: - self._root.databases[db_name].fetch() + self.get_database(db_name) return True except NotFoundError: return False @@ -487,7 +570,7 @@ def schema_exists( db_name = self._parse_database(database, schema) schema_name = self._parse_schema(schema) try: - self._root.databases[db_name].schemas[schema_name].fetch() + self.get_schema(schema=schema_name, database=db_name) return True except NotFoundError: return False @@ -511,9 +594,7 @@ def table_exists( schema_name = self._parse_schema(schema, table) table_name = table if isinstance(table, str) else table.name try: - self._root.databases[db_name].schemas[schema_name].tables[ - table_name - ].fetch() + self.get_table(table_name=table_name, database=db_name, schema=schema_name) return True except NotFoundError: return False @@ -537,7 +618,7 @@ def view_exists( schema_name = self._parse_schema(schema, view) view_name = view if isinstance(view, str) else view.name try: - self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + self.get_view(view_name=view_name, database=db_name, schema=schema_name) return True except NotFoundError: return False @@ -559,14 +640,24 @@ def procedure_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, procedure) - schema_name = self._parse_schema(schema, procedure) - procedure_id = self._parse_function_or_procedure(procedure, arg_types) - try: - self._root.databases[db_name].schemas[schema_name].procedures[ - procedure_id - ].fetch() + if isinstance(procedure, Procedure): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided procedure is a Procedure class no other arguments can be provided" + ) + database = procedure.database_name + schema = procedure.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in procedure.arguments + ] + procedure = procedure.name + self.get_procedure( + procedure_name=procedure, + arg_types=arg_types, + database=database, + schema=schema, + ) return True except NotFoundError: return False @@ -590,14 +681,24 @@ def user_defined_function_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, udf) - schema_name = self._parse_schema(schema, udf) - function_id = self._parse_function_or_procedure(udf, arg_types) - try: - self._root.databases[db_name].schemas[schema_name].user_defined_functions[ - function_id - ].fetch() + if isinstance(udf, UserDefinedFunction): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided udf is a UserDefinedFunction class no other arguments can be provided" + ) + database = udf.database_name + schema = udf.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in udf.arguments + ] + udf = udf.name + self.get_user_defined_function( + udf_name=udf, + arg_types=arg_types, + database=database, + schema=schema, + ) return True except NotFoundError: return False @@ -610,7 +711,7 @@ def drop_database(self, database: Union[str, Database]) -> None: database: database name or ``Database`` object. """ db_name = self._parse_database(database) - self._root.databases[db_name].drop() + self._session.sql(f"DROP DATABASE {db_name}").collect() def drop_schema( self, @@ -627,7 +728,7 @@ def drop_schema( """ db_name = self._parse_database(database, schema) schema_name = self._parse_schema(schema) - self._root.databases[db_name].schemas[schema_name].drop() + self._session.sql(f"DROP SCHEMA {db_name}.{schema_name}").collect() def drop_table( self, @@ -648,7 +749,7 @@ def drop_table( schema_name = self._parse_schema(schema, table) table_name = table if isinstance(table, str) else table.name - self._root.databases[db_name].schemas[schema_name].tables[table_name].drop() + self._session.sql(f"DROP TABLE {db_name}.{schema_name}.{table_name}").collect() def drop_view( self, @@ -669,7 +770,7 @@ def drop_view( schema_name = self._parse_schema(schema, view) view_name = view if isinstance(view, str) else view.name - self._root.databases[db_name].schemas[schema_name].views[view_name].drop() + self._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() # aliases listDatabases = list_databases diff --git a/src/snowflake/snowpark/exceptions.py b/src/snowflake/snowpark/exceptions.py index 1142e9545e..d31fe178a6 100644 --- a/src/snowflake/snowpark/exceptions.py +++ b/src/snowflake/snowpark/exceptions.py @@ -283,3 +283,9 @@ class SnowparkInvalidObjectNameException(SnowparkGeneralException): """ pass + + +class NotFoundError(SnowparkClientException): + """Raised when we encounter an object is not found.""" + + pass diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index 11643e4005..e8bd173e21 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -10,7 +10,6 @@ from snowflake.snowpark.catalog import Catalog from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType -from snowflake.core.exceptions import APIError pytestmark = [ @@ -19,10 +18,6 @@ reason="deepcopy is not supported and required by local testing", run=False, ), - pytest.mark.xfail( - raises=APIError, - reason="Failure due to warehouse overload", - ), ] CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" @@ -412,8 +407,8 @@ def test_exists_db_schema(session, temp_db1, temp_schema1): def test_exists_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): catalog = session.catalog - db1_obj = catalog._root.databases[temp_db1].fetch() - schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch() + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) @@ -437,8 +432,8 @@ def test_exists_function_procedure_udf( session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 ): catalog = session.catalog - db1_obj = catalog._root.databases[temp_db1].fetch() - schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch() + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) assert catalog.procedure_exists( temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 @@ -481,8 +476,8 @@ def test_drop(session, use_object): temp_table = create_temp_table(session, temp_db, temp_schema) temp_view = create_temp_view(session, temp_db, temp_schema) if use_object: - temp_schema = catalog._root.databases[temp_db].schemas[temp_schema].fetch() - temp_db = catalog._root.databases[temp_db].fetch() + temp_schema = catalog.get_schema(temp_schema, database=temp_db) + temp_db = catalog.get_database(temp_db) assert catalog.database_exists(temp_db) assert catalog.schema_exists(temp_schema, database=temp_db) From 143f30a08cbab3086281fef69bb7d35b239be353 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 21 Apr 2026 11:09:41 -0700 Subject: [PATCH 02/25] fix test --- src/snowflake/snowpark/catalog.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 730ae184a7..515091731b 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -38,6 +38,8 @@ if TYPE_CHECKING: from snowflake.snowpark.session import Session + +class Catalog: """The Catalog class provides methods to interact with and manage the Snowflake objects. It allows users to list, get, and drop various database objects such as databases, schemas, tables, views, functions, etc. From e45239a8a1448bc2b347aa7ab8545f4250b75699 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 22 Apr 2026 15:39:39 -0700 Subject: [PATCH 03/25] avoid regress --- src/snowflake/snowpark/catalog.py | 33 ++++++++++++++++++------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 515091731b..8333d85ae4 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -404,27 +404,32 @@ def get_table( *, database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, - ) -> Table: - """Get the table by name in given database and schema. If database or schema are not - provided, get the table in the current database and schema. + ) -> Union[Table, View]: + """Get the table or permanent view by name in the given database and schema. + + If database or schema are not provided, resolve the name in the current database + and schema. Matches :meth:`pyspark.sql.Catalog.getTable`, which returns metadata + for base tables and for views. Args: - table_name: name of the table. + table_name: name of the table or view. database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ db_name = self._parse_database(database) schema_name = self._parse_schema(schema) - try: - return self.listTables( - database=db_name, - schema=schema_name, - like=unquote_if_quoted(table_name), - )[0] - except IndexError: - raise NotFoundError( - f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" - ) + like_arg = unquote_if_quoted(table_name) + tables = self.listTables(database=db_name, schema=schema_name, like=like_arg) + views: List[View] = [] + if not tables: + views = self.list_views(database=db_name, schema=schema_name, like=like_arg) + if tables: + return tables[0] + if views: + return views[0] + raise NotFoundError( + f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" + ) def get_view( self, From 0692a09991cfab6b2902ee8dbc5a5848424b3dd1 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 27 Apr 2026 12:00:48 -0700 Subject: [PATCH 04/25] dual mode of catalog --- src/snowflake/snowpark/catalog.py | 894 ++++++++++++++++++++------ tests/integ/conftest.py | 2 + tests/integ/test_catalog.py | 349 +--------- tests/integ/test_catalog_rest_mode.py | 253 ++++++++ tests/integ/test_catalog_sql_mode.py | 243 +++++++ 5 files changed, 1203 insertions(+), 538 deletions(-) create mode 100644 tests/integ/test_catalog_rest_mode.py create mode 100644 tests/integ/test_catalog_sql_mode.py diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 8333d85ae4..0a2af60e58 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from abc import ABC, abstractmethod from ctypes import ArgumentError import re from typing import ( @@ -11,12 +12,15 @@ TYPE_CHECKING, ) +from snowflake.snowpark import context from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark.exceptions import SnowparkSQLException, NotFoundError try: + from snowflake.core import Root # type: ignore from snowflake.core.database import Database # type: ignore from snowflake.core.database._generated.models import Database as ModelDatabase # type: ignore + from snowflake.core.exceptions import NotFoundError as CoreNotFoundError # type: ignore from snowflake.core.procedure import Procedure from snowflake.core.schema import Schema # type: ignore from snowflake.core.schema._generated.models import Schema as ModelSchema # type: ignore @@ -39,6 +43,662 @@ from snowflake.snowpark.session import Session +class _CatalogBackend(ABC): + """Internal catalog implementation selected by ``context._is_snowpark_connect_compatible_mode``.""" + + def __init__(self, catalog: "Catalog") -> None: + self._catalog = catalog + + @abstractmethod + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + pass + + @abstractmethod + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + pass + + @abstractmethod + def get_database(self, database: str) -> Database: + pass + + @abstractmethod + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + pass + + @abstractmethod + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + pass + + @abstractmethod + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + pass + + @abstractmethod + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + pass + + @abstractmethod + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + pass + + @abstractmethod + def database_exists(self, database: Union[str, Database]) -> bool: + pass + + @abstractmethod + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + pass + + @abstractmethod + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + pass + + @abstractmethod + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + pass + + @abstractmethod + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + pass + + @abstractmethod + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + pass + + +class _SqlCatalogBackend(_CatalogBackend): + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + c = self._catalog + like_str = f"LIKE '{like}'" if like else "" + df = c._session.sql(f"SHOW AS RESOURCE DATABASES {like_str}") + if pattern: + c._initialize_regex_udf() + assert c._python_regex_udf is not None # pyright + df = df.filter( + c._python_regex_udf(lit(pattern), parse_json('"As Resource"')["name"]) + ) + + return list( + map( + lambda row: Database._from_model(ModelDatabase.from_json(str(row[0]))), + df.collect(), + ) + ) + + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + c = self._catalog + db_name = c._parse_database(database) + like_str = f"LIKE '{like}'" if like else "" + df = c._session.sql(f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name}") + if pattern: + c._initialize_regex_udf() + assert c._python_regex_udf is not None # pyright + df = df.filter( + c._python_regex_udf(lit(pattern), parse_json('"As Resource"')["name"]) + ) + + return list( + map( + lambda row: Schema._from_model(ModelSchema.from_json(str(row[0]))), + df.collect(), + ) + ) + + def get_database(self, database: str) -> Database: + try: + return self.list_databases(like=unquote_if_quoted(database))[0] + except IndexError: + raise NotFoundError(f"Database with name {database} could not be found") + + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + c = self._catalog + db_name = c._parse_database(database) + try: + return self.list_schemas(database=db_name, like=unquote_if_quoted(schema))[ + 0 + ] + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Schema with name {schema} could not be found in database '{db_name}'" + ) + + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + like_arg = unquote_if_quoted(table_name) + tables = c.list_tables(database=db_name, schema=schema_name, like=like_arg) + views: List[View] = [] + if not tables: + views = c.list_views(database=db_name, schema=schema_name, like=like_arg) + if tables: + return tables[0] + if views: + return views[0] + raise NotFoundError( + f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + try: + return c.list_views( + database=db_name, + schema=schema_name, + like=unquote_if_quoted(view_name), + )[0] + except IndexError: + raise NotFoundError( + f"View with name {view_name} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + procedure_id = c._parse_function_or_procedure(procedure_name, arg_types) + + try: + procedures = c._session.sql( + f"DESCRIBE AS RESOURCE PROCEDURE {db_name}.{schema_name}.{procedure_id}" + ).collect() + return Procedure.from_json(str(procedures[0][0])) + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Procedure with name {procedure_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) + + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + function_id = c._parse_function_or_procedure(udf_name, arg_types) + + try: + rows = c._session.sql( + f"DESCRIBE AS RESOURCE FUNCTION {db_name}.{schema_name}.{function_id}" + ).collect() + return UserDefinedFunction.from_json(str(rows[0][0])) + except ( + IndexError, + SnowparkSQLException, + ): + raise NotFoundError( + f"Function with name {udf_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" + ) + + def database_exists(self, database: Union[str, Database]) -> bool: + c = self._catalog + db_name = c._parse_database(database) + try: + self.get_database(db_name) + return True + except NotFoundError: + return False + + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + try: + self.get_schema(schema=schema_name, database=db_name) + return True + except NotFoundError: + return False + + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + try: + self.get_table(table_name=table_name, database=db_name, schema=schema_name) + return True + except NotFoundError: + return False + + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + try: + self.get_view(view_name=view_name, database=db_name, schema=schema_name) + return True + except NotFoundError: + return False + + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + try: + if isinstance(procedure, Procedure): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided procedure is a Procedure class no other arguments can be provided" + ) + database = procedure.database_name + schema = procedure.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in procedure.arguments + ] + procedure = procedure.name + self.get_procedure( + procedure_name=procedure, + arg_types=arg_types, + database=database, + schema=schema, + ) + return True + except NotFoundError: + return False + + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + try: + if isinstance(udf, UserDefinedFunction): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided udf is a UserDefinedFunction class no other arguments can be provided" + ) + database = udf.database_name + schema = udf.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in udf.arguments + ] + udf = udf.name + self.get_user_defined_function( + udf_name=udf, + arg_types=arg_types, + database=database, + schema=schema, + ) + return True + except NotFoundError: + return False + + +class _RestCatalogBackend(_CatalogBackend): + def __init__(self, catalog: "Catalog") -> None: + super().__init__(catalog) + self._root_obj: Optional[Root] = None + + @property + def _root(self) -> Root: + if self._root_obj is None: + self._root_obj = Root(self._catalog._session) + return self._root_obj + + def list_databases( + self, + *, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Database]: + it = self._root.databases.iter(like=like) + if pattern: + it = filter(lambda x: re.match(pattern, x.name), it) + return list(it) + + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + like: Optional[str] = None, + ) -> List[Schema]: + db_name = self._catalog._parse_database(database) + it = self._root.databases[db_name].schemas.iter(like=like) + if pattern: + it = filter(lambda x: re.match(pattern, x.name), it) + return list(it) + + def get_database(self, database: str) -> Database: + return self._root.databases[database].fetch() + + def get_schema( + self, schema: str, *, database: Optional[Union[str, Database]] = None + ) -> Schema: + db_name = self._catalog._parse_database(database) + return self._root.databases[db_name].schemas[schema].fetch() + + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Union[Table, View]: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .tables[table_name] + .fetch() + ) + + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + return ( + self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + ) + + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + procedure_id = c._parse_function_or_procedure(procedure_name, arg_types) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .procedures[procedure_id] + .fetch() + ) + + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + c = self._catalog + db_name = c._parse_database(database) + schema_name = c._parse_schema(schema) + function_id = c._parse_function_or_procedure(udf_name, arg_types) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .user_defined_functions[function_id] + .fetch() + ) + + def database_exists(self, database: Union[str, Database]) -> bool: + c = self._catalog + db_name = c._parse_database(database) + try: + self._root.databases[db_name].fetch() + return True + except CoreNotFoundError: + return False + + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + try: + self._root.databases[db_name].schemas[schema_name].fetch() + return True + except CoreNotFoundError: + return False + + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + try: + self._root.databases[db_name].schemas[schema_name].tables[ + table_name + ].fetch() + return True + except CoreNotFoundError: + return False + + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + try: + self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + return True + except CoreNotFoundError: + return False + + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + try: + if isinstance(procedure, Procedure): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided procedure is a Procedure class no other arguments can be provided" + ) + database = procedure.database_name + schema = procedure.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in procedure.arguments + ] + procedure = procedure.name + db_name = c._parse_database(database, procedure) + schema_name = c._parse_schema(schema, procedure) + procedure_id = c._parse_function_or_procedure(procedure, arg_types) + self._root.databases[db_name].schemas[schema_name].procedures[ + procedure_id + ].fetch() + return True + except CoreNotFoundError: + return False + + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + c = self._catalog + try: + if isinstance(udf, UserDefinedFunction): + if arg_types is not None or database is not None or schema is not None: + raise ArgumentError( + "When provided udf is a UserDefinedFunction class no other arguments can be provided" + ) + database = udf.database_name + schema = udf.schema_name + arg_types = [ + type_string_to_type_object(a.datatype) for a in udf.arguments + ] + udf = udf.name + db_name = c._parse_database(database, udf) + schema_name = c._parse_schema(schema, udf) + function_id = c._parse_function_or_procedure(udf, arg_types) + self._root.databases[db_name].schemas[schema_name].user_defined_functions[ + function_id + ].fetch() + return True + except CoreNotFoundError: + return False + + class Catalog: """The Catalog class provides methods to interact with and manage the Snowflake objects. It allows users to list, get, and drop various database objects such as databases, schemas, tables, @@ -48,6 +708,15 @@ class Catalog: def __init__(self, session: "Session") -> None: self._session = session self._python_regex_udf = None + self._sql_backend = _SqlCatalogBackend(self) + self._rest_backend: Optional[_RestCatalogBackend] = None + + def _backend(self) -> _CatalogBackend: + if context._is_snowpark_connect_compatible_mode: + return self._sql_backend + if self._rest_backend is None: + self._rest_backend = _RestCatalogBackend(self) + return self._rest_backend def _parse_database( self, @@ -150,13 +819,9 @@ def _list_objects( f"SHOW AS RESOURCE {object_name} {like_str} IN {db_name}.{schema_name} -- catalog api" ) if pattern: - # initialize udf self._initialize_regex_udf() assert self._python_regex_udf is not None # pyright - # The result of SHOW AS RESOURCE query is a json string which contains - # key 'name' to store the name of the object. We parse json for the returned - # result and apply the filter on name. df = df.filter( self._python_regex_udf( lit(pattern), parse_json('"As Resource"')["name"] @@ -165,7 +830,6 @@ def _list_objects( return list(map(lambda row: object_class.from_json(row[0]), df.collect())) - # List methods def list_databases( self, *, @@ -178,28 +842,7 @@ def list_databases( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - like_str = f"LIKE '{like}'" if like else "" - df = self._session.sql(f"SHOW AS RESOURCE DATABASES {like_str}") - if pattern: - # initialize udf - self._initialize_regex_udf() - assert self._python_regex_udf is not None # pyright - - # The result of SHOW AS RESOURCE query is a json string which contains - # key 'name' to store the name of the object. We parse json for the returned - # result and apply the filter on name. - df = df.filter( - self._python_regex_udf( - lit(pattern), parse_json('"As Resource"')["name"] - ) - ) - - return list( - map( - lambda row: Database._from_model(ModelDatabase.from_json(str(row[0]))), - df.collect(), - ) - ) + return self._backend().list_databases(pattern=pattern, like=like) def list_schemas( self, @@ -216,28 +859,8 @@ def list_schemas( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - db_name = self._parse_database(database) - like_str = f"LIKE '{like}'" if like else "" - df = self._session.sql(f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name}") - if pattern: - # initialize udf - self._initialize_regex_udf() - assert self._python_regex_udf is not None # pyright - - # The result of SHOW AS RESOURCE query is a json string which contains - # key 'name' to store the name of the object. We parse json for the returned - # result and apply the filter on name. - df = df.filter( - self._python_regex_udf( - lit(pattern), parse_json('"As Resource"')["name"] - ) - ) - - return list( - map( - lambda row: Schema._from_model(ModelSchema.from_json(str(row[0]))), - df.collect(), - ) + return self._backend().list_schemas( + database=database, pattern=pattern, like=like ) def list_tables( @@ -365,7 +988,6 @@ def list_user_defined_functions( like=like, ) - # get methods def get_current_database(self) -> Optional[str]: """Get the current database.""" return self._session.get_current_database() @@ -376,27 +998,13 @@ def get_current_schema(self) -> Optional[str]: def get_database(self, database: str) -> Database: """Name of the database to get""" - try: - return self.list_databases(like=unquote_if_quoted(database))[0] - except IndexError: - raise NotFoundError(f"Database with name {database} could not be found") + return self._backend().get_database(database) def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: """Name of the schema to get.""" - db_name = self._parse_database(database) - try: - return self.list_schemas(database=db_name, like=unquote_if_quoted(schema))[ - 0 - ] - except ( - IndexError, # schema with this name doesn't exist - SnowparkSQLException, # database in which we are looking doesn't exist - ): - raise NotFoundError( - f"Schema with name {schema} could not be found in database '{db_name}'" - ) + return self._backend().get_schema(schema, database=database) def get_table( self, @@ -411,25 +1019,15 @@ def get_table( and schema. Matches :meth:`pyspark.sql.Catalog.getTable`, which returns metadata for base tables and for views. + When ``context._is_snowpark_connect_compatible_mode`` is False (legacy REST path), + only base tables are returned; use :meth:`get_view` for views. + Args: table_name: name of the table or view. database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - like_arg = unquote_if_quoted(table_name) - tables = self.listTables(database=db_name, schema=schema_name, like=like_arg) - views: List[View] = [] - if not tables: - views = self.list_views(database=db_name, schema=schema_name, like=like_arg) - if tables: - return tables[0] - if views: - return views[0] - raise NotFoundError( - f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" - ) + return self._backend().get_table(table_name, database=database, schema=schema) def get_view( self, @@ -446,18 +1044,7 @@ def get_view( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - try: - return self.list_views( - database=db_name, - schema=schema_name, - like=unquote_if_quoted(view_name), - )[0] - except IndexError: - raise NotFoundError( - f"View with name {view_name} could not be found in schema '{db_name}.{schema_name}'" - ) + return self._backend().get_view(view_name, database=database, schema=schema) def get_procedure( self, @@ -476,22 +1063,9 @@ def get_procedure( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - procedure_id = self._parse_function_or_procedure(procedure_name, arg_types) - - try: - procedures = self._session.sql( - f"DESCRIBE AS RESOURCE PROCEDURE {db_name}.{schema_name}.{procedure_id}" - ).collect() - return Procedure.from_json(str(procedures[0][0])) - except ( - IndexError, # when sql returned no results - SnowparkSQLException, # when database, or schema doesn't exist - ): - raise NotFoundError( - f"Procedure with name {procedure_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" - ) + return self._backend().get_procedure( + procedure_name, arg_types, database=database, schema=schema + ) def get_user_defined_function( self, @@ -511,24 +1085,10 @@ def get_user_defined_function( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - function_id = self._parse_function_or_procedure(udf_name, arg_types) - - try: - procedures = self._session.sql( - f"DESCRIBE AS RESOURCE FUNCTION {db_name}.{schema_name}.{function_id}" - ).collect() - return UserDefinedFunction.from_json(str(procedures[0][0])) - except ( - IndexError, # when sql returned no results - SnowparkSQLException, # when database, or schema doesn't exist - ): - raise NotFoundError( - f"Function with name {udf_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" - ) + return self._backend().get_user_defined_function( + udf_name, arg_types, database=database, schema=schema + ) - # set methods def set_current_database(self, database: Union[str, Database]) -> None: """Set the current default database for the session. @@ -547,19 +1107,13 @@ def set_current_schema(self, schema: Union[str, Schema]) -> None: schema_name = self._parse_schema(schema) self._session.use_schema(schema_name) - # exists methods def database_exists(self, database: Union[str, Database]) -> bool: """Check if the given database exists. Args: database: database name or ``Database`` object. """ - db_name = self._parse_database(database) - try: - self.get_database(db_name) - return True - except NotFoundError: - return False + return self._backend().database_exists(database) def schema_exists( self, @@ -574,13 +1128,7 @@ def schema_exists( schema: schema name or ``Schema`` object. database: database name or ``Database`` object. Defaults to None. """ - db_name = self._parse_database(database, schema) - schema_name = self._parse_schema(schema) - try: - self.get_schema(schema=schema_name, database=db_name) - return True - except NotFoundError: - return False + return self._backend().schema_exists(schema, database=database) def table_exists( self, @@ -597,14 +1145,7 @@ def table_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, table) - schema_name = self._parse_schema(schema, table) - table_name = table if isinstance(table, str) else table.name - try: - self.get_table(table_name=table_name, database=db_name, schema=schema_name) - return True - except NotFoundError: - return False + return self._backend().table_exists(table, database=database, schema=schema) def view_exists( self, @@ -621,14 +1162,7 @@ def view_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, view) - schema_name = self._parse_schema(schema, view) - view_name = view if isinstance(view, str) else view.name - try: - self.get_view(view_name=view_name, database=db_name, schema=schema_name) - return True - except NotFoundError: - return False + return self._backend().view_exists(view, database=database, schema=schema) def procedure_exists( self, @@ -647,27 +1181,9 @@ def procedure_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - try: - if isinstance(procedure, Procedure): - if arg_types is not None or database is not None or schema is not None: - raise ArgumentError( - "When provided procedure is a Procedure class no other arguments can be provided" - ) - database = procedure.database_name - schema = procedure.schema_name - arg_types = [ - type_string_to_type_object(a.datatype) for a in procedure.arguments - ] - procedure = procedure.name - self.get_procedure( - procedure_name=procedure, - arg_types=arg_types, - database=database, - schema=schema, - ) - return True - except NotFoundError: - return False + return self._backend().procedure_exists( + procedure, arg_types, database=database, schema=schema + ) def user_defined_function_exists( self, @@ -688,29 +1204,10 @@ def user_defined_function_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - try: - if isinstance(udf, UserDefinedFunction): - if arg_types is not None or database is not None or schema is not None: - raise ArgumentError( - "When provided udf is a UserDefinedFunction class no other arguments can be provided" - ) - database = udf.database_name - schema = udf.schema_name - arg_types = [ - type_string_to_type_object(a.datatype) for a in udf.arguments - ] - udf = udf.name - self.get_user_defined_function( - udf_name=udf, - arg_types=arg_types, - database=database, - schema=schema, - ) - return True - except NotFoundError: - return False + return self._backend().user_defined_function_exists( + udf, arg_types, database=database, schema=schema + ) - # drop methods def drop_database(self, database: Union[str, Database]) -> None: """Drop the given database. @@ -779,7 +1276,6 @@ def drop_view( self._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() - # aliases listDatabases = list_databases listSchemas = list_schemas listTables = list_tables diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index fc1835e923..cbed543fbe 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -30,6 +30,8 @@ RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true" RUNNING_ON_JENKINS = "JENKINS_HOME" in os.environ +pytest_plugins = ("tests.integ.catalog_integ_common",) + test_dir = os.path.dirname(__file__) test_data_dir = os.path.join(test_dir, "cassettes") diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index e8bd173e21..fa8940bd38 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -1,16 +1,21 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +"""Mode-agnostic catalog integration tests. + +Only tests whose call paths are identical between the SQL-based and REST-based +catalog backends live here. Backend-specific behavior is covered in +``test_catalog_sql_mode.py`` and ``test_catalog_rest_mode.py``. +""" from unittest.mock import patch -import uuid import pytest -from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark.catalog import Catalog -from snowflake.snowpark.session import Session -from snowflake.snowpark.types import IntegerType - +from tests.integ.catalog_integ_common import ( + CATALOG_TEMP_OBJECT_PREFIX, + DOES_NOT_EXIST_PATTERN, +) pytestmark = [ pytest.mark.xfail( @@ -20,192 +25,6 @@ ), ] -CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" - - -def get_temp_name(type: str) -> str: - return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() - - -def create_temp_db(session) -> str: - original_db = session.get_current_database() - temp_db = get_temp_name("DB") - session._run_query(f"create or replace database {temp_db}") - session.use_database(original_db) - return temp_db - - -@pytest.fixture(scope="module") -def temp_db1(session): - temp_db = create_temp_db(session) - yield temp_db - session._run_query(f"drop database if exists {temp_db}") - - -@pytest.fixture(scope="module") -def temp_db2(session): - temp_db = create_temp_db(session) - yield temp_db - session._run_query(f"drop database if exists {temp_db}") - - -def create_temp_schema(session, db: str) -> str: - original_db = session.get_current_database() - original_schema = session.get_current_schema() - temp_schema = get_temp_name("SCHEMA") - session._run_query(f"create or replace schema {db}.{temp_schema}") - - session.use_database(original_db) - session.use_schema(original_schema) - return temp_schema - - -@pytest.fixture(scope="module") -def temp_schema1(session, temp_db1): - temp_schema = create_temp_schema(session, temp_db1) - yield temp_schema - session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") - - -@pytest.fixture(scope="module") -def temp_schema2(session, temp_db1): - temp_schema = create_temp_schema(session, temp_db1) - yield temp_schema - session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") - - -def create_temp_table(session, db: str, schema: str) -> str: - temp_table = get_temp_name("TABLE") - session._run_query( - f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" - ) - return temp_table - - -@pytest.fixture(scope="module") -def temp_table1(session, temp_db1, temp_schema1): - temp_table = create_temp_table(session, temp_db1, temp_schema1) - yield temp_table - session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") - - -@pytest.fixture(scope="module") -def temp_table2(session, temp_db1, temp_schema1): - temp_table = create_temp_table(session, temp_db1, temp_schema1) - yield temp_table - session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") - - -def create_temp_view(session, db: str, schema: str) -> str: - temp_schema = get_temp_name("VIEW") - session._run_query( - f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" - ) - return temp_schema - - -@pytest.fixture(scope="module") -def temp_view1(session, temp_db1, temp_schema1): - temp_view = create_temp_view(session, temp_db1, temp_schema1) - yield temp_view - session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") - - -@pytest.fixture(scope="module") -def temp_view2(session, temp_db1, temp_schema1): - temp_view = create_temp_view(session, temp_db1, temp_schema1) - yield temp_view - session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") - - -def create_temp_procedure(session: Session, db, schema) -> str: - temp_procedure = get_temp_name("PROCEDURE") - session.sproc.register( - lambda _, x: x + 1, - return_type=IntegerType(), - input_types=[IntegerType()], - name=f"{db}.{schema}.{temp_procedure}", - packages=["snowflake-snowpark-python"], - ) - return temp_procedure - - -@pytest.fixture(scope="module") -def temp_procedure1(session, temp_db1, temp_schema1): - temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) - yield temp_procedure - session._run_query( - f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" - ) - - -@pytest.fixture(scope="module") -def temp_procedure2(session, temp_db1, temp_schema1): - temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) - yield temp_procedure - session._run_query( - f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" - ) - - -def create_temp_udf(session: Session, db, schema) -> str: - temp_udf = get_temp_name("UDF") - session.udf.register( - lambda x: x + 1, - return_type=IntegerType(), - input_types=[IntegerType()], - name=f"{db}.{schema}.{temp_udf}", - ) - return temp_udf - - -@pytest.fixture(scope="module") -def temp_udf1(session, temp_db1, temp_schema1): - temp_udf = create_temp_udf(session, temp_db1, temp_schema1) - yield temp_udf - session._run_query( - f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" - ) - - -@pytest.fixture(scope="module") -def temp_udf2(session, temp_db1, temp_schema1): - temp_udf = create_temp_udf(session, temp_db1, temp_schema1) - yield temp_udf - session._run_query( - f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" - ) - - -DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" - - -def test_list_db(session, temp_db1, temp_db2): - catalog: Catalog = session.catalog - db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") - assert {db.name for db in db_list} >= {temp_db1, temp_db2} - - db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") - assert {db.name for db in db_list} >= {temp_db1, temp_db2} - - -def test_list_schema(session, temp_db1, temp_schema1, temp_schema2): - catalog: Catalog = session.catalog - assert ( - len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) - == 0 - ) - - schema_list = catalog.list_schemas( - pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 - ) - assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} - - schema_list = catalog.list_schemas( - like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 - ) - assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} - def test_list_tables(session, temp_db1, temp_schema1, temp_table1, temp_table2): catalog: Catalog = session.catalog @@ -333,48 +152,6 @@ def test_list_udfs(session, temp_db1, temp_schema1, temp_udf1, temp_udf2): assert {udf.name for udf in udf_list} >= {temp_udf1, temp_udf2} -def test_get_db_schema(session): - catalog: Catalog = session.catalog - current_db = session.get_current_database() - current_schema = session.get_current_schema() - assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) - assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) - - -def test_get_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): - catalog: Catalog = session.catalog - table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) - assert table.name == temp_table1 - assert table.database_name == temp_db1 - assert table.schema_name == temp_schema1 - - view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) - assert view.name == temp_view1 - assert view.database_name == temp_db1 - assert view.schema_name == temp_schema1 - - -@pytest.mark.udf -def test_get_function_procedure_udf( - session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 -): - catalog: Catalog = session.catalog - - procedure = catalog.get_procedure( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert procedure.name == temp_procedure1 - assert procedure.database_name == temp_db1 - assert procedure.schema_name == temp_schema1 - - udf = catalog.get_user_defined_function( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert udf.name == temp_udf1 - assert udf.database_name == temp_db1 - assert udf.schema_name == temp_schema1 - - def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2): catalog = session.catalog @@ -396,112 +173,6 @@ def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2): session.use_schema(original_schema) -def test_exists_db_schema(session, temp_db1, temp_schema1): - catalog = session.catalog - assert catalog.database_exists(temp_db1) - assert not catalog.database_exists("does_not_exist") - - assert catalog.schema_exists(temp_schema1, database=temp_db1) - assert not catalog.schema_exists(temp_schema1, database="does_not_exist") - - -def test_exists_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): - catalog = session.catalog - db1_obj = catalog.get_database(temp_db1) - schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) - - assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) - assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) - table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) - assert catalog.table_exists(table) - assert not catalog.table_exists( - "does_not_exist", database=temp_db1, schema=temp_schema1 - ) - - assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) - assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) - view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) - assert catalog.view_exists(view) - assert not catalog.view_exists( - "does_not_exist", database=temp_db1, schema=temp_schema1 - ) - - -@pytest.mark.udf -def test_exists_function_procedure_udf( - session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 -): - catalog = session.catalog - db1_obj = catalog.get_database(temp_db1) - schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) - - assert catalog.procedure_exists( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.procedure_exists( - temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj - ) - proc = catalog.get_procedure( - temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.procedure_exists(proc) - assert not catalog.procedure_exists( - "does_not_exist", [], database=temp_db1, schema=temp_schema1 - ) - - assert catalog.user_defined_function_exists( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.user_defined_function_exists( - temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj - ) - udf = catalog.get_user_defined_function( - temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 - ) - assert catalog.user_defined_function_exists(udf) - assert not catalog.user_defined_function_exists( - "does_not_exist", [], database=temp_db1, schema=temp_schema1 - ) - - -@pytest.mark.parametrize("use_object", [True, False]) -def test_drop(session, use_object): - catalog = session.catalog - - original_db = session.get_current_database() - original_schema = session.get_current_schema() - try: - temp_db = create_temp_db(session) - temp_schema = create_temp_schema(session, temp_db) - temp_table = create_temp_table(session, temp_db, temp_schema) - temp_view = create_temp_view(session, temp_db, temp_schema) - if use_object: - temp_schema = catalog.get_schema(temp_schema, database=temp_db) - temp_db = catalog.get_database(temp_db) - - assert catalog.database_exists(temp_db) - assert catalog.schema_exists(temp_schema, database=temp_db) - assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) - assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) - - catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) - catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) - - assert not catalog.table_exists( - temp_table, database=temp_db, schema=temp_schema - ) - assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) - - catalog.drop_schema(temp_schema, database=temp_db) - assert not catalog.schema_exists(temp_schema, database=temp_db) - - catalog.drop_database(temp_db) - assert not catalog.database_exists(temp_db) - finally: - session.use_database(original_db) - session.use_schema(original_schema) - - def test_parse_names_negative(session): catalog = session.catalog with pytest.raises( diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py new file mode 100644 index 0000000000..89de818eab --- /dev/null +++ b/tests/integ/test_catalog_rest_mode.py @@ -0,0 +1,253 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +"""Catalog integration tests with ``context._is_snowpark_connect_compatible_mode`` False (REST / Root backend). + +Keep this file separate from ``test_catalog_sql_mode.py`` so removing one backend path +deletes only the matching test module. +""" + +import pytest + +from snowflake.core.exceptions import NotFoundError as CoreNotFoundError +from snowflake.snowpark import context +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.catalog import Catalog +from snowflake.snowpark.types import IntegerType +from tests.integ.catalog_integ_common import ( + CATALOG_TEMP_OBJECT_PREFIX, + create_temp_db, + create_temp_schema, + create_temp_table, + create_temp_view, +) + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ), +] + + +@pytest.fixture(autouse=True) +def _catalog_rest_backend_mode(monkeypatch): + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", False) + + +def test_list_db_rest_mode(session, temp_db1, temp_db2): + catalog: Catalog = session.catalog + db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + +def test_list_schema_rest_mode(session, temp_db1, temp_schema1, temp_schema2): + catalog: Catalog = session.catalog + assert ( + len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) + == 0 + ) + + schema_list = catalog.list_schemas( + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + schema_list = catalog.list_schemas( + like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + +def test_get_db_schema_rest_mode(session): + catalog: Catalog = session.catalog + current_db = session.get_current_database() + current_schema = session.get_current_schema() + assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) + assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) + + +def test_get_database_missing_raises_core_not_found_rest_mode(session): + catalog: Catalog = session.catalog + with pytest.raises(CoreNotFoundError): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + + +def test_get_table_does_not_resolve_view_rest_mode( + session, temp_db1, temp_schema1, temp_view1 +): + catalog: Catalog = session.catalog + with pytest.raises(CoreNotFoundError): + catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) + + +def test_get_view_rest_mode(session, temp_db1, temp_schema1, temp_view1): + catalog: Catalog = session.catalog + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_table_exists_false_for_view_name_rest_mode( + session, temp_db1, temp_schema1, temp_view1 +): + catalog: Catalog = session.catalog + assert not catalog.table_exists(temp_view1, database=temp_db1, schema=temp_schema1) + + +@pytest.mark.udf +def test_get_procedure_rest_mode(session, temp_db1, temp_schema1, temp_procedure1): + catalog: Catalog = session.catalog + procedure = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert procedure.name == temp_procedure1 + assert procedure.database_name == temp_db1 + assert procedure.schema_name == temp_schema1 + + +@pytest.mark.udf +def test_get_user_defined_function_rest_mode( + session, temp_db1, temp_schema1, temp_udf1 +): + catalog: Catalog = session.catalog + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert udf.name == temp_udf1 + assert udf.database_name == temp_db1 + assert udf.schema_name == temp_schema1 + + +def test_database_exists_rest_mode(session, temp_db1): + catalog: Catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + +def test_get_table_view_rest_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog: Catalog = session.catalog + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert table.name == temp_table1 + assert table.database_name == temp_db1 + assert table.schema_name == temp_schema1 + + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_exists_db_schema_rest_mode(session, temp_db1, temp_schema1): + catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + assert catalog.schema_exists(temp_schema1, database=temp_db1) + assert not catalog.schema_exists(temp_schema1, database="does_not_exist") + + +def test_exists_table_view_rest_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) + + assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(table) + assert not catalog.table_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(view) + assert not catalog.view_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.udf +def test_exists_function_procedure_udf_rest_mode( + session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) + + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + proc = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists(proc) + assert not catalog.procedure_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists(udf) + assert not catalog.user_defined_function_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.parametrize("use_object", [True, False]) +def test_drop_rest_mode(session, use_object): + catalog = session.catalog + + original_db = session.get_current_database() + original_schema = session.get_current_schema() + try: + temp_db = create_temp_db(session) + temp_schema = create_temp_schema(session, temp_db) + temp_table = create_temp_table(session, temp_db, temp_schema) + temp_view = create_temp_view(session, temp_db, temp_schema) + if use_object: + temp_schema = catalog.get_schema(temp_schema, database=temp_db) + temp_db = catalog.get_database(temp_db) + + assert catalog.database_exists(temp_db) + assert catalog.schema_exists(temp_schema, database=temp_db) + assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) + assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) + catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) + + assert not catalog.table_exists( + temp_table, database=temp_db, schema=temp_schema + ) + assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_schema(temp_schema, database=temp_db) + assert not catalog.schema_exists(temp_schema, database=temp_db) + + catalog.drop_database(temp_db) + assert not catalog.database_exists(temp_db) + finally: + session.use_database(original_db) + session.use_schema(original_schema) diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py new file mode 100644 index 0000000000..b7d166b211 --- /dev/null +++ b/tests/integ/test_catalog_sql_mode.py @@ -0,0 +1,243 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +"""Catalog integration tests with ``context._is_snowpark_connect_compatible_mode`` True (SQL backend). + +Keep this file separate from ``test_catalog_rest_mode.py`` so removing one backend path +deletes only the matching test module. +""" + +import pytest + +from snowflake.core.view import View +from snowflake.snowpark import context +from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted +from snowflake.snowpark.catalog import Catalog +from snowflake.snowpark.exceptions import NotFoundError +from snowflake.snowpark.types import IntegerType +from tests.integ.catalog_integ_common import ( + CATALOG_TEMP_OBJECT_PREFIX, + create_temp_db, + create_temp_schema, + create_temp_table, + create_temp_view, +) + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ), +] + + +@pytest.fixture(autouse=True) +def _catalog_sql_backend_mode(monkeypatch): + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + + +def test_list_db_sql_mode(session, temp_db1, temp_db2): + catalog: Catalog = session.catalog + db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + db_list = catalog.list_databases(like=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_%") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + +def test_list_schema_sql_mode(session, temp_db1, temp_schema1, temp_schema2): + catalog: Catalog = session.catalog + assert ( + len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*")) + == 0 + ) + + schema_list = catalog.list_schemas( + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_.*", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + schema_list = catalog.list_schemas( + like=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_%", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + +def test_get_db_schema_sql_mode(session): + catalog: Catalog = session.catalog + current_db = session.get_current_database() + current_schema = session.get_current_schema() + assert catalog.get_database(current_db).name == unquote_if_quoted(current_db) + assert catalog.get_schema(current_schema).name == unquote_if_quoted(current_schema) + + +def test_get_database_missing_raises_snowpark_not_found_sql_mode(session): + catalog: Catalog = session.catalog + with pytest.raises(NotFoundError, match="could not be found"): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + + +def test_get_table_resolves_view_sql_mode(session, temp_db1, temp_schema1, temp_view1): + catalog: Catalog = session.catalog + obj = catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) + assert isinstance(obj, View) + assert obj.name == temp_view1 + + +def test_table_exists_true_for_view_name_sql_mode( + session, temp_db1, temp_schema1, temp_view1 +): + catalog: Catalog = session.catalog + assert catalog.table_exists(temp_view1, database=temp_db1, schema=temp_schema1) + + +@pytest.mark.udf +def test_get_procedure_sql_mode(session, temp_db1, temp_schema1, temp_procedure1): + catalog: Catalog = session.catalog + procedure = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert procedure.name == temp_procedure1 + assert procedure.database_name == temp_db1 + assert procedure.schema_name == temp_schema1 + + +@pytest.mark.udf +def test_get_user_defined_function_sql_mode(session, temp_db1, temp_schema1, temp_udf1): + catalog: Catalog = session.catalog + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert udf.name == temp_udf1 + assert udf.database_name == temp_db1 + assert udf.schema_name == temp_schema1 + + +def test_database_exists_sql_mode(session, temp_db1): + catalog: Catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + +def test_get_table_view_sql_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog: Catalog = session.catalog + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert table.name == temp_table1 + assert table.database_name == temp_db1 + assert table.schema_name == temp_schema1 + + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_exists_db_schema_sql_mode(session, temp_db1, temp_schema1): + catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + assert catalog.schema_exists(temp_schema1, database=temp_db1) + assert not catalog.schema_exists(temp_schema1, database="does_not_exist") + + +def test_exists_table_view_sql_mode( + session, temp_db1, temp_schema1, temp_table1, temp_view1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(database=temp_db1, schema=temp_schema1) + + assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(table) + assert not catalog.table_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(view) + assert not catalog.view_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.udf +def test_exists_function_procedure_udf_sql_mode( + session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 +): + catalog = session.catalog + db1_obj = catalog.get_database(temp_db1) + schema1_obj = catalog.get_schema(temp_schema1, database=temp_db1) + + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + proc = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists(proc) + assert not catalog.procedure_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists(udf) + assert not catalog.user_defined_function_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.parametrize("use_object", [True, False]) +def test_drop_sql_mode(session, use_object): + catalog = session.catalog + + original_db = session.get_current_database() + original_schema = session.get_current_schema() + try: + temp_db = create_temp_db(session) + temp_schema = create_temp_schema(session, temp_db) + temp_table = create_temp_table(session, temp_db, temp_schema) + temp_view = create_temp_view(session, temp_db, temp_schema) + if use_object: + temp_schema = catalog.get_schema(temp_schema, database=temp_db) + temp_db = catalog.get_database(temp_db) + + assert catalog.database_exists(temp_db) + assert catalog.schema_exists(temp_schema, database=temp_db) + assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) + assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) + catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) + + assert not catalog.table_exists( + temp_table, database=temp_db, schema=temp_schema + ) + assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_schema(temp_schema, database=temp_db) + assert not catalog.schema_exists(temp_schema, database=temp_db) + + catalog.drop_database(temp_db) + assert not catalog.database_exists(temp_db) + finally: + session.use_database(original_db) + session.use_schema(original_schema) From f6108350b0fffd05fd5aa7c3bac8ee335cab437b Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 27 Apr 2026 15:21:27 -0700 Subject: [PATCH 05/25] init back end at catalog init --- src/snowflake/snowpark/catalog.py | 49 ++++++++++----------------- tests/integ/test_catalog_rest_mode.py | 13 +++++-- tests/integ/test_catalog_sql_mode.py | 13 +++++-- 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 0a2af60e58..fedf1672ac 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -466,13 +466,7 @@ def user_defined_function_exists( class _RestCatalogBackend(_CatalogBackend): def __init__(self, catalog: "Catalog") -> None: super().__init__(catalog) - self._root_obj: Optional[Root] = None - - @property - def _root(self) -> Root: - if self._root_obj is None: - self._root_obj = Root(self._catalog._session) - return self._root_obj + self._root = Root(catalog._session) def list_databases( self, @@ -708,15 +702,10 @@ class Catalog: def __init__(self, session: "Session") -> None: self._session = session self._python_regex_udf = None - self._sql_backend = _SqlCatalogBackend(self) - self._rest_backend: Optional[_RestCatalogBackend] = None - - def _backend(self) -> _CatalogBackend: if context._is_snowpark_connect_compatible_mode: - return self._sql_backend - if self._rest_backend is None: - self._rest_backend = _RestCatalogBackend(self) - return self._rest_backend + self._backend: _CatalogBackend = _SqlCatalogBackend(self) + else: + self._backend = _RestCatalogBackend(self) def _parse_database( self, @@ -842,7 +831,7 @@ def list_databases( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - return self._backend().list_databases(pattern=pattern, like=like) + return self._backend.list_databases(pattern=pattern, like=like) def list_schemas( self, @@ -859,9 +848,7 @@ def list_schemas( pattern: the python regex pattern of name to match. Defaults to None. like: the sql style pattern for name to match. Default to None. """ - return self._backend().list_schemas( - database=database, pattern=pattern, like=like - ) + return self._backend.list_schemas(database=database, pattern=pattern, like=like) def list_tables( self, @@ -998,13 +985,13 @@ def get_current_schema(self) -> Optional[str]: def get_database(self, database: str) -> Database: """Name of the database to get""" - return self._backend().get_database(database) + return self._backend.get_database(database) def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: """Name of the schema to get.""" - return self._backend().get_schema(schema, database=database) + return self._backend.get_schema(schema, database=database) def get_table( self, @@ -1027,7 +1014,7 @@ def get_table( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().get_table(table_name, database=database, schema=schema) + return self._backend.get_table(table_name, database=database, schema=schema) def get_view( self, @@ -1044,7 +1031,7 @@ def get_view( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().get_view(view_name, database=database, schema=schema) + return self._backend.get_view(view_name, database=database, schema=schema) def get_procedure( self, @@ -1063,7 +1050,7 @@ def get_procedure( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().get_procedure( + return self._backend.get_procedure( procedure_name, arg_types, database=database, schema=schema ) @@ -1085,7 +1072,7 @@ def get_user_defined_function( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().get_user_defined_function( + return self._backend.get_user_defined_function( udf_name, arg_types, database=database, schema=schema ) @@ -1113,7 +1100,7 @@ def database_exists(self, database: Union[str, Database]) -> bool: Args: database: database name or ``Database`` object. """ - return self._backend().database_exists(database) + return self._backend.database_exists(database) def schema_exists( self, @@ -1128,7 +1115,7 @@ def schema_exists( schema: schema name or ``Schema`` object. database: database name or ``Database`` object. Defaults to None. """ - return self._backend().schema_exists(schema, database=database) + return self._backend.schema_exists(schema, database=database) def table_exists( self, @@ -1145,7 +1132,7 @@ def table_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().table_exists(table, database=database, schema=schema) + return self._backend.table_exists(table, database=database, schema=schema) def view_exists( self, @@ -1162,7 +1149,7 @@ def view_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().view_exists(view, database=database, schema=schema) + return self._backend.view_exists(view, database=database, schema=schema) def procedure_exists( self, @@ -1181,7 +1168,7 @@ def procedure_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().procedure_exists( + return self._backend.procedure_exists( procedure, arg_types, database=database, schema=schema ) @@ -1204,7 +1191,7 @@ def user_defined_function_exists( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - return self._backend().user_defined_function_exists( + return self._backend.user_defined_function_exists( udf, arg_types, database=database, schema=schema ) diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py index 89de818eab..be4bba4a39 100644 --- a/tests/integ/test_catalog_rest_mode.py +++ b/tests/integ/test_catalog_rest_mode.py @@ -31,9 +31,16 @@ ] -@pytest.fixture(autouse=True) -def _catalog_rest_backend_mode(monkeypatch): - monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", False) +@pytest.fixture(autouse=True, scope="module") +def _catalog_rest_backend_mode(session): + mp = pytest.MonkeyPatch() + mp.setattr(context, "_is_snowpark_connect_compatible_mode", False) + mp.setattr(session, "_catalog", None) + try: + yield + finally: + mp.undo() + session._catalog = None def test_list_db_rest_mode(session, temp_db1, temp_db2): diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index b7d166b211..ac97b435ce 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -32,9 +32,16 @@ ] -@pytest.fixture(autouse=True) -def _catalog_sql_backend_mode(monkeypatch): - monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) +@pytest.fixture(autouse=True, scope="module") +def _catalog_sql_backend_mode(session): + mp = pytest.MonkeyPatch() + mp.setattr(context, "_is_snowpark_connect_compatible_mode", True) + mp.setattr(session, "_catalog", None) + try: + yield + finally: + mp.undo() + session._catalog = None def test_list_db_sql_mode(session, temp_db1, temp_db2): From 674d83171cada68968dd7aaaa8266030a53813d3 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 10:06:31 -0700 Subject: [PATCH 06/25] push missed test fixture --- tests/integ/catalog_integ_common.py | 170 ++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 tests/integ/catalog_integ_common.py diff --git a/tests/integ/catalog_integ_common.py b/tests/integ/catalog_integ_common.py new file mode 100644 index 0000000000..bb459ac8a6 --- /dev/null +++ b/tests/integ/catalog_integ_common.py @@ -0,0 +1,170 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +"""Shared pytest fixtures for catalog integration tests (see ``test_catalog*.py``).""" + +import uuid + +import pytest + +from snowflake.snowpark.session import Session +from snowflake.snowpark.types import IntegerType + +CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" + + +def get_temp_name(type: str) -> str: + return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() + + +def create_temp_db(session) -> str: + original_db = session.get_current_database() + temp_db = get_temp_name("DB") + session._run_query(f"create or replace database {temp_db}") + session.use_database(original_db) + return temp_db + + +@pytest.fixture(scope="module") +def temp_db1(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +@pytest.fixture(scope="module") +def temp_db2(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +def create_temp_schema(session, db: str) -> str: + original_db = session.get_current_database() + original_schema = session.get_current_schema() + temp_schema = get_temp_name("SCHEMA") + session._run_query(f"create or replace schema {db}.{temp_schema}") + + session.use_database(original_db) + session.use_schema(original_schema) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_schema1(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +@pytest.fixture(scope="module") +def temp_schema2(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +def create_temp_table(session, db: str, schema: str) -> str: + temp_table = get_temp_name("TABLE") + session._run_query( + f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" + ) + return temp_table + + +@pytest.fixture(scope="module") +def temp_table1(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +@pytest.fixture(scope="module") +def temp_table2(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +def create_temp_view(session, db: str, schema: str) -> str: + temp_schema = get_temp_name("VIEW") + session._run_query( + f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" + ) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_view1(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +@pytest.fixture(scope="module") +def temp_view2(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +def create_temp_procedure(session: Session, db, schema) -> str: + temp_procedure = get_temp_name("PROCEDURE") + session.sproc.register( + lambda _, x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_procedure}", + packages=["snowflake-snowpark-python"], + ) + return temp_procedure + + +@pytest.fixture(scope="module") +def temp_procedure1(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_procedure2(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +def create_temp_udf(session: Session, db, schema) -> str: + temp_udf = get_temp_name("UDF") + session.udf.register( + lambda x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_udf}", + ) + return temp_udf + + +@pytest.fixture(scope="module") +def temp_udf1(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_udf2(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) + + +DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" From 5933508fc926c9b7ad49f7f750007069cc3e7192 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 10:29:54 -0700 Subject: [PATCH 07/25] fix lint --- tests/integ/test_catalog.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index 801ecf5c36..115aabf36a 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -16,9 +16,6 @@ CATALOG_TEMP_OBJECT_PREFIX, DOES_NOT_EXIST_PATTERN, ) -from snowflake.snowpark.session import Session -from snowflake.snowpark.types import IntegerType -from snowflake.core.exceptions import APIError from snowflake.snowpark.context import _DEFAULT_ARTIFACT_REPOSITORY From 02ec3b8ab1c005b67920b519fd665201fb4d7526 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 11:40:12 -0700 Subject: [PATCH 08/25] remove changelog and use notimplementederror --- CHANGELOG.md | 8 ----- src/snowflake/snowpark/catalog.py | 57 +++++++++++++++++++++++-------- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0e2e8b621..959915450f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,5 @@ # Release History -## 1.51.0 (TBD) - -### Snowpark Python API Updates - -#### Improvements - -- Catalog API now uses SQL commands instead of SnowAPI calls to improve stability. - ## 1.50.0 (2026-04-23) ### Snowpark Python API Updates diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index fedf1672ac..4078ebc340 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -56,7 +56,9 @@ def list_databases( pattern: Optional[str] = None, like: Optional[str] = None, ) -> List[Database]: - pass + raise NotImplementedError( + "_CatalogBackend.list_databases must be implemented by a concrete subclass." + ) @abstractmethod def list_schemas( @@ -66,17 +68,23 @@ def list_schemas( pattern: Optional[str] = None, like: Optional[str] = None, ) -> List[Schema]: - pass + raise NotImplementedError( + "_CatalogBackend.list_schemas must be implemented by a concrete subclass." + ) @abstractmethod def get_database(self, database: str) -> Database: - pass + raise NotImplementedError( + "_CatalogBackend.get_database must be implemented by a concrete subclass." + ) @abstractmethod def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None ) -> Schema: - pass + raise NotImplementedError( + "_CatalogBackend.get_schema must be implemented by a concrete subclass." + ) @abstractmethod def get_table( @@ -86,7 +94,9 @@ def get_table( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> Union[Table, View]: - pass + raise NotImplementedError( + "_CatalogBackend.get_table must be implemented by a concrete subclass." + ) @abstractmethod def get_view( @@ -96,7 +106,9 @@ def get_view( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> View: - pass + raise NotImplementedError( + "_CatalogBackend.get_view must be implemented by a concrete subclass." + ) @abstractmethod def get_procedure( @@ -107,7 +119,9 @@ def get_procedure( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> Procedure: - pass + raise NotImplementedError( + "_CatalogBackend.get_procedure must be implemented by a concrete subclass." + ) @abstractmethod def get_user_defined_function( @@ -118,11 +132,15 @@ def get_user_defined_function( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> UserDefinedFunction: - pass + raise NotImplementedError( + "_CatalogBackend.get_user_defined_function must be implemented by a concrete subclass." + ) @abstractmethod def database_exists(self, database: Union[str, Database]) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.database_exists must be implemented by a concrete subclass." + ) @abstractmethod def schema_exists( @@ -131,7 +149,9 @@ def schema_exists( *, database: Optional[Union[str, Database]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.schema_exists must be implemented by a concrete subclass." + ) @abstractmethod def table_exists( @@ -141,7 +161,9 @@ def table_exists( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.table_exists must be implemented by a concrete subclass." + ) @abstractmethod def view_exists( @@ -151,7 +173,9 @@ def view_exists( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.view_exists must be implemented by a concrete subclass." + ) @abstractmethod def procedure_exists( @@ -162,7 +186,9 @@ def procedure_exists( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.procedure_exists must be implemented by a concrete subclass." + ) @abstractmethod def user_defined_function_exists( @@ -173,7 +199,10 @@ def user_defined_function_exists( database: Optional[Union[str, Database]] = None, schema: Optional[Union[str, Schema]] = None, ) -> bool: - pass + raise NotImplementedError( + "_CatalogBackend.user_defined_function_exists must be implemented by a " + "concrete subclass." + ) class _SqlCatalogBackend(_CatalogBackend): From c78333c8f1f348c86149859eaf64ccd3acb56e53 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 14:31:08 -0700 Subject: [PATCH 09/25] move fixture back to test_catalog --- tests/integ/catalog_integ_common.py | 170 ------------------------- tests/integ/conftest.py | 2 +- tests/integ/test_catalog.py | 174 ++++++++++++++++++++++++-- tests/integ/test_catalog_rest_mode.py | 2 +- tests/integ/test_catalog_sql_mode.py | 2 +- 5 files changed, 169 insertions(+), 181 deletions(-) delete mode 100644 tests/integ/catalog_integ_common.py diff --git a/tests/integ/catalog_integ_common.py b/tests/integ/catalog_integ_common.py deleted file mode 100644 index bb459ac8a6..0000000000 --- a/tests/integ/catalog_integ_common.py +++ /dev/null @@ -1,170 +0,0 @@ -# -# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. -# -"""Shared pytest fixtures for catalog integration tests (see ``test_catalog*.py``).""" - -import uuid - -import pytest - -from snowflake.snowpark.session import Session -from snowflake.snowpark.types import IntegerType - -CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" - - -def get_temp_name(type: str) -> str: - return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() - - -def create_temp_db(session) -> str: - original_db = session.get_current_database() - temp_db = get_temp_name("DB") - session._run_query(f"create or replace database {temp_db}") - session.use_database(original_db) - return temp_db - - -@pytest.fixture(scope="module") -def temp_db1(session): - temp_db = create_temp_db(session) - yield temp_db - session._run_query(f"drop database if exists {temp_db}") - - -@pytest.fixture(scope="module") -def temp_db2(session): - temp_db = create_temp_db(session) - yield temp_db - session._run_query(f"drop database if exists {temp_db}") - - -def create_temp_schema(session, db: str) -> str: - original_db = session.get_current_database() - original_schema = session.get_current_schema() - temp_schema = get_temp_name("SCHEMA") - session._run_query(f"create or replace schema {db}.{temp_schema}") - - session.use_database(original_db) - session.use_schema(original_schema) - return temp_schema - - -@pytest.fixture(scope="module") -def temp_schema1(session, temp_db1): - temp_schema = create_temp_schema(session, temp_db1) - yield temp_schema - session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") - - -@pytest.fixture(scope="module") -def temp_schema2(session, temp_db1): - temp_schema = create_temp_schema(session, temp_db1) - yield temp_schema - session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") - - -def create_temp_table(session, db: str, schema: str) -> str: - temp_table = get_temp_name("TABLE") - session._run_query( - f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" - ) - return temp_table - - -@pytest.fixture(scope="module") -def temp_table1(session, temp_db1, temp_schema1): - temp_table = create_temp_table(session, temp_db1, temp_schema1) - yield temp_table - session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") - - -@pytest.fixture(scope="module") -def temp_table2(session, temp_db1, temp_schema1): - temp_table = create_temp_table(session, temp_db1, temp_schema1) - yield temp_table - session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") - - -def create_temp_view(session, db: str, schema: str) -> str: - temp_schema = get_temp_name("VIEW") - session._run_query( - f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" - ) - return temp_schema - - -@pytest.fixture(scope="module") -def temp_view1(session, temp_db1, temp_schema1): - temp_view = create_temp_view(session, temp_db1, temp_schema1) - yield temp_view - session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") - - -@pytest.fixture(scope="module") -def temp_view2(session, temp_db1, temp_schema1): - temp_view = create_temp_view(session, temp_db1, temp_schema1) - yield temp_view - session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") - - -def create_temp_procedure(session: Session, db, schema) -> str: - temp_procedure = get_temp_name("PROCEDURE") - session.sproc.register( - lambda _, x: x + 1, - return_type=IntegerType(), - input_types=[IntegerType()], - name=f"{db}.{schema}.{temp_procedure}", - packages=["snowflake-snowpark-python"], - ) - return temp_procedure - - -@pytest.fixture(scope="module") -def temp_procedure1(session, temp_db1, temp_schema1): - temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) - yield temp_procedure - session._run_query( - f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" - ) - - -@pytest.fixture(scope="module") -def temp_procedure2(session, temp_db1, temp_schema1): - temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) - yield temp_procedure - session._run_query( - f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" - ) - - -def create_temp_udf(session: Session, db, schema) -> str: - temp_udf = get_temp_name("UDF") - session.udf.register( - lambda x: x + 1, - return_type=IntegerType(), - input_types=[IntegerType()], - name=f"{db}.{schema}.{temp_udf}", - ) - return temp_udf - - -@pytest.fixture(scope="module") -def temp_udf1(session, temp_db1, temp_schema1): - temp_udf = create_temp_udf(session, temp_db1, temp_schema1) - yield temp_udf - session._run_query( - f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" - ) - - -@pytest.fixture(scope="module") -def temp_udf2(session, temp_db1, temp_schema1): - temp_udf = create_temp_udf(session, temp_db1, temp_schema1) - yield temp_udf - session._run_query( - f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" - ) - - -DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 15feb15fed..01edcf170b 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -31,7 +31,7 @@ RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true" RUNNING_ON_JENKINS = "JENKINS_HOME" in os.environ -pytest_plugins = ("tests.integ.catalog_integ_common",) +pytest_plugins = ("tests.integ.test_catalog",) test_dir = os.path.dirname(__file__) test_data_dir = os.path.join(test_dir, "cassettes") diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index 115aabf36a..b0fef8e6c1 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -1,22 +1,180 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # -"""Mode-agnostic catalog integration tests. +"""Catalog integration tests and shared fixtures. -Only tests whose call paths are identical between the SQL-based and REST-based -catalog backends live here. Backend-specific behavior is covered in -``test_catalog_sql_mode.py`` and ``test_catalog_rest_mode.py``. +Mode-agnostic tests (same behavior for SQL and REST catalog backends) live in +this module. Backend-specific tests are in ``test_catalog_sql_mode.py`` and +``test_catalog_rest_mode.py``, which reuse the fixtures defined here via +``pytest_plugins`` in ``conftest.py``. """ +import uuid from unittest.mock import patch + import pytest from snowflake.snowpark.catalog import Catalog -from tests.integ.catalog_integ_common import ( - CATALOG_TEMP_OBJECT_PREFIX, - DOES_NOT_EXIST_PATTERN, -) from snowflake.snowpark.context import _DEFAULT_ARTIFACT_REPOSITORY +from snowflake.snowpark.session import Session +from snowflake.snowpark.types import IntegerType + +CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" +DOES_NOT_EXIST_PATTERN = "does_not_exist_.*" + + +def get_temp_name(type: str) -> str: + return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() + + +def create_temp_db(session) -> str: + original_db = session.get_current_database() + temp_db = get_temp_name("DB") + session._run_query(f"create or replace database {temp_db}") + session.use_database(original_db) + return temp_db + + +@pytest.fixture(scope="module") +def temp_db1(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +@pytest.fixture(scope="module") +def temp_db2(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +def create_temp_schema(session, db: str) -> str: + original_db = session.get_current_database() + original_schema = session.get_current_schema() + temp_schema = get_temp_name("SCHEMA") + session._run_query(f"create or replace schema {db}.{temp_schema}") + + session.use_database(original_db) + session.use_schema(original_schema) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_schema1(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +@pytest.fixture(scope="module") +def temp_schema2(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +def create_temp_table(session, db: str, schema: str) -> str: + temp_table = get_temp_name("TABLE") + session._run_query( + f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" + ) + return temp_table + + +@pytest.fixture(scope="module") +def temp_table1(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +@pytest.fixture(scope="module") +def temp_table2(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +def create_temp_view(session, db: str, schema: str) -> str: + temp_schema = get_temp_name("VIEW") + session._run_query( + f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" + ) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_view1(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +@pytest.fixture(scope="module") +def temp_view2(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +def create_temp_procedure(session: Session, db, schema) -> str: + temp_procedure = get_temp_name("PROCEDURE") + session.sproc.register( + lambda _, x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_procedure}", + packages=["snowflake-snowpark-python"], + ) + return temp_procedure + + +@pytest.fixture(scope="module") +def temp_procedure1(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_procedure2(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +def create_temp_udf(session: Session, db, schema) -> str: + temp_udf = get_temp_name("UDF") + session.udf.register( + lambda x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_udf}", + ) + return temp_udf + + +@pytest.fixture(scope="module") +def temp_udf1(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_udf2(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) pytestmark = [ diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py index be4bba4a39..e7d33932ca 100644 --- a/tests/integ/test_catalog_rest_mode.py +++ b/tests/integ/test_catalog_rest_mode.py @@ -14,7 +14,7 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark.catalog import Catalog from snowflake.snowpark.types import IntegerType -from tests.integ.catalog_integ_common import ( +from tests.integ.test_catalog import ( CATALOG_TEMP_OBJECT_PREFIX, create_temp_db, create_temp_schema, diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index ac97b435ce..eff8bd8d91 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -15,7 +15,7 @@ from snowflake.snowpark.catalog import Catalog from snowflake.snowpark.exceptions import NotFoundError from snowflake.snowpark.types import IntegerType -from tests.integ.catalog_integ_common import ( +from tests.integ.test_catalog import ( CATALOG_TEMP_OBJECT_PREFIX, create_temp_db, create_temp_schema, From 2afd73802e36984b0564540e6c5372f6286788e0 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 28 Apr 2026 19:29:58 -0700 Subject: [PATCH 10/25] fix test --- tests/conftest.py | 2 ++ tests/integ/conftest.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bfd9ab8f78..bca98bd3b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,8 @@ from snowflake.snowpark._internal.utils import warning_dict from .ast.conftest import default_unparser_path +pytest_plugins = ("tests.integ.test_catalog",) + logging.getLogger("snowflake.connector").setLevel(logging.ERROR) excluded_frontend_files = [ diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 01edcf170b..0ac231a396 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -31,8 +31,6 @@ RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true" RUNNING_ON_JENKINS = "JENKINS_HOME" in os.environ -pytest_plugins = ("tests.integ.test_catalog",) - test_dir = os.path.dirname(__file__) test_data_dir = os.path.join(test_dir, "cassettes") From 50d0c32320f2d760b6067c51d07cf33f03321a41 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 29 Apr 2026 10:10:40 -0700 Subject: [PATCH 11/25] add limit 10000 in sql base(scos only) --- src/snowflake/snowpark/catalog.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 4078ebc340..dadcfe31ac 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -42,6 +42,10 @@ if TYPE_CHECKING: from snowflake.snowpark.session import Session +# Cap for SHOW AS RESOURCE DATABASES / SCHEMAS in the SQL backend (SCOS; avoids +# oversized result sets when accounts have very many databases or schemas). +_SHOW_AS_RESOURCE_LIMIT = 10000 + class _CatalogBackend(ABC): """Internal catalog implementation selected by ``context._is_snowpark_connect_compatible_mode``.""" @@ -214,7 +218,9 @@ def list_databases( ) -> List[Database]: c = self._catalog like_str = f"LIKE '{like}'" if like else "" - df = c._session.sql(f"SHOW AS RESOURCE DATABASES {like_str}") + df = c._session.sql( + f"SHOW AS RESOURCE DATABASES {like_str} LIMIT {_SHOW_AS_RESOURCE_LIMIT}" + ) if pattern: c._initialize_regex_udf() assert c._python_regex_udf is not None # pyright @@ -239,7 +245,9 @@ def list_schemas( c = self._catalog db_name = c._parse_database(database) like_str = f"LIKE '{like}'" if like else "" - df = c._session.sql(f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name}") + df = c._session.sql( + f"SHOW AS RESOURCE SCHEMAS {like_str} IN {db_name} LIMIT {_SHOW_AS_RESOURCE_LIMIT}" + ) if pattern: c._initialize_regex_udf() assert c._python_regex_udf is not None # pyright From 9e89ca4525bcd253ff59ff5b17e6838349596312 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 29 Apr 2026 11:36:19 -0700 Subject: [PATCH 12/25] remove wrong test --- tests/integ/test_catalog_rest_mode.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py index e7d33932ca..7082f808b3 100644 --- a/tests/integ/test_catalog_rest_mode.py +++ b/tests/integ/test_catalog_rest_mode.py @@ -84,14 +84,6 @@ def test_get_database_missing_raises_core_not_found_rest_mode(session): catalog.get_database("NONEXISTENT_DB_XYZ_12345") -def test_get_table_does_not_resolve_view_rest_mode( - session, temp_db1, temp_schema1, temp_view1 -): - catalog: Catalog = session.catalog - with pytest.raises(CoreNotFoundError): - catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) - - def test_get_view_rest_mode(session, temp_db1, temp_schema1, temp_view1): catalog: Catalog = session.catalog view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) @@ -100,13 +92,6 @@ def test_get_view_rest_mode(session, temp_db1, temp_schema1, temp_view1): assert view.schema_name == temp_schema1 -def test_table_exists_false_for_view_name_rest_mode( - session, temp_db1, temp_schema1, temp_view1 -): - catalog: Catalog = session.catalog - assert not catalog.table_exists(temp_view1, database=temp_db1, schema=temp_schema1) - - @pytest.mark.udf def test_get_procedure_rest_mode(session, temp_db1, temp_schema1, temp_procedure1): catalog: Catalog = session.catalog From ef16ee85c71164167a82639e08f3137b01c9deca Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 30 Apr 2026 16:45:56 -0700 Subject: [PATCH 13/25] address comments --- src/snowflake/snowpark/catalog.py | 144 +++++++++++++++++++++++--- tests/integ/test_catalog.py | 3 + tests/integ/test_catalog_rest_mode.py | 5 + 3 files changed, 137 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index dadcfe31ac..af23658c92 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -208,6 +208,47 @@ def user_defined_function_exists( "concrete subclass." ) + @abstractmethod + def drop_database(self, database: Union[str, Database]) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_database must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_schema must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_table must be implemented by a concrete subclass." + ) + + @abstractmethod + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + raise NotImplementedError( + "_CatalogBackend.drop_view must be implemented by a concrete subclass." + ) + class _SqlCatalogBackend(_CatalogBackend): def list_databases( @@ -499,6 +540,48 @@ def user_defined_function_exists( except NotFoundError: return False + def drop_database(self, database: Union[str, Database]) -> None: + c = self._catalog + db_name = c._parse_database(database) + c._session.sql(f"DROP DATABASE {db_name}").collect() + + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + c._session.sql(f"DROP SCHEMA {db_name}.{schema_name}").collect() + + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + c._session.sql(f"DROP TABLE {db_name}.{schema_name}.{table_name}").collect() + + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + c._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() + class _RestCatalogBackend(_CatalogBackend): def __init__(self, catalog: "Catalog") -> None: @@ -729,6 +812,48 @@ def user_defined_function_exists( except CoreNotFoundError: return False + def drop_database(self, database: Union[str, Database]) -> None: + c = self._catalog + db_name = c._parse_database(database) + self._root.databases[db_name].drop() + + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, schema) + schema_name = c._parse_schema(schema) + self._root.databases[db_name].schemas[schema_name].drop() + + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, table) + schema_name = c._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + self._root.databases[db_name].schemas[schema_name].tables[table_name].drop() + + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + c = self._catalog + db_name = c._parse_database(database, view) + schema_name = c._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + self._root.databases[db_name].schemas[schema_name].views[view_name].drop() + class Catalog: """The Catalog class provides methods to interact with and manage the Snowflake objects. @@ -1238,8 +1363,7 @@ def drop_database(self, database: Union[str, Database]) -> None: Args: database: database name or ``Database`` object. """ - db_name = self._parse_database(database) - self._session.sql(f"DROP DATABASE {db_name}").collect() + return self._backend.drop_database(database) def drop_schema( self, @@ -1254,9 +1378,7 @@ def drop_schema( schema: schema name or ``Schema`` object. database: database name or ``Database`` object. Defaults to None. """ - db_name = self._parse_database(database, schema) - schema_name = self._parse_schema(schema) - self._session.sql(f"DROP SCHEMA {db_name}.{schema_name}").collect() + return self._backend.drop_schema(schema, database=database) def drop_table( self, @@ -1273,11 +1395,7 @@ def drop_table( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, table) - schema_name = self._parse_schema(schema, table) - table_name = table if isinstance(table, str) else table.name - - self._session.sql(f"DROP TABLE {db_name}.{schema_name}.{table_name}").collect() + return self._backend.drop_table(table, database=database, schema=schema) def drop_view( self, @@ -1294,11 +1412,7 @@ def drop_view( database: database name or ``Database`` object. Defaults to None. schema: schema name or ``Schema`` object. Defaults to None. """ - db_name = self._parse_database(database, view) - schema_name = self._parse_schema(schema, view) - view_name = view if isinstance(view, str) else view.name - - self._session.sql(f"DROP VIEW {db_name}.{schema_name}.{view_name}").collect() + return self._backend.drop_view(view, database=database, schema=schema) listDatabases = list_databases listSchemas = list_schemas diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py index b0fef8e6c1..0827e132dd 100644 --- a/tests/integ/test_catalog.py +++ b/tests/integ/test_catalog.py @@ -54,6 +54,9 @@ def create_temp_schema(session, db: str) -> str: original_schema = session.get_current_schema() temp_schema = get_temp_name("SCHEMA") session._run_query(f"create or replace schema {db}.{temp_schema}") + session.sql( + f"ALTER SCHEMA SET DEFAULT_PYTHON_ARTIFACT_REPOSITORY = {_DEFAULT_ARTIFACT_REPOSITORY}" + ).collect() session.use_database(original_db) session.use_schema(original_schema) diff --git a/tests/integ/test_catalog_rest_mode.py b/tests/integ/test_catalog_rest_mode.py index 7082f808b3..11b0cc39ed 100644 --- a/tests/integ/test_catalog_rest_mode.py +++ b/tests/integ/test_catalog_rest_mode.py @@ -9,6 +9,7 @@ import pytest +from snowflake.core.exceptions import APIError from snowflake.core.exceptions import NotFoundError as CoreNotFoundError from snowflake.snowpark import context from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted @@ -28,6 +29,10 @@ reason="deepcopy is not supported and required by local testing", run=False, ), + pytest.mark.xfail( + raises=APIError, + reason="Failure due to warehouse overload", + ), ] From d9dad80fc00d6bef0ec210769b2ebe25028508f4 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 30 Apr 2026 16:47:48 -0700 Subject: [PATCH 14/25] restore comment --- src/snowflake/snowpark/catalog.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index af23658c92..35025d8397 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -1414,6 +1414,7 @@ def drop_view( """ return self._backend.drop_view(view, database=database, schema=schema) + # aliases listDatabases = list_databases listSchemas = list_schemas listTables = list_tables From 68650b11faed2808f4ddbc65fa36deac32baf622 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 1 May 2026 12:58:08 -0700 Subject: [PATCH 15/25] parameter protection --- src/snowflake/snowpark/catalog.py | 6 +++--- src/snowflake/snowpark/session.py | 5 ++++- tests/integ/test_catalog_sql_mode.py | 14 ++++++++++++++ tests/unit/test_session.py | 19 +++++++++++++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 35025d8397..e92e546a99 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -48,7 +48,7 @@ class _CatalogBackend(ABC): - """Internal catalog implementation selected by ``context._is_snowpark_connect_compatible_mode``.""" + """Internal catalog implementation selected by compatibility mode and SQL base flag.""" def __init__(self, catalog: "Catalog") -> None: self._catalog = catalog @@ -861,10 +861,10 @@ class Catalog: views, functions, etc. """ - def __init__(self, session: "Session") -> None: + def __init__(self, session: "Session", *, _use_sql_base: bool = True) -> None: self._session = session self._python_regex_udf = None - if context._is_snowpark_connect_compatible_mode: + if context._is_snowpark_connect_compatible_mode and _use_sql_base: self._backend: _CatalogBackend = _SqlCatalogBackend(self) else: self._backend = _RestCatalogBackend(self) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5803cc8329..2ec57288e6 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -628,6 +628,9 @@ def __init__( """ self.version = get_version() self._session_stage = None + self._use_sql_base = self._conn._get_client_side_session_parameter( + "SNOWPARK_CONNECT_CATALOG_USE_SQL_BASE", True + ) if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) @@ -961,7 +964,7 @@ def catalog(self): external_feature_name="Session.catalog", raise_error=NotImplementedError, ) - self._catalog = Catalog(self) + self._catalog = Catalog(self, _use_sql_base=self._use_sql_base) return self._catalog def close(self) -> None: diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index eff8bd8d91..e253b7396a 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -9,6 +9,7 @@ import pytest +from snowflake.core.exceptions import NotFoundError as CoreNotFoundError from snowflake.core.view import View from snowflake.snowpark import context from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted @@ -85,6 +86,19 @@ def test_get_database_missing_raises_snowpark_not_found_sql_mode(session): catalog.get_database("NONEXISTENT_DB_XYZ_12345") +def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): + original_use_sql_base = session._use_sql_base + try: + session._use_sql_base = False + session._catalog = None + catalog: Catalog = session.catalog + with pytest.raises(CoreNotFoundError): + catalog.get_database("NONEXISTENT_DB_XYZ_12345") + finally: + session._use_sql_base = original_use_sql_base + session._catalog = None + + def test_get_table_resolves_view_sql_mode(session, temp_db1, temp_schema1, temp_view1): catalog: Catalog = session.catalog obj = catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0349618659..fdcba3abac 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -117,6 +117,25 @@ def test_used_scoped_temp_object(): assert Session(fake_connection)._use_scoped_temp_objects is False +@pytest.mark.parametrize( + "parameter_value", + [True, False], +) +def test_session_use_sql_base_from_session_parameter(parameter_value): + fake_connection = mock.create_autospec(ServerConnection) + fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True + fake_connection._get_client_side_session_parameter = mock.Mock( + side_effect=lambda name, default: parameter_value + if name == "SNOWPARK_CONNECT_CATALOG_USE_SQL_BASE" + else default + ) + fake_connection._conn._session_parameters = {} + + session = Session(fake_connection) + assert session._use_sql_base is parameter_value + + def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() From bb2ef8e459a1947545cc5c454373806e329cc71d Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 1 May 2026 15:49:04 -0700 Subject: [PATCH 16/25] parameter protection --- src/snowflake/snowpark/session.py | 20 +++++++++++++++----- tests/unit/test_session.py | 19 ++++++++++--------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 2ec57288e6..ca2b15ee9f 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -254,6 +254,7 @@ _session_management_lock = RLock() _active_sessions: Set["Session"] = set() +_USE_SQL_BASE_OPTION_KEY = "_use_sql_base" _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING = ( "PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS" ) @@ -463,6 +464,13 @@ def __init__(self) -> None: self._app_name = None self._format_json = None + @staticmethod + def _connection_options( + options: Dict[str, Union[int, str]] + ) -> Dict[str, Union[int, str]]: + # Internal Session-only options must not be forwarded to connector connect(**kwargs). + return {k: v for k, v in options.items() if k != _USE_SQL_BASE_OPTION_KEY} + def _remove_config(self, key: str) -> "Session.SessionBuilder": """Only used in test.""" self._options.pop(key, None) @@ -569,8 +577,11 @@ def _create_internal( # Set paramstyle to qmark by default to be consistent with previous behavior if "paramstyle" not in self._options: self._options["paramstyle"] = "qmark" + connection_options = self._connection_options(self._options) new_session = Session( - ServerConnection({}, conn) if conn else ServerConnection(self._options), + ServerConnection({}, conn) + if conn + else ServerConnection(connection_options), self._options, ) @@ -628,9 +639,8 @@ def __init__( """ self.version = get_version() self._session_stage = None - self._use_sql_base = self._conn._get_client_side_session_parameter( - "SNOWPARK_CONNECT_CATALOG_USE_SQL_BASE", True - ) + options = options or {} + self._use_sql_base = options.pop(_USE_SQL_BASE_OPTION_KEY, True) if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) @@ -851,7 +861,7 @@ def __init__( _PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION ) ) - self._conf = self.RuntimeConfig(self, options or {}) + self._conf = self.RuntimeConfig(self, options) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) self._sp_profiler = StoredProcedureProfiler(session=self) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index fdcba3abac..4bf067979e 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -118,22 +118,23 @@ def test_used_scoped_temp_object(): @pytest.mark.parametrize( - "parameter_value", - [True, False], + "option_value, expected", + [(True, True), (False, False)], ) -def test_session_use_sql_base_from_session_parameter(parameter_value): +def test_session_use_sql_base_from_options(option_value, expected): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() fake_connection._thread_safe_session_enabled = True - fake_connection._get_client_side_session_parameter = mock.Mock( - side_effect=lambda name, default: parameter_value - if name == "SNOWPARK_CONNECT_CATALOG_USE_SQL_BASE" - else default + fake_connection._get_client_side_session_parameter = ( + lambda x, y: ServerConnection._get_client_side_session_parameter( + fake_connection, x, y + ) ) fake_connection._conn._session_parameters = {} - session = Session(fake_connection) - assert session._use_sql_base is parameter_value + session = Session(fake_connection, {"_use_sql_base": option_value}) + assert session._use_sql_base is expected + assert session.conf.get("_use_sql_base") is None def test_close_exception(): From e0097c9030075d0bb07cde109ccde65fae56b8c5 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 1 May 2026 15:59:33 -0700 Subject: [PATCH 17/25] add test --- tests/integ/test_catalog_sql_mode.py | 13 +++++++++++ tests/unit/test_session.py | 34 ++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index e253b7396a..9a6ed7e10d 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -92,6 +92,7 @@ def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): session._use_sql_base = False session._catalog = None catalog: Catalog = session.catalog + assert type(catalog._backend).__name__ == "_RestCatalogBackend" with pytest.raises(CoreNotFoundError): catalog.get_database("NONEXISTENT_DB_XYZ_12345") finally: @@ -99,6 +100,18 @@ def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): session._catalog = None +def test_compat_mode_with_sql_base_enabled_uses_sql_backend(session): + original_use_sql_base = session._use_sql_base + try: + session._use_sql_base = True + session._catalog = None + catalog: Catalog = session.catalog + assert type(catalog._backend).__name__ == "_SqlCatalogBackend" + finally: + session._use_sql_base = original_use_sql_base + session._catalog = None + + def test_get_table_resolves_view_sql_mode(session, temp_db1, temp_schema1, temp_view1): catalog: Catalog = session.catalog obj = catalog.get_table(temp_view1, database=temp_db1, schema=temp_schema1) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4bf067979e..9377d1ad92 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -119,7 +119,7 @@ def test_used_scoped_temp_object(): @pytest.mark.parametrize( "option_value, expected", - [(True, True), (False, False)], + [(None, True), (True, True), (False, False)], ) def test_session_use_sql_base_from_options(option_value, expected): fake_connection = mock.create_autospec(ServerConnection) @@ -132,11 +132,41 @@ def test_session_use_sql_base_from_options(option_value, expected): ) fake_connection._conn._session_parameters = {} - session = Session(fake_connection, {"_use_sql_base": option_value}) + options = {} if option_value is None else {"_use_sql_base": option_value} + session = Session(fake_connection, options) assert session._use_sql_base is expected assert session.conf.get("_use_sql_base") is None +@pytest.mark.parametrize( + "option_value, expected_backend_name", + [(True, "_SqlCatalogBackend"), (False, "_RestCatalogBackend")], +) +def test_catalog_backend_selection_from_use_sql_base_option( + option_value, expected_backend_name +): + import snowflake.snowpark.context as ctx + + fake_connection = mock.create_autospec(ServerConnection) + fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True + fake_connection._get_client_side_session_parameter = ( + lambda x, y: ServerConnection._get_client_side_session_parameter( + fake_connection, x, y + ) + ) + fake_connection._conn._session_parameters = {} + fake_connection.get_session_id.return_value = "fake_session_id" + + original_compat = ctx._is_snowpark_connect_compatible_mode + try: + ctx._is_snowpark_connect_compatible_mode = True + session = Session(fake_connection, {"_use_sql_base": option_value}) + assert type(session.catalog._backend).__name__ == expected_backend_name + finally: + ctx._is_snowpark_connect_compatible_mode = original_compat + + def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() From 1472367c899093e1c103f791f426a2c968df8217 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 4 May 2026 14:26:41 -0700 Subject: [PATCH 18/25] fix test --- tests/integ/test_session.py | 49 ++++++++++++++++++++++++++++++++++++ tests/unit/test_session.py | 50 ------------------------------------- 2 files changed, 49 insertions(+), 50 deletions(-) diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 89cf277646..6efe0a8411 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -96,6 +96,55 @@ def test_runtime_config(db_parameters): session.close() +@pytest.mark.parametrize( + "option_value, expected", + [(None, True), (True, True), (False, False)], +) +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="Requires real Snowflake connection", +) +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +def test_session_use_sql_base_from_options(db_parameters, option_value, expected): + configs = dict(db_parameters) + if option_value is not None: + configs["_use_sql_base"] = option_value + session = Session.builder.configs(configs).create() + try: + assert session._use_sql_base is expected + assert session.conf.get("_use_sql_base") is None + finally: + session.close() + + +@pytest.mark.parametrize( + "option_value, expected_backend_name", + [(True, "_SqlCatalogBackend"), (False, "_RestCatalogBackend")], +) +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="Requires real Snowflake connection for Catalog REST backend", +) +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +def test_catalog_backend_selection_from_use_sql_base_option( + db_parameters, option_value, expected_backend_name +): + import snowflake.snowpark.context as ctx + + original_compat = ctx._is_snowpark_connect_compatible_mode + session = None + try: + ctx._is_snowpark_connect_compatible_mode = True + session = Session.builder.configs( + {**db_parameters, "_use_sql_base": option_value} + ).create() + assert type(session.catalog._backend).__name__ == expected_backend_name + finally: + if session is not None: + session.close() + ctx._is_snowpark_connect_compatible_mode = original_compat + + @pytest.mark.xfail( "config.getoption('local_testing_mode', default=False)", reason="SQL query not supported", diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 9377d1ad92..0349618659 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -117,56 +117,6 @@ def test_used_scoped_temp_object(): assert Session(fake_connection)._use_scoped_temp_objects is False -@pytest.mark.parametrize( - "option_value, expected", - [(None, True), (True, True), (False, False)], -) -def test_session_use_sql_base_from_options(option_value, expected): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - fake_connection._thread_safe_session_enabled = True - fake_connection._get_client_side_session_parameter = ( - lambda x, y: ServerConnection._get_client_side_session_parameter( - fake_connection, x, y - ) - ) - fake_connection._conn._session_parameters = {} - - options = {} if option_value is None else {"_use_sql_base": option_value} - session = Session(fake_connection, options) - assert session._use_sql_base is expected - assert session.conf.get("_use_sql_base") is None - - -@pytest.mark.parametrize( - "option_value, expected_backend_name", - [(True, "_SqlCatalogBackend"), (False, "_RestCatalogBackend")], -) -def test_catalog_backend_selection_from_use_sql_base_option( - option_value, expected_backend_name -): - import snowflake.snowpark.context as ctx - - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - fake_connection._thread_safe_session_enabled = True - fake_connection._get_client_side_session_parameter = ( - lambda x, y: ServerConnection._get_client_side_session_parameter( - fake_connection, x, y - ) - ) - fake_connection._conn._session_parameters = {} - fake_connection.get_session_id.return_value = "fake_session_id" - - original_compat = ctx._is_snowpark_connect_compatible_mode - try: - ctx._is_snowpark_connect_compatible_mode = True - session = Session(fake_connection, {"_use_sql_base": option_value}) - assert type(session.catalog._backend).__name__ == expected_backend_name - finally: - ctx._is_snowpark_connect_compatible_mode = original_compat - - def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() From 8a9bc660c7f6e47e36626ae245f63c9ecd1a8d4c Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 4 May 2026 15:49:46 -0700 Subject: [PATCH 19/25] rewrite parameter protection --- src/snowflake/snowpark/session.py | 15 ++------------- tests/integ/test_session.py | 29 +++++++++++------------------ 2 files changed, 13 insertions(+), 31 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index ca2b15ee9f..02d0950d40 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -254,7 +254,6 @@ _session_management_lock = RLock() _active_sessions: Set["Session"] = set() -_USE_SQL_BASE_OPTION_KEY = "_use_sql_base" _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING = ( "PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS" ) @@ -464,13 +463,6 @@ def __init__(self) -> None: self._app_name = None self._format_json = None - @staticmethod - def _connection_options( - options: Dict[str, Union[int, str]] - ) -> Dict[str, Union[int, str]]: - # Internal Session-only options must not be forwarded to connector connect(**kwargs). - return {k: v for k, v in options.items() if k != _USE_SQL_BASE_OPTION_KEY} - def _remove_config(self, key: str) -> "Session.SessionBuilder": """Only used in test.""" self._options.pop(key, None) @@ -577,11 +569,8 @@ def _create_internal( # Set paramstyle to qmark by default to be consistent with previous behavior if "paramstyle" not in self._options: self._options["paramstyle"] = "qmark" - connection_options = self._connection_options(self._options) new_session = Session( - ServerConnection({}, conn) - if conn - else ServerConnection(connection_options), + ServerConnection({}, conn) if conn else ServerConnection(self._options), self._options, ) @@ -640,7 +629,7 @@ def __init__( self.version = get_version() self._session_stage = None options = options or {} - self._use_sql_base = options.pop(_USE_SQL_BASE_OPTION_KEY, True) + self._use_sql_base = True if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 6efe0a8411..625fb69392 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -96,29 +96,24 @@ def test_runtime_config(db_parameters): session.close() -@pytest.mark.parametrize( - "option_value, expected", - [(None, True), (True, True), (False, False)], -) @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="Requires real Snowflake connection", ) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_session_use_sql_base_from_options(db_parameters, option_value, expected): - configs = dict(db_parameters) - if option_value is not None: - configs["_use_sql_base"] = option_value - session = Session.builder.configs(configs).create() +def test_session_use_sql_base_default_and_override(db_parameters): + session = Session.builder.configs(db_parameters).create() try: - assert session._use_sql_base is expected + assert session._use_sql_base is True assert session.conf.get("_use_sql_base") is None + session._use_sql_base = False + assert session._use_sql_base is False finally: session.close() @pytest.mark.parametrize( - "option_value, expected_backend_name", + "use_sql_base, expected_backend_name", [(True, "_SqlCatalogBackend"), (False, "_RestCatalogBackend")], ) @pytest.mark.skipif( @@ -127,22 +122,20 @@ def test_session_use_sql_base_from_options(db_parameters, option_value, expected ) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_catalog_backend_selection_from_use_sql_base_option( - db_parameters, option_value, expected_backend_name + db_parameters, use_sql_base, expected_backend_name ): import snowflake.snowpark.context as ctx original_compat = ctx._is_snowpark_connect_compatible_mode - session = None + session = Session.builder.configs(db_parameters).create() try: ctx._is_snowpark_connect_compatible_mode = True - session = Session.builder.configs( - {**db_parameters, "_use_sql_base": option_value} - ).create() + session._use_sql_base = use_sql_base + session._catalog = None assert type(session.catalog._backend).__name__ == expected_backend_name finally: - if session is not None: - session.close() ctx._is_snowpark_connect_compatible_mode = original_compat + session.close() @pytest.mark.xfail( From 40ace2866f037cb1815b07d266e0c68f9923d676 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 4 May 2026 15:51:19 -0700 Subject: [PATCH 20/25] revert session change --- src/snowflake/snowpark/session.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 02d0950d40..debe552e39 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -628,7 +628,6 @@ def __init__( """ self.version = get_version() self._session_stage = None - options = options or {} self._use_sql_base = True if isinstance(conn, MockServerConnection): @@ -850,7 +849,7 @@ def __init__( _PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION ) ) - self._conf = self.RuntimeConfig(self, options) + self._conf = self.RuntimeConfig(self, options or {}) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) self._sp_profiler = StoredProcedureProfiler(session=self) From e8cbe794dbc3074f594d0d75eed52ed383b58621 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 4 May 2026 16:05:18 -0700 Subject: [PATCH 21/25] rename to _use_sql_base_catalog --- src/snowflake/snowpark/catalog.py | 6 ++++-- src/snowflake/snowpark/session.py | 6 ++++-- tests/integ/test_catalog_sql_mode.py | 12 ++++++------ tests/integ/test_session.py | 18 +++++++++--------- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index e92e546a99..9dcce2a9e1 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -861,10 +861,12 @@ class Catalog: views, functions, etc. """ - def __init__(self, session: "Session", *, _use_sql_base: bool = True) -> None: + def __init__( + self, session: "Session", *, _use_sql_base_catalog: bool = True + ) -> None: self._session = session self._python_regex_udf = None - if context._is_snowpark_connect_compatible_mode and _use_sql_base: + if context._is_snowpark_connect_compatible_mode and _use_sql_base_catalog: self._backend: _CatalogBackend = _SqlCatalogBackend(self) else: self._backend = _RestCatalogBackend(self) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index debe552e39..4b355a8cdd 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -628,7 +628,7 @@ def __init__( """ self.version = get_version() self._session_stage = None - self._use_sql_base = True + self._use_sql_base_catalog = True if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) @@ -962,7 +962,9 @@ def catalog(self): external_feature_name="Session.catalog", raise_error=NotImplementedError, ) - self._catalog = Catalog(self, _use_sql_base=self._use_sql_base) + self._catalog = Catalog( + self, _use_sql_base_catalog=self._use_sql_base_catalog + ) return self._catalog def close(self) -> None: diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index 9a6ed7e10d..957b94995a 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -87,28 +87,28 @@ def test_get_database_missing_raises_snowpark_not_found_sql_mode(session): def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): - original_use_sql_base = session._use_sql_base + original_use_sql_base_catalog = session._use_sql_base_catalog try: - session._use_sql_base = False + session._use_sql_base_catalog = False session._catalog = None catalog: Catalog = session.catalog assert type(catalog._backend).__name__ == "_RestCatalogBackend" with pytest.raises(CoreNotFoundError): catalog.get_database("NONEXISTENT_DB_XYZ_12345") finally: - session._use_sql_base = original_use_sql_base + session._use_sql_base_catalog = original_use_sql_base_catalog session._catalog = None def test_compat_mode_with_sql_base_enabled_uses_sql_backend(session): - original_use_sql_base = session._use_sql_base + original_use_sql_base_catalog = session._use_sql_base_catalog try: - session._use_sql_base = True + session._use_sql_base_catalog = True session._catalog = None catalog: Catalog = session.catalog assert type(catalog._backend).__name__ == "_SqlCatalogBackend" finally: - session._use_sql_base = original_use_sql_base + session._use_sql_base_catalog = original_use_sql_base_catalog session._catalog = None diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 625fb69392..2e035b2ad5 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -101,19 +101,19 @@ def test_runtime_config(db_parameters): reason="Requires real Snowflake connection", ) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_session_use_sql_base_default_and_override(db_parameters): +def test_session_use_sql_base_catalog_default_and_override(db_parameters): session = Session.builder.configs(db_parameters).create() try: - assert session._use_sql_base is True - assert session.conf.get("_use_sql_base") is None - session._use_sql_base = False - assert session._use_sql_base is False + assert session._use_sql_base_catalog is True + assert session.conf.get("_use_sql_base_catalog") is None + session._use_sql_base_catalog = False + assert session._use_sql_base_catalog is False finally: session.close() @pytest.mark.parametrize( - "use_sql_base, expected_backend_name", + "use_sql_base_catalog, expected_backend_name", [(True, "_SqlCatalogBackend"), (False, "_RestCatalogBackend")], ) @pytest.mark.skipif( @@ -121,8 +121,8 @@ def test_session_use_sql_base_default_and_override(db_parameters): reason="Requires real Snowflake connection for Catalog REST backend", ) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_catalog_backend_selection_from_use_sql_base_option( - db_parameters, use_sql_base, expected_backend_name +def test_catalog_backend_selection_from_use_sql_base_catalog_option( + db_parameters, use_sql_base_catalog, expected_backend_name ): import snowflake.snowpark.context as ctx @@ -130,7 +130,7 @@ def test_catalog_backend_selection_from_use_sql_base_option( session = Session.builder.configs(db_parameters).create() try: ctx._is_snowpark_connect_compatible_mode = True - session._use_sql_base = use_sql_base + session._use_sql_base_catalog = use_sql_base_catalog session._catalog = None assert type(session.catalog._backend).__name__ == expected_backend_name finally: From ff8a005dc45e3eb8c0330a2f0cd1158fda45f65a Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 5 May 2026 10:49:25 -0700 Subject: [PATCH 22/25] add test --- src/snowflake/snowpark/catalog.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 9dcce2a9e1..6827243c12 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -339,10 +339,10 @@ def get_table( like_arg = unquote_if_quoted(table_name) tables = c.list_tables(database=db_name, schema=schema_name, like=like_arg) views: List[View] = [] - if not tables: - views = c.list_views(database=db_name, schema=schema_name, like=like_arg) if tables: return tables[0] + if not tables: + views = c.list_views(database=db_name, schema=schema_name, like=like_arg) if views: return views[0] raise NotFoundError( @@ -867,7 +867,7 @@ def __init__( self._session = session self._python_regex_udf = None if context._is_snowpark_connect_compatible_mode and _use_sql_base_catalog: - self._backend: _CatalogBackend = _SqlCatalogBackend(self) + self._backend = _SqlCatalogBackend(self) else: self._backend = _RestCatalogBackend(self) From 743620bc70ecd88c76fab709c34ab7af9678d5d4 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 6 May 2026 14:34:45 -0700 Subject: [PATCH 23/25] address comments --- src/snowflake/snowpark/catalog.py | 9 +++++---- src/snowflake/snowpark/context.py | 5 +++++ src/snowflake/snowpark/session.py | 5 +---- tests/integ/test_catalog_sql_mode.py | 12 ++++-------- tests/integ/test_session.py | 26 ++++---------------------- 5 files changed, 19 insertions(+), 38 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 6827243c12..36bab2331a 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -861,12 +861,13 @@ class Catalog: views, functions, etc. """ - def __init__( - self, session: "Session", *, _use_sql_base_catalog: bool = True - ) -> None: + def __init__(self, session: "Session") -> None: self._session = session self._python_regex_udf = None - if context._is_snowpark_connect_compatible_mode and _use_sql_base_catalog: + if ( + context._is_snowpark_connect_compatible_mode + and context._use_sql_base_catalog + ): self._backend = _SqlCatalogBackend(self) else: self._backend = _RestCatalogBackend(self) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index a111839050..c402871889 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -32,6 +32,11 @@ # This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect _is_snowpark_connect_compatible_mode = False + +# Default backend selector for the Snowpark Catalog when running in +# Snowpark-Connect / SCOS compatible mode. True -> SQL-based backend, +# False -> legacy snowflake.core REST backend. Read live by Catalog.__init__. +_use_sql_base_catalog = True # Internal-only global flag that enables improved SQL simplifier query flattening # for filter, sort, select, and distinct. When True (default), the branch # improvements are active regardless of _is_snowpark_connect_compatible_mode. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 4b355a8cdd..5803cc8329 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -628,7 +628,6 @@ def __init__( """ self.version = get_version() self._session_stage = None - self._use_sql_base_catalog = True if isinstance(conn, MockServerConnection): self._udf_registration = MockUDFRegistration(self) @@ -962,9 +961,7 @@ def catalog(self): external_feature_name="Session.catalog", raise_error=NotImplementedError, ) - self._catalog = Catalog( - self, _use_sql_base_catalog=self._use_sql_base_catalog - ) + self._catalog = Catalog(self) return self._catalog def close(self) -> None: diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index 957b94995a..6eb0978c5c 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -86,29 +86,25 @@ def test_get_database_missing_raises_snowpark_not_found_sql_mode(session): catalog.get_database("NONEXISTENT_DB_XYZ_12345") -def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session): - original_use_sql_base_catalog = session._use_sql_base_catalog +def test_compat_mode_with_sql_base_disabled_uses_rest_backend(session, monkeypatch): try: - session._use_sql_base_catalog = False + monkeypatch.setattr(context, "_use_sql_base_catalog", False) session._catalog = None catalog: Catalog = session.catalog assert type(catalog._backend).__name__ == "_RestCatalogBackend" with pytest.raises(CoreNotFoundError): catalog.get_database("NONEXISTENT_DB_XYZ_12345") finally: - session._use_sql_base_catalog = original_use_sql_base_catalog session._catalog = None -def test_compat_mode_with_sql_base_enabled_uses_sql_backend(session): - original_use_sql_base_catalog = session._use_sql_base_catalog +def test_catalog_with_sql_base_enabled_uses_sql_backend(session, monkeypatch): try: - session._use_sql_base_catalog = True + monkeypatch.setattr(context, "_use_sql_base_catalog", True) session._catalog = None catalog: Catalog = session.catalog assert type(catalog._backend).__name__ == "_SqlCatalogBackend" finally: - session._use_sql_base_catalog = original_use_sql_base_catalog session._catalog = None diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 2e035b2ad5..f2f16f4ad5 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -96,22 +96,6 @@ def test_runtime_config(db_parameters): session.close() -@pytest.mark.skipif( - "config.getoption('local_testing_mode', default=False)", - reason="Requires real Snowflake connection", -) -@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_session_use_sql_base_catalog_default_and_override(db_parameters): - session = Session.builder.configs(db_parameters).create() - try: - assert session._use_sql_base_catalog is True - assert session.conf.get("_use_sql_base_catalog") is None - session._use_sql_base_catalog = False - assert session._use_sql_base_catalog is False - finally: - session.close() - - @pytest.mark.parametrize( "use_sql_base_catalog, expected_backend_name", [(True, "_SqlCatalogBackend"), (False, "_RestCatalogBackend")], @@ -121,20 +105,18 @@ def test_session_use_sql_base_catalog_default_and_override(db_parameters): reason="Requires real Snowflake connection for Catalog REST backend", ) @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") -def test_catalog_backend_selection_from_use_sql_base_catalog_option( - db_parameters, use_sql_base_catalog, expected_backend_name +def test_catalog_backend_selection_from_context_use_sql_base_catalog( + db_parameters, monkeypatch, use_sql_base_catalog, expected_backend_name ): import snowflake.snowpark.context as ctx - original_compat = ctx._is_snowpark_connect_compatible_mode session = Session.builder.configs(db_parameters).create() try: - ctx._is_snowpark_connect_compatible_mode = True - session._use_sql_base_catalog = use_sql_base_catalog + monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(ctx, "_use_sql_base_catalog", use_sql_base_catalog) session._catalog = None assert type(session.catalog._backend).__name__ == expected_backend_name finally: - ctx._is_snowpark_connect_compatible_mode = original_compat session.close() From 1e3af122d5589da35061b4dcb8bb4b6c7097ff5e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 6 May 2026 14:46:45 -0700 Subject: [PATCH 24/25] address comments --- src/snowflake/snowpark/catalog.py | 34 ++++++------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 36bab2331a..710f050086 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -759,21 +759,10 @@ def procedure_exists( schema: Optional[Union[str, Schema]] = None, ) -> bool: c = self._catalog + db_name = c._parse_database(database, procedure) + schema_name = c._parse_schema(schema, procedure) + procedure_id = c._parse_function_or_procedure(procedure, arg_types) try: - if isinstance(procedure, Procedure): - if arg_types is not None or database is not None or schema is not None: - raise ArgumentError( - "When provided procedure is a Procedure class no other arguments can be provided" - ) - database = procedure.database_name - schema = procedure.schema_name - arg_types = [ - type_string_to_type_object(a.datatype) for a in procedure.arguments - ] - procedure = procedure.name - db_name = c._parse_database(database, procedure) - schema_name = c._parse_schema(schema, procedure) - procedure_id = c._parse_function_or_procedure(procedure, arg_types) self._root.databases[db_name].schemas[schema_name].procedures[ procedure_id ].fetch() @@ -790,21 +779,10 @@ def user_defined_function_exists( schema: Optional[Union[str, Schema]] = None, ) -> bool: c = self._catalog + db_name = c._parse_database(database, udf) + schema_name = c._parse_schema(schema, udf) + function_id = c._parse_function_or_procedure(udf, arg_types) try: - if isinstance(udf, UserDefinedFunction): - if arg_types is not None or database is not None or schema is not None: - raise ArgumentError( - "When provided udf is a UserDefinedFunction class no other arguments can be provided" - ) - database = udf.database_name - schema = udf.schema_name - arg_types = [ - type_string_to_type_object(a.datatype) for a in udf.arguments - ] - udf = udf.name - db_name = c._parse_database(database, udf) - schema_name = c._parse_schema(schema, udf) - function_id = c._parse_function_or_procedure(udf, arg_types) self._root.databases[db_name].schemas[schema_name].user_defined_functions[ function_id ].fetch() From 2ecc8054528e0b21aa7a9b7ee64d314b73577f97 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 6 May 2026 15:03:19 -0700 Subject: [PATCH 25/25] address comments --- src/snowflake/snowpark/catalog.py | 26 +++++++++++++------------- src/snowflake/snowpark/exceptions.py | 4 ++-- tests/integ/test_catalog_sql_mode.py | 4 ++-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 710f050086..d36f0342cd 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -14,7 +14,7 @@ from snowflake.snowpark import context from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted -from snowflake.snowpark.exceptions import SnowparkSQLException, NotFoundError +from snowflake.snowpark.exceptions import SnowparkSQLException, _NotFoundError try: from snowflake.core import Root # type: ignore @@ -307,7 +307,7 @@ def get_database(self, database: str) -> Database: try: return self.list_databases(like=unquote_if_quoted(database))[0] except IndexError: - raise NotFoundError(f"Database with name {database} could not be found") + raise _NotFoundError(f"Database with name {database} could not be found") def get_schema( self, schema: str, *, database: Optional[Union[str, Database]] = None @@ -322,7 +322,7 @@ def get_schema( IndexError, SnowparkSQLException, ): - raise NotFoundError( + raise _NotFoundError( f"Schema with name {schema} could not be found in database '{db_name}'" ) @@ -345,7 +345,7 @@ def get_table( views = c.list_views(database=db_name, schema=schema_name, like=like_arg) if views: return views[0] - raise NotFoundError( + raise _NotFoundError( f"Table with name {table_name} could not be found in schema '{db_name}.{schema_name}'" ) @@ -366,7 +366,7 @@ def get_view( like=unquote_if_quoted(view_name), )[0] except IndexError: - raise NotFoundError( + raise _NotFoundError( f"View with name {view_name} could not be found in schema '{db_name}.{schema_name}'" ) @@ -392,7 +392,7 @@ def get_procedure( IndexError, SnowparkSQLException, ): - raise NotFoundError( + raise _NotFoundError( f"Procedure with name {procedure_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" ) @@ -418,7 +418,7 @@ def get_user_defined_function( IndexError, SnowparkSQLException, ): - raise NotFoundError( + raise _NotFoundError( f"Function with name {udf_name} and arguments {arg_types} could not be found in schema '{db_name}.{schema_name}'" ) @@ -428,7 +428,7 @@ def database_exists(self, database: Union[str, Database]) -> bool: try: self.get_database(db_name) return True - except NotFoundError: + except _NotFoundError: return False def schema_exists( @@ -443,7 +443,7 @@ def schema_exists( try: self.get_schema(schema=schema_name, database=db_name) return True - except NotFoundError: + except _NotFoundError: return False def table_exists( @@ -460,7 +460,7 @@ def table_exists( try: self.get_table(table_name=table_name, database=db_name, schema=schema_name) return True - except NotFoundError: + except _NotFoundError: return False def view_exists( @@ -477,7 +477,7 @@ def view_exists( try: self.get_view(view_name=view_name, database=db_name, schema=schema_name) return True - except NotFoundError: + except _NotFoundError: return False def procedure_exists( @@ -507,7 +507,7 @@ def procedure_exists( schema=schema, ) return True - except NotFoundError: + except _NotFoundError: return False def user_defined_function_exists( @@ -537,7 +537,7 @@ def user_defined_function_exists( schema=schema, ) return True - except NotFoundError: + except _NotFoundError: return False def drop_database(self, database: Union[str, Database]) -> None: diff --git a/src/snowflake/snowpark/exceptions.py b/src/snowflake/snowpark/exceptions.py index d31fe178a6..5ae5ace724 100644 --- a/src/snowflake/snowpark/exceptions.py +++ b/src/snowflake/snowpark/exceptions.py @@ -285,7 +285,7 @@ class SnowparkInvalidObjectNameException(SnowparkGeneralException): pass -class NotFoundError(SnowparkClientException): - """Raised when we encounter an object is not found.""" +class _NotFoundError(SnowparkClientException): + """Internal exception raised when a Snowpark catalog object is not found.""" pass diff --git a/tests/integ/test_catalog_sql_mode.py b/tests/integ/test_catalog_sql_mode.py index 6eb0978c5c..9299bdd6ab 100644 --- a/tests/integ/test_catalog_sql_mode.py +++ b/tests/integ/test_catalog_sql_mode.py @@ -14,7 +14,7 @@ from snowflake.snowpark import context from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark.catalog import Catalog -from snowflake.snowpark.exceptions import NotFoundError +from snowflake.snowpark.exceptions import _NotFoundError from snowflake.snowpark.types import IntegerType from tests.integ.test_catalog import ( CATALOG_TEMP_OBJECT_PREFIX, @@ -82,7 +82,7 @@ def test_get_db_schema_sql_mode(session): def test_get_database_missing_raises_snowpark_not_found_sql_mode(session): catalog: Catalog = session.catalog - with pytest.raises(NotFoundError, match="could not be found"): + with pytest.raises(_NotFoundError, match="could not be found"): catalog.get_database("NONEXISTENT_DB_XYZ_12345")