From 5f1752425449a0baf28bda7e3e0a27f7cfff7c33 Mon Sep 17 00:00:00 2001 From: JingWen Fan <106414602+study8677@users.noreply.github.com> Date: Wed, 20 May 2026 17:57:23 +0800 Subject: [PATCH] fix: validate and dedupe geo ask platforms --- src/opencmo/tools/geo_ask.py | 13 ++++++-- src/opencmo/web/routers/projects.py | 19 +++++++++-- tests/test_geo_ask.py | 50 +++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/src/opencmo/tools/geo_ask.py b/src/opencmo/tools/geo_ask.py index 8a459ba..3cb4796 100644 --- a/src/opencmo/tools/geo_ask.py +++ b/src/opencmo/tools/geo_ask.py @@ -50,10 +50,19 @@ def _select_providers(platform_names: list[str] | None) -> tuple[list[GeoProvide case_map = {p.name.lower(): p for p in enabled} selected: list[GeoProvider] = [] unknown: list[str] = [] + seen: set[str] = set() for name in platform_names: - prov = enabled_by_name.get(name) or case_map.get(name.lower()) + if not isinstance(name, str): + unknown.append(str(name)) + continue + normalized = name.strip() + key = normalized.lower() + if not normalized or key in seen: + continue + seen.add(key) + prov = enabled_by_name.get(normalized) or case_map.get(key) if prov is None: - unknown.append(name) + unknown.append(normalized) else: selected.append(prov) return selected, unknown diff --git a/src/opencmo/web/routers/projects.py b/src/opencmo/web/routers/projects.py index 069cd82..1d9681c 100644 --- a/src/opencmo/web/routers/projects.py +++ b/src/opencmo/web/routers/projects.py @@ -400,8 +400,23 @@ async def api_v1_geo_ask(project_id: int, request: Request): return JSONResponse({"error": "query too long (max 500 chars)"}, status_code=400) platforms = body.get("platforms") - if platforms is not None and not isinstance(platforms, list): - return JSONResponse({"error": "platforms must be a list of strings"}, status_code=400) + if platforms is not None: + if not isinstance(platforms, list): + return JSONResponse({"error": "platforms must be a list of strings"}, status_code=400) + if len(platforms) > 20: + return JSONResponse({"error": "platforms too long (max 20 entries)"}, status_code=400) + if any(not isinstance(name, str) or not name.strip() for name in platforms): + return JSONResponse({"error": "platforms must be a list of non-empty strings"}, status_code=400) + # Collapse duplicate provider names (case-insensitive) to prevent fan-out amplification. + deduped: list[str] = [] + seen: set[str] = set() + for name in platforms: + key = name.strip().lower() + if key in seen: + continue + seen.add(key) + deduped.append(name.strip()) + platforms = deduped from dataclasses import asdict diff --git a/tests/test_geo_ask.py b/tests/test_geo_ask.py index 1cd3665..b74392b 100644 --- a/tests/test_geo_ask.py +++ b/tests/test_geo_ask.py @@ -201,6 +201,20 @@ def test_select_providers_reports_unknown(stub_registry): assert unknown == ["Nope"] + + +def test_select_providers_deduplicates_case_insensitive(stub_registry): + stub_registry.append(_StubProvider("StubA", result=_make_result("StubA"))) + selected, unknown = _select_providers(["StubA", "stuba", " StubA "]) + assert [p.name for p in selected] == ["StubA"] + assert unknown == [] + + +def test_select_providers_non_string_is_unknown(stub_registry): + stub_registry.append(_StubProvider("StubA", result=_make_result("StubA"))) + selected, unknown = _select_providers(["StubA", 123]) + assert [p.name for p in selected] == ["StubA"] + assert unknown == ["123"] def test_list_available_platforms_reports_status(stub_registry): stub_registry.append(_StubProvider("StubA", result=_make_result("StubA"))) items = list_available_platforms() @@ -318,3 +332,39 @@ def test_geo_ask_endpoint_platforms_must_be_list(client): json={"query": "hi", "platforms": "not-a-list"}, ) assert resp.status_code == 400 + + +def test_geo_ask_endpoint_platforms_max_entries(client): + pid = _seed_project() + resp = client.post( + f"/api/v1/projects/{pid}/geo/ask", + json={"query": "hi", "platforms": ["Perplexity"] * 21}, + ) + assert resp.status_code == 400 + assert "max 20" in resp.json()["error"] + + +def test_geo_ask_endpoint_platforms_must_be_non_empty_strings(client): + pid = _seed_project() + resp = client.post( + f"/api/v1/projects/{pid}/geo/ask", + json={"query": "hi", "platforms": ["Perplexity", "", 123]}, + ) + assert resp.status_code == 400 + assert "non-empty strings" in resp.json()["error"] + + +def test_geo_ask_endpoint_platforms_deduped_before_call(client): + pid = _seed_project() + fake_response = GeoAskResponse(query="hi", results=[], total_duration_ms=1, query_lang="en") + with patch( + "opencmo.tools.geo_ask.ask_platforms", + new=AsyncMock(return_value=fake_response), + ) as ask_mock: + resp = client.post( + f"/api/v1/projects/{pid}/geo/ask", + json={"query": "hi", "platforms": ["Perplexity", "perplexity", " Perplexity "]}, + ) + assert resp.status_code == 200 + assert ask_mock.await_count == 1 + assert ask_mock.await_args.kwargs["platform_names"] == ["Perplexity"]