Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions app/features/registry/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,18 +601,30 @@ async def _find_duplicate(
data_window_end: Data window end date.

Returns:
Existing run or None.
The most recent matching run, or None.

Note:
Under ``registry_duplicate_policy="detect"`` (the default), duplicate
runs are intentionally created, so multiple non-archived rows may
share the same config hash. This query therefore orders by
``created_at`` and takes the first row rather than asserting a single
match — ``scalar_one_or_none()`` would raise ``MultipleResultsFound``.
"""
stmt = select(ModelRun).where(
(ModelRun.config_hash == config_hash)
& (ModelRun.store_id == store_id)
& (ModelRun.product_id == product_id)
& (ModelRun.data_window_start == data_window_start)
& (ModelRun.data_window_end == data_window_end)
& (ModelRun.status != RunStatusORM.ARCHIVED.value)
stmt = (
select(ModelRun)
.where(
(ModelRun.config_hash == config_hash)
& (ModelRun.store_id == store_id)
& (ModelRun.product_id == product_id)
& (ModelRun.data_window_start == data_window_start)
& (ModelRun.data_window_end == data_window_end)
& (ModelRun.status != RunStatusORM.ARCHIVED.value)
)
.order_by(ModelRun.created_at.desc())
.limit(1)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
return result.scalars().first()

def _model_to_response(self, model_run: ModelRun) -> RunResponse:
"""Convert ORM model to response schema.
Expand Down
24 changes: 24 additions & 0 deletions app/features/registry/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ async def test_create_run_invalid_date_order(self, client: AsyncClient) -> None:
)
assert response.status_code == 422

async def test_create_run_repeated_duplicate_does_not_500(self, client: AsyncClient) -> None:
"""Repeated identical runs must not 500 (regression for #146).

Under the default ``registry_duplicate_policy="detect"`` duplicate runs
are created intentionally, so multiple non-archived rows can share one
config hash. ``_find_duplicate`` previously used ``scalar_one_or_none()``,
which raised ``MultipleResultsFound`` once two duplicates existed — the
third POST returned ``HTTP 500 Database Error``.
"""
payload = {
"model_type": "test-dup-regression",
"model_config": {"strategy": "last_value"},
"data_window_start": "2024-01-01",
"data_window_end": "2024-03-31",
"store_id": 1,
"product_id": 1,
}

# Three identical creates: 1st has no prior match, 2nd has one,
# 3rd would hit the MultipleResultsFound trap before the fix.
for _ in range(3):
response = await client.post("/registry/runs", json=payload)
assert response.status_code == 201, response.text


class TestListRunsEndpoint:
"""Tests for GET /registry/runs endpoint."""
Expand Down