Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/opencmo/tools/geo_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,19 @@ def _select_providers(platform_names: list[str] | None) -> tuple[list[GeoProvide
case_map = {p.name.lower(): p for p in enabled}
selected: list[GeoProvider] = []
unknown: list[str] = []
seen: set[str] = set()
for name in platform_names:
prov = enabled_by_name.get(name) or case_map.get(name.lower())
if not isinstance(name, str):
unknown.append(str(name))
continue
normalized = name.strip()
key = normalized.lower()
if not normalized or key in seen:
continue
seen.add(key)
prov = enabled_by_name.get(normalized) or case_map.get(key)
if prov is None:
unknown.append(name)
unknown.append(normalized)
else:
selected.append(prov)
return selected, unknown
Expand Down
19 changes: 17 additions & 2 deletions src/opencmo/web/routers/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,23 @@ async def api_v1_geo_ask(project_id: int, request: Request):
return JSONResponse({"error": "query too long (max 500 chars)"}, status_code=400)

platforms = body.get("platforms")
if platforms is not None and not isinstance(platforms, list):
return JSONResponse({"error": "platforms must be a list of strings"}, status_code=400)
if platforms is not None:
if not isinstance(platforms, list):
return JSONResponse({"error": "platforms must be a list of strings"}, status_code=400)
if len(platforms) > 20:
return JSONResponse({"error": "platforms too long (max 20 entries)"}, status_code=400)
if any(not isinstance(name, str) or not name.strip() for name in platforms):
return JSONResponse({"error": "platforms must be a list of non-empty strings"}, status_code=400)
# Collapse duplicate provider names (case-insensitive) to prevent fan-out amplification.
deduped: list[str] = []
seen: set[str] = set()
for name in platforms:
key = name.strip().lower()
if key in seen:
continue
seen.add(key)
deduped.append(name.strip())
platforms = deduped

from dataclasses import asdict

Expand Down
50 changes: 50 additions & 0 deletions tests/test_geo_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,20 @@ def test_select_providers_reports_unknown(stub_registry):
assert unknown == ["Nope"]




def test_select_providers_deduplicates_case_insensitive(stub_registry):
stub_registry.append(_StubProvider("StubA", result=_make_result("StubA")))
selected, unknown = _select_providers(["StubA", "stuba", " StubA "])
assert [p.name for p in selected] == ["StubA"]
assert unknown == []


def test_select_providers_non_string_is_unknown(stub_registry):
stub_registry.append(_StubProvider("StubA", result=_make_result("StubA")))
selected, unknown = _select_providers(["StubA", 123])
assert [p.name for p in selected] == ["StubA"]
assert unknown == ["123"]
def test_list_available_platforms_reports_status(stub_registry):
stub_registry.append(_StubProvider("StubA", result=_make_result("StubA")))
items = list_available_platforms()
Expand Down Expand Up @@ -318,3 +332,39 @@ def test_geo_ask_endpoint_platforms_must_be_list(client):
json={"query": "hi", "platforms": "not-a-list"},
)
assert resp.status_code == 400


def test_geo_ask_endpoint_platforms_max_entries(client):
pid = _seed_project()
resp = client.post(
f"/api/v1/projects/{pid}/geo/ask",
json={"query": "hi", "platforms": ["Perplexity"] * 21},
)
assert resp.status_code == 400
assert "max 20" in resp.json()["error"]


def test_geo_ask_endpoint_platforms_must_be_non_empty_strings(client):
pid = _seed_project()
resp = client.post(
f"/api/v1/projects/{pid}/geo/ask",
json={"query": "hi", "platforms": ["Perplexity", "", 123]},
)
assert resp.status_code == 400
assert "non-empty strings" in resp.json()["error"]


def test_geo_ask_endpoint_platforms_deduped_before_call(client):
pid = _seed_project()
fake_response = GeoAskResponse(query="hi", results=[], total_duration_ms=1, query_lang="en")
with patch(
"opencmo.tools.geo_ask.ask_platforms",
new=AsyncMock(return_value=fake_response),
) as ask_mock:
resp = client.post(
f"/api/v1/projects/{pid}/geo/ask",
json={"query": "hi", "platforms": ["Perplexity", "perplexity", " Perplexity "]},
)
assert resp.status_code == 200
assert ask_mock.await_count == 1
assert ask_mock.await_args.kwargs["platform_names"] == ["Perplexity"]
Loading