Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/langchain/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "databricks-langchain"
version = "0.10.0.dev0"
version = "0.10.0"
description = "Support for Databricks AI support in LangChain"
authors = [
{ name="Databricks", email="agent-feedback@databricks.com" },
Expand Down
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
self,
*,
instance_name: str,
schema: str = "public",
workspace_client: WorkspaceClient | None = None,
**pool_kwargs: object,
) -> None:
Expand All @@ -35,6 +36,7 @@ def __init__(

self._lakebase: LakebasePool = LakebasePool(
instance_name=instance_name,
schema=schema,
workspace_client=workspace_client,
**dict(pool_kwargs),
)
Expand Down
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self,
*,
instance_name: str,
schema: str = "public",
workspace_client: Optional[WorkspaceClient] = None,
**pool_kwargs: Any,
) -> None:
Expand All @@ -43,6 +44,7 @@ def __init__(

self._lakebase: LakebasePool = LakebasePool(
instance_name=instance_name,
schema=schema,
workspace_client=workspace_client,
**pool_kwargs,
)
Expand Down
48 changes: 47 additions & 1 deletion integrations/langchain/tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ class TestConnectionPool:
def __init__(self, connection_value="conn"):
self.connection_value = connection_value
self.conninfo = ""
self.conn_kwargs = {}

def __call__(
self,
*,
conninfo,
connection_class=None,
**kwargs,
kwargs=None,
**pool_kwargs,
):
self.conninfo = conninfo
self.conn_kwargs = kwargs or {}
return self

def connection(self):
Expand Down Expand Up @@ -66,3 +69,46 @@ def test_checkpoint_saver_configures_lakebase(monkeypatch):

with saver._lakebase.connection() as conn:
assert conn == "lake-conn"


def test_checkpoint_saver_uses_default_public_schema(monkeypatch):
"""Verify that the default schema 'public' is used when no schema is specified."""
test_pool = TestConnectionPool(connection_value="lake-conn")
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = MagicMock()
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")

saver = CheckpointSaver(
instance_name="lakebase-instance",
workspace_client=workspace,
)

# Check that the options contain the default public schema
options = test_pool.conn_kwargs.get("options", "")
assert options == "-c search_path=public,public"


def test_checkpoint_saver_uses_custom_schema(monkeypatch):
"""Verify that a custom schema is passed through to LakebasePool."""
test_pool = TestConnectionPool(connection_value="lake-conn")
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = MagicMock()
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")

saver = CheckpointSaver(
instance_name="lakebase-instance",
schema="my_custom_schema",
workspace_client=workspace,
)

# Check that the options contain the custom schema
options = test_pool.conn_kwargs.get("options", "")
assert options == "-c search_path=my_custom_schema,public"
50 changes: 49 additions & 1 deletion integrations/langchain/tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ class TestConnectionPool:
def __init__(self, connection_value="conn"):
self.connection_value = connection_value
self.conninfo = ""
self.conn_kwargs = {}

def __call__(
self,
*,
conninfo,
connection_class=None,
**kwargs,
kwargs=None,
**pool_kwargs,
):
self.conninfo = conninfo
self.conn_kwargs = kwargs or {}
return self

def connection(self):
Expand Down Expand Up @@ -70,3 +73,48 @@ def test_databricks_store_configures_lakebase(monkeypatch):

with store._lakebase.connection() as conn:
assert conn == mock_conn


def test_databricks_store_uses_default_public_schema(monkeypatch):
"""Verify that the default schema 'public' is used when no schema is specified."""
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = MagicMock()
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")

DatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
)

# Check that the options contain the default public schema
options = test_pool.conn_kwargs.get("options", "")
assert options == "-c search_path=public,public"


def test_databricks_store_uses_custom_schema(monkeypatch):
"""Verify that a custom schema is passed through to LakebasePool."""
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = MagicMock()
workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token")
workspace.database.get_database_instance.return_value.read_write_dns = "db-host"
workspace.current_service_principal.me.side_effect = RuntimeError("no sp")
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")

DatabricksStore(
instance_name="lakebase-instance",
schema="my_custom_schema",
workspace_client=workspace,
)

# Check that the options contain the custom schema
options = test_pool.conn_kwargs.get("options", "")
assert options == "-c search_path=my_custom_schema,public"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "databricks-ai-bridge"
version = "0.10.0.dev0"
version = "0.10.0"
description = "Official Python library for Databricks AI support"
authors = [
{ name="Databricks", email="agent-feedback@databricks.com" },
Expand Down
2 changes: 2 additions & 0 deletions src/databricks_ai_bridge/lakebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
self,
*,
instance_name: str,
schema: str = "public",
workspace_client: WorkspaceClient | None = None,
token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS,
**pool_kwargs: object,
Expand Down Expand Up @@ -110,6 +111,7 @@ def connect(cls, conninfo: str = "", **kwargs):
default_kwargs: dict[str, object] = {
"autocommit": True,
"row_factory": dict_row,
"options": f"-c search_path={schema},public",
"keepalives": 1,
"keepalives_idle": 30,
"keepalives_interval": 10,
Expand Down
39 changes: 38 additions & 1 deletion tests/databricks_ai_bridge/test_lakebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def __init__(
*,
conninfo,
connection_class,
**kwargs,
kwargs=None,
**pool_kwargs,
):
self.conninfo = conninfo
self.connection_class = connection_class
self.conn_kwargs = kwargs or {}

return TestConnectionPool

Expand Down Expand Up @@ -140,6 +142,41 @@ def test_lakebase_pool_falls_back_to_user_when_service_principal_missing(monkeyp
assert "user=test@databricks.com" in pool.pool.conninfo


def test_lakebase_pool_uses_default_public_schema(monkeypatch):
"""Verify that the default schema 'public' is used when no schema is specified."""
TestConnectionPool = _make_connection_pool_class()
monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool)

workspace = _make_workspace()

pool = LakebasePool(
instance_name="lake-instance",
workspace_client=workspace,
)

# Check that the options contain the default public schema
options = pool.pool.conn_kwargs.get("options", "")
assert options == "-c search_path=public,public"


def test_lakebase_pool_uses_custom_schema(monkeypatch):
"""Verify that a custom schema is used when specified."""
TestConnectionPool = _make_connection_pool_class()
monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool)

workspace = _make_workspace()

pool = LakebasePool(
instance_name="lake-instance",
schema="my_custom_schema",
workspace_client=workspace,
)

# Check that the options contain the custom schema
options = pool.pool.conn_kwargs.get("options", "")
assert options == "-c search_path=my_custom_schema,public"


def test_lakebase_pool_refreshes_token_after_cache_expiry(monkeypatch):
"""Verify that a new token is minted when the cache duration expires."""
TestConnectionPool = _make_connection_pool_class()
Expand Down