diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index e2eed829e..1f909f6f7 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "databricks-langchain" -version = "0.10.0.dev0" +version = "0.10.0" description = "Support for Databricks AI support in LangChain" authors = [ { name="Databricks", email="agent-feedback@databricks.com" }, diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 354982a5d..fbe5176c9 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -23,6 +23,7 @@ def __init__( self, *, instance_name: str, + schema: str = "public", workspace_client: WorkspaceClient | None = None, **pool_kwargs: object, ) -> None: @@ -35,6 +36,7 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, + schema=schema, workspace_client=workspace_client, **dict(pool_kwargs), ) diff --git a/integrations/langchain/src/databricks_langchain/store.py b/integrations/langchain/src/databricks_langchain/store.py index 1a481f3b8..464d7243e 100644 --- a/integrations/langchain/src/databricks_langchain/store.py +++ b/integrations/langchain/src/databricks_langchain/store.py @@ -32,6 +32,7 @@ def __init__( self, *, instance_name: str, + schema: str = "public", workspace_client: Optional[WorkspaceClient] = None, **pool_kwargs: Any, ) -> None: @@ -43,6 +44,7 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, + schema=schema, workspace_client=workspace_client, **pool_kwargs, ) diff --git a/integrations/langchain/tests/unit_tests/test_checkpoint.py b/integrations/langchain/tests/unit_tests/test_checkpoint.py index 65f410a2d..69ce046e7 100644 --- a/integrations/langchain/tests/unit_tests/test_checkpoint.py +++ b/integrations/langchain/tests/unit_tests/test_checkpoint.py @@ -17,15 +17,18 @@ class TestConnectionPool: def __init__(self, connection_value="conn"): self.connection_value = connection_value self.conninfo = "" + self.conn_kwargs = {} def __call__( self, *, conninfo, connection_class=None, - **kwargs, + kwargs=None, + **pool_kwargs, ): self.conninfo = conninfo + self.conn_kwargs = kwargs or {} return self def connection(self): @@ -66,3 +69,46 @@ def test_checkpoint_saver_configures_lakebase(monkeypatch): with saver._lakebase.connection() as conn: assert conn == "lake-conn" + + +def test_checkpoint_saver_uses_default_public_schema(monkeypatch): + """Verify that the default schema 'public' is used when no schema is specified.""" + test_pool = TestConnectionPool(connection_value="lake-conn") + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = MagicMock() + workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") + workspace.database.get_database_instance.return_value.read_write_dns = "db-host" + workspace.current_service_principal.me.side_effect = RuntimeError("no sp") + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + saver = CheckpointSaver( + instance_name="lakebase-instance", + workspace_client=workspace, + ) + + # Check that the options contain the default public schema + options = test_pool.conn_kwargs.get("options", "") + assert options == "-c search_path=public,public" + + +def test_checkpoint_saver_uses_custom_schema(monkeypatch): + """Verify that a custom schema is passed through to LakebasePool.""" + test_pool = TestConnectionPool(connection_value="lake-conn") + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = MagicMock() + workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") + workspace.database.get_database_instance.return_value.read_write_dns = "db-host" + workspace.current_service_principal.me.side_effect = RuntimeError("no sp") + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + saver = CheckpointSaver( + instance_name="lakebase-instance", + schema="my_custom_schema", + workspace_client=workspace, + ) + + # Check that the options contain the custom schema + options = test_pool.conn_kwargs.get("options", "") + assert options == "-c search_path=my_custom_schema,public" diff --git a/integrations/langchain/tests/unit_tests/test_store.py b/integrations/langchain/tests/unit_tests/test_store.py index 8b9c009e4..ca5341f15 100644 --- a/integrations/langchain/tests/unit_tests/test_store.py +++ b/integrations/langchain/tests/unit_tests/test_store.py @@ -17,15 +17,18 @@ class TestConnectionPool: def __init__(self, connection_value="conn"): self.connection_value = connection_value self.conninfo = "" + self.conn_kwargs = {} def __call__( self, *, conninfo, connection_class=None, - **kwargs, + kwargs=None, + **pool_kwargs, ): self.conninfo = conninfo + self.conn_kwargs = kwargs or {} return self def connection(self): @@ -70,3 +73,48 @@ def test_databricks_store_configures_lakebase(monkeypatch): with store._lakebase.connection() as conn: assert conn == mock_conn + + +def test_databricks_store_uses_default_public_schema(monkeypatch): + """Verify that the default schema 'public' is used when no schema is specified.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = MagicMock() + workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") + workspace.database.get_database_instance.return_value.read_write_dns = "db-host" + workspace.current_service_principal.me.side_effect = RuntimeError("no sp") + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + DatabricksStore( + instance_name="lakebase-instance", + workspace_client=workspace, + ) + + # Check that the options contain the default public schema + options = test_pool.conn_kwargs.get("options", "") + assert options == "-c search_path=public,public" + + +def test_databricks_store_uses_custom_schema(monkeypatch): + """Verify that a custom schema is passed through to LakebasePool.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = MagicMock() + workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") + workspace.database.get_database_instance.return_value.read_write_dns = "db-host" + workspace.current_service_principal.me.side_effect = RuntimeError("no sp") + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + DatabricksStore( + instance_name="lakebase-instance", + schema="my_custom_schema", + workspace_client=workspace, + ) + + # Check that the options contain the custom schema + options = test_pool.conn_kwargs.get("options", "") + assert options == "-c search_path=my_custom_schema,public" diff --git a/pyproject.toml b/pyproject.toml index ba3e1faec..521439359 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "databricks-ai-bridge" -version = "0.10.0.dev0" +version = "0.10.0" description = "Official Python library for Databricks AI support" authors = [ { name="Databricks", email="agent-feedback@databricks.com" }, diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 11a608866..0969d360b 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -57,6 +57,7 @@ def __init__( self, *, instance_name: str, + schema: str = "public", workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: object, @@ -110,6 +111,7 @@ def connect(cls, conninfo: str = "", **kwargs): default_kwargs: dict[str, object] = { "autocommit": True, "row_factory": dict_row, + "options": f"-c search_path={schema},public", "keepalives": 1, "keepalives_idle": 30, "keepalives_interval": 10, diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index 0bed9490a..3fd50e88f 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -42,10 +42,12 @@ def __init__( *, conninfo, connection_class, - **kwargs, + kwargs=None, + **pool_kwargs, ): self.conninfo = conninfo self.connection_class = connection_class + self.conn_kwargs = kwargs or {} return TestConnectionPool @@ -140,6 +142,41 @@ def test_lakebase_pool_falls_back_to_user_when_service_principal_missing(monkeyp assert "user=test@databricks.com" in pool.pool.conninfo +def test_lakebase_pool_uses_default_public_schema(monkeypatch): + """Verify that the default schema 'public' is used when no schema is specified.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_workspace() + + pool = LakebasePool( + instance_name="lake-instance", + workspace_client=workspace, + ) + + # Check that the options contain the default public schema + options = pool.pool.conn_kwargs.get("options", "") + assert options == "-c search_path=public,public" + + +def test_lakebase_pool_uses_custom_schema(monkeypatch): + """Verify that a custom schema is used when specified.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_workspace() + + pool = LakebasePool( + instance_name="lake-instance", + schema="my_custom_schema", + workspace_client=workspace, + ) + + # Check that the options contain the custom schema + options = pool.pool.conn_kwargs.get("options", "") + assert options == "-c search_path=my_custom_schema,public" + + def test_lakebase_pool_refreshes_token_after_cache_expiry(monkeypatch): """Verify that a new token is minted when the cache duration expires.""" TestConnectionPool = _make_connection_pool_class()