Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/databricks_ai_bridge/long_running/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ class Response(Base):
messages = relationship("Message", back_populates="response", cascade="all, delete-orphan")


class ConversationAlias(Base):
"""Maps a stable, client-visible ``conversation_id`` to its current
rotated form so the client never has to track rotation itself.

On first use, ``base_conversation_id == current_conversation_id`` (no
rotation has happened). Each crash-resume rotates ``current`` to
``{base}::attempt-N`` (anchored off ``base``, never off the prior
rotated form, so ids don't grow unboundedly across multiple crashes).

The bridge resolves every incoming request's ``context.conversation_id``
forward through this table before dispatching to the handler, so the
SDK's session/checkpointer always lands on the post-rotation thread.
"""

__tablename__ = "conversation_aliases"
__table_args__ = {"schema": AGENT_DB_SCHEMA}

base_conversation_id: Mapped[str] = mapped_column(Text, primary_key=True)
current_conversation_id: Mapped[str] = mapped_column(Text, nullable=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)


class Message(Base):
"""Stream events and output items for a response.

Expand Down
48 changes: 47 additions & 1 deletion src/databricks_ai_bridge/long_running/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from sqlalchemy.sql import bindparam, text

from databricks_ai_bridge.long_running.db import session_scope
from databricks_ai_bridge.long_running.models import AGENT_DB_SCHEMA, Message, Response
from databricks_ai_bridge.long_running.models import (
AGENT_DB_SCHEMA,
ConversationAlias,
Message,
Response,
)


async def create_response(
Expand Down Expand Up @@ -250,3 +255,44 @@ async def get_response(response_id: str) -> ResponseInfo | None:
json.loads(row.original_request) if row.original_request else None,
)
return None


async def resolve_conversation_alias(base_conversation_id: str) -> str:
"""Return the current rotated form of a conversation id, or the input unchanged.

Called on the request hot path, so it must be cheap. Misses (no row) are
expected on first contact for a new conversation; those return the input.
"""
async with session_scope() as session:
result = await session.execute(
select(ConversationAlias.current_conversation_id).where(
ConversationAlias.base_conversation_id == base_conversation_id
)
)
current = result.scalar_one_or_none()
return current if current is not None else base_conversation_id


async def upsert_conversation_alias(
base_conversation_id: str, current_conversation_id: str
) -> None:
"""Record (or update) the mapping ``base -> current``.

Postgres upsert keeps the table size at one row per logical conversation
regardless of how many times that conversation has been rotated.
"""
stmt = text(
f"""
INSERT INTO {AGENT_DB_SCHEMA}.conversation_aliases
(base_conversation_id, current_conversation_id, updated_at)
VALUES (:base, :current, now())
ON CONFLICT (base_conversation_id) DO UPDATE
SET current_conversation_id = EXCLUDED.current_conversation_id,
updated_at = EXCLUDED.updated_at
"""
)
async with session_scope() as session:
await session.execute(
stmt, {"base": base_conversation_id, "current": current_conversation_id}
)
await session.commit()
92 changes: 81 additions & 11 deletions src/databricks_ai_bridge/long_running/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
get_messages,
get_response,
heartbeat_response,
resolve_conversation_alias,
update_response_status,
update_response_trace_id,
upsert_conversation_alias,
)
from databricks_ai_bridge.long_running.settings import LongRunningSettings
from databricks_ai_bridge.utils.annotations import experimental
Expand Down Expand Up @@ -197,7 +199,7 @@ def _rotate_conversation_id(
new_attempt_number: int,
response_id: str,
) -> dict[str, Any]:
"""Rotate the conversation anchor to a per-attempt value.
"""Rotate the conversation anchor to a per-rotation-unique value.

After a crash, attempt N+1 should see a FRESH checkpointer / session so it
doesn't inherit mid-turn state that the SDK can't repair cleanly (most
Expand All @@ -208,9 +210,12 @@ def _rotate_conversation_id(
2. context.conversation_id (fallback)
3. auto-generated (last resort)

We drop (1), pick the current base anchor, and write ``{base}::attempt-N``
into (2). The handler then resolves to a fresh key for this attempt while
still being deterministic across retries of the same attempt.
We drop (1), pick the current base anchor, and write
``{base}::r-{response_id_short}-{attempt}`` into (2). Including
``response_id`` (truncated for readability) is what keeps multi-turn
rotations collision-free: turn 2's attempt-2 must not share a session
with turn 1's attempt-2 — both would otherwise mint
``{base}::attempt-2`` and re-poison the just-rotated session.

The LLM sees full turn history via ``original_request.input``, which was
captured at the initial POST — before any attempt ran, so it's clean by
Expand All @@ -231,9 +236,12 @@ def _rotate_conversation_id(
custom_inputs.pop("session_id", None)
request_dict["custom_inputs"] = custom_inputs

# response_id format is ``resp_<24hex>``; first 8 chars of the hex are
# plenty for collision-avoidance within one bridge deployment.
rid_short = response_id.removeprefix("resp_")[:8] or response_id
ctx = request_dict.get("context") or {}
ctx = dict(ctx)
rotated = f"{base_anchor}::attempt-{new_attempt_number}"
rotated = f"{base_anchor}::r-{rid_short}-{new_attempt_number}"
ctx["conversation_id"] = rotated
request_dict["context"] = ctx
logger.info(
Expand All @@ -246,6 +254,23 @@ def _rotate_conversation_id(
return request_dict


def _attach_bridge_logger() -> None:
"""Surface ``databricks_ai_bridge`` INFO logs in app stdout, since
uvicorn's logging config drops INFO from non-uvicorn loggers. Idempotent;
set ``DATABRICKS_AI_BRIDGE_LOG_QUIET=1`` to opt out.
"""
if os.environ.get("DATABRICKS_AI_BRIDGE_LOG_QUIET", "").lower() in ("1", "true", "yes"):
return
bridge_logger = logging.getLogger("databricks_ai_bridge")
if bridge_logger.level == logging.NOTSET or bridge_logger.level > logging.INFO:
bridge_logger.setLevel(logging.INFO)
if any(isinstance(h, logging.StreamHandler) for h in bridge_logger.handlers):
return
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
bridge_logger.addHandler(handler)


@experimental
class LongRunningAgentServer(AgentServer):
"""AgentServer subclass adding background mode, retrieve endpoints, and
Expand Down Expand Up @@ -328,6 +353,7 @@ def __init__(
f"LongRunningAgentServer only supports '{self._SUPPORTED_AGENT_TYPE}', "
f"got '{agent_type}'"
)
_attach_bridge_logger()
self._settings = LongRunningSettings(
task_timeout_seconds=task_timeout_seconds,
poll_interval_seconds=poll_interval_seconds,
Expand Down Expand Up @@ -517,6 +543,20 @@ async def _handle_background_request(
# when tests pass a plain dict directly.
dump = getattr(request_data, "model_dump", None)
request_dict = dump() if callable(dump) else dict(request_data)

# Forward-resolve the client-visible base conversation_id to its
# current rotated form so the SDK lands on the post-rotation
# session. The client never has to track rotation itself; it always
# sends the original base id. ``original_request`` keeps the BASE so
# later rotations anchor off it (and ids don't grow unboundedly
# across multiple crashes). The dispatched copy uses the resolved
# form so this turn's SDK session is the rotated one.
base_conv_id = (request_dict.get("context") or {}).get("conversation_id")
if base_conv_id:
current_conv_id = await resolve_conversation_alias(base_conv_id)
else:
current_conv_id = None

# Store the FULL request (untrimmed) as `original_request` so resume can
# recover the entire prior-turn history. Per-template handlers are
# responsible for deduping their own UI-echoed input against the SDK's
Expand All @@ -527,7 +567,22 @@ async def _handle_background_request(
durable=True,
original_request=request_dict,
)
durable_request = self.validator.validate_and_convert_request(request_dict)

if current_conv_id and current_conv_id != base_conv_id:
dispatch_dict = copy.deepcopy(request_dict)
ctx = dispatch_dict.get("context") or {}
ctx = dict(ctx)
ctx["conversation_id"] = current_conv_id
dispatch_dict["context"] = ctx
logger.info(
"[durable] resolved alias on POST response_id=%s base=%s current=%s",
response_id,
base_conv_id,
current_conv_id,
)
else:
dispatch_dict = request_dict
durable_request = self.validator.validate_and_convert_request(dispatch_dict)

logger.info(
"Background response created response_id=%s stream=%s pod=%s",
Expand Down Expand Up @@ -1113,13 +1168,28 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None:
new_attempt - 1,
response_id,
)
# Rotate ANCHORED off the base id stored in original_request — never
# off the prior rotated form — so multi-crash chains stay flat
# (chat-123 → chat-123::attempt-3, never chat-123::attempt-2::attempt-3).
base_conv_id = (resp.original_request.get("context") or {}).get("conversation_id")
resume_dict = _rotate_conversation_id(resume_dict, new_attempt, response_id)
resume_request = self.validator.validate_and_convert_request(resume_dict)
# Surface the rotated conversation_id in the sentinel so clients that
# cache `chat_id → conversation_id` can pick up the rotation and use
# the rotated session on subsequent turns. Without this the next turn
# lands on the original (orphan-poisoned) session.
rotated_conv_id = (resume_dict.get("context") or {}).get("conversation_id")
# Persist the alias so future requests for ``base_conv_id`` resolve
# forward to the rotated form on every pod, surviving chatbot restarts
# and multi-pod chatbot deployments. Without this, the client would
# need to remember the rotation itself.
if base_conv_id and rotated_conv_id and rotated_conv_id != base_conv_id:
await upsert_conversation_alias(base_conv_id, rotated_conv_id)
logger.info(
"[durable] persisted alias response_id=%s base=%s current=%s",
response_id,
base_conv_id,
rotated_conv_id,
)
resume_request = self.validator.validate_and_convert_request(resume_dict)
# Keep emitting the response.resumed sentinel for visibility / debug
# / test assertions; clients no longer need to act on it for
# cross-turn alias tracking — the bridge handles that server-side.
await append_message(
response_id,
next_seq,
Expand Down
43 changes: 43 additions & 0 deletions tests/databricks_ai_bridge/test_long_running_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
create_response,
get_messages,
get_response,
resolve_conversation_alias,
update_response_status,
update_response_trace_id,
upsert_conversation_alias,
)


Expand Down Expand Up @@ -206,6 +208,47 @@ async def test_get_response_not_found(mock_session):
assert result is None


# ---------------------------------------------------------------------------
# Conversation alias tests (Option B: server-side rotation resolution)
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_resolve_conversation_alias_returns_current_when_present(mock_session):
result_mock = MagicMock()
result_mock.scalar_one_or_none.return_value = "chat-123::r-abcdef01-2"
mock_session.execute.return_value = result_mock

out = await resolve_conversation_alias("chat-123")
assert out == "chat-123::r-abcdef01-2"


@pytest.mark.asyncio
async def test_resolve_conversation_alias_passthrough_when_absent(mock_session):
"""A miss is the common case for fresh conversations — return the input
so the caller can dispatch with the client-sent id unchanged."""
result_mock = MagicMock()
result_mock.scalar_one_or_none.return_value = None
mock_session.execute.return_value = result_mock

out = await resolve_conversation_alias("chat-fresh")
assert out == "chat-fresh"


@pytest.mark.asyncio
async def test_upsert_conversation_alias_executes_and_commits(mock_session):
await upsert_conversation_alias("chat-123", "chat-123::r-abcdef01-2")
mock_session.execute.assert_awaited_once()
mock_session.commit.assert_awaited_once()
# SQL is an INSERT ... ON CONFLICT DO UPDATE so a second crash on the same
# conversation overwrites rather than duplicating the row.
sql_text = str(mock_session.execute.await_args.args[0])
assert "INSERT" in sql_text.upper()
assert "ON CONFLICT" in sql_text.upper()
bind_params = mock_session.execute.await_args.args[1]
assert bind_params == {"base": "chat-123", "current": "chat-123::r-abcdef01-2"}


# ---------------------------------------------------------------------------
# init_db / dispose_db / session_scope tests
# ---------------------------------------------------------------------------
Expand Down
Loading
Loading