diff --git a/docker-compose.yml b/docker-compose.yml index 3e60e5b..e131d34 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ services: opencmo: build: . ports: - - "8080:8080" + - "127.0.0.1:8080:8080" volumes: - opencmo_data:/data env_file: diff --git a/src/opencmo/llm.py b/src/opencmo/llm.py index fb7496f..30e2d56 100644 --- a/src/opencmo/llm.py +++ b/src/opencmo/llm.py @@ -63,6 +63,16 @@ "OPENAI_BASE_URL", "OPENCMO_MODEL_DEFAULT", }) +_ACCOUNT_SCOPED_SECRET_KEYS = frozenset({ + "REDDIT_CLIENT_ID", + "REDDIT_CLIENT_SECRET", + "REDDIT_USERNAME", + "REDDIT_PASSWORD", + "TWITTER_API_KEY", + "TWITTER_API_SECRET", + "TWITTER_ACCESS_TOKEN", + "TWITTER_ACCESS_SECRET", +}) # --------------------------------------------------------------------------- # ContextVar — per-request key isolation (asyncio Task-local) @@ -178,7 +188,12 @@ def get_key(name: str, default: str | None = None) -> str | None: if val: return val - # 3. os.environ + # 3. Sensitive account-scoped secrets must fail closed to avoid + # cross-account credential bleed through process-global env. + if name in _ACCOUNT_SCOPED_SECRET_KEYS: + return default + + # 4. os.environ val = os.environ.get(name) if val: return val @@ -222,13 +237,17 @@ async def get_key_async(name: str, default: str | None = None) -> str | None: except Exception: pass - # 4. For core router defaults, prefer env/.env over persisted DB settings. + # 4. Sensitive account-scoped secrets must not fall through to system/env. + if name in _ACCOUNT_SCOPED_SECRET_KEYS: + return default + + # 5. For core router defaults, prefer env/.env over persisted DB settings. if name in _ENV_PRIORITY_KEYS: val = os.environ.get(name) if val: return val - # 5. System fallback (admin account → legacy settings table) + # 6. System fallback (admin account → legacy settings table) try: from opencmo import storage val = await storage.get_system_setting(name) @@ -237,7 +256,7 @@ async def get_key_async(name: str, default: str | None = None) -> str | None: except Exception: pass # DB may not be initialized yet - # 6. os.environ + # 7. os.environ val = os.environ.get(name) if val: return val diff --git a/src/opencmo/storage/accounts.py b/src/opencmo/storage/accounts.py index a370cbe..3454466 100644 --- a/src/opencmo/storage/accounts.py +++ b/src/opencmo/storage/accounts.py @@ -177,6 +177,7 @@ async def get_user_account(user_id: int) -> dict | None: async def create_user_with_account(email: str, password: str, name: str = "") -> tuple[dict, dict]: normalized = (email or "").strip().lower() + admin_email = os.environ.get("OPENCMO_ADMIN_EMAIL", "hello@aidcmo.com").strip().lower() if not is_valid_email(normalized): raise ValueError("invalid_email") if len(password or "") < MIN_PASSWORD_LENGTH: @@ -189,6 +190,8 @@ async def create_user_with_account(email: str, password: str, name: str = "") -> try: existing = await db.execute("SELECT id, password_hash FROM users WHERE email = ?", (normalized,)) row = await existing.fetchone() + if row and row[1] == "!unusable" and normalized == admin_email: + raise ValueError("email_exists") if row and row[1] != "!unusable": raise ValueError("email_exists") @@ -206,7 +209,7 @@ async def create_user_with_account(email: str, password: str, name: str = "") -> normalized, hash_password(password), name.strip(), - "admin" if normalized == os.environ.get("OPENCMO_ADMIN_EMAIL", "hello@aidcmo.com").strip().lower() else "user", + "user", ), ) user_id = int(cursor.lastrowid) diff --git a/src/opencmo/tools/geo_ask.py b/src/opencmo/tools/geo_ask.py index 8a459ba..3cb4796 100644 --- a/src/opencmo/tools/geo_ask.py +++ b/src/opencmo/tools/geo_ask.py @@ -50,10 +50,19 @@ def _select_providers(platform_names: list[str] | None) -> tuple[list[GeoProvide case_map = {p.name.lower(): p for p in enabled} selected: list[GeoProvider] = [] unknown: list[str] = [] + seen: set[str] = set() for name in platform_names: - prov = enabled_by_name.get(name) or case_map.get(name.lower()) + if not isinstance(name, str): + unknown.append(str(name)) + continue + normalized = name.strip() + key = normalized.lower() + if not normalized or key in seen: + continue + seen.add(key) + prov = enabled_by_name.get(normalized) or case_map.get(key) if prov is None: - unknown.append(name) + unknown.append(normalized) else: selected.append(prov) return selected, unknown diff --git a/src/opencmo/web/app.py b/src/opencmo/web/app.py index 1508c5b..a9f1d80 100644 --- a/src/opencmo/web/app.py +++ b/src/opencmo/web/app.py @@ -1558,11 +1558,7 @@ async def api_v1_auth_verify_email(payload: _AuthVerifyEmailRequest, request: Re return JSONResponse({"ok": False, "error": "user_not_found"}, status_code=404) if await storage.is_user_verified(payload.user_id): - # Idempotent: already verified -> just sign them in. - account = await storage.get_user_account(payload.user_id) - if account is None or account["status"] != "active": - return JSONResponse({"ok": False, "error": "account_unavailable"}, status_code=403) - return await _json_with_session(request, user, account) + return JSONResponse({"ok": False, "error": "already_verified"}, status_code=400) result = await storage.consume_verification_code(payload.user_id, payload.code, purpose="signup") if not result.get("ok"): diff --git a/src/opencmo/web/routers/projects.py b/src/opencmo/web/routers/projects.py index 069cd82..1d9681c 100644 --- a/src/opencmo/web/routers/projects.py +++ b/src/opencmo/web/routers/projects.py @@ -400,8 +400,23 @@ async def api_v1_geo_ask(project_id: int, request: Request): return JSONResponse({"error": "query too long (max 500 chars)"}, status_code=400) platforms = body.get("platforms") - if platforms is not None and not isinstance(platforms, list): - return JSONResponse({"error": "platforms must be a list of strings"}, status_code=400) + if platforms is not None: + if not isinstance(platforms, list): + return JSONResponse({"error": "platforms must be a list of strings"}, status_code=400) + if len(platforms) > 20: + return JSONResponse({"error": "platforms too long (max 20 entries)"}, status_code=400) + if any(not isinstance(name, str) or not name.strip() for name in platforms): + return JSONResponse({"error": "platforms must be a list of non-empty strings"}, status_code=400) + # Collapse duplicate provider names (case-insensitive) to prevent fan-out amplification. + deduped: list[str] = [] + seen: set[str] = set() + for name in platforms: + key = name.strip().lower() + if key in seen: + continue + seen.add(key) + deduped.append(name.strip()) + platforms = deduped from dataclasses import asdict diff --git a/tests/test_email_verification.py b/tests/test_email_verification.py index e072fd4..b922d99 100644 --- a/tests/test_email_verification.py +++ b/tests/test_email_verification.py @@ -205,6 +205,31 @@ def test_verified_user_login_succeeds(verification_db): assert login.json()["authenticated"] is True +def test_verify_email_already_verified_requires_login_flow(verification_db): + with TestClient(app) as client: + signup = _signup(client, "already@example.test") + user_id = signup["user_id"] + code = _last_code_for("already@example.test") + + first = client.post( + "/api/v1/auth/verify-email", + json={"user_id": user_id, "code": code}, + ) + assert first.status_code == 200, first.text + assert first.json()["authenticated"] is True + + client.post("/api/v1/auth/logout") + client.cookies.clear() + + second = client.post( + "/api/v1/auth/verify-email", + json={"user_id": user_id, "code": "000000"}, + ) + assert second.status_code == 400, second.text + assert second.json()["error"] == "already_verified" + assert "opencmo_session" not in client.cookies + + def test_existing_legacy_users_remain_verified_after_backfill(tmp_path, monkeypatch): """A user created before this feature should keep working — backfill runs once at ensure_db() startup.""" diff --git a/tests/test_geo_ask.py b/tests/test_geo_ask.py index 1cd3665..b74392b 100644 --- a/tests/test_geo_ask.py +++ b/tests/test_geo_ask.py @@ -201,6 +201,20 @@ def test_select_providers_reports_unknown(stub_registry): assert unknown == ["Nope"] + + +def test_select_providers_deduplicates_case_insensitive(stub_registry): + stub_registry.append(_StubProvider("StubA", result=_make_result("StubA"))) + selected, unknown = _select_providers(["StubA", "stuba", " StubA "]) + assert [p.name for p in selected] == ["StubA"] + assert unknown == [] + + +def test_select_providers_non_string_is_unknown(stub_registry): + stub_registry.append(_StubProvider("StubA", result=_make_result("StubA"))) + selected, unknown = _select_providers(["StubA", 123]) + assert [p.name for p in selected] == ["StubA"] + assert unknown == ["123"] def test_list_available_platforms_reports_status(stub_registry): stub_registry.append(_StubProvider("StubA", result=_make_result("StubA"))) items = list_available_platforms() @@ -318,3 +332,39 @@ def test_geo_ask_endpoint_platforms_must_be_list(client): json={"query": "hi", "platforms": "not-a-list"}, ) assert resp.status_code == 400 + + +def test_geo_ask_endpoint_platforms_max_entries(client): + pid = _seed_project() + resp = client.post( + f"/api/v1/projects/{pid}/geo/ask", + json={"query": "hi", "platforms": ["Perplexity"] * 21}, + ) + assert resp.status_code == 400 + assert "max 20" in resp.json()["error"] + + +def test_geo_ask_endpoint_platforms_must_be_non_empty_strings(client): + pid = _seed_project() + resp = client.post( + f"/api/v1/projects/{pid}/geo/ask", + json={"query": "hi", "platforms": ["Perplexity", "", 123]}, + ) + assert resp.status_code == 400 + assert "non-empty strings" in resp.json()["error"] + + +def test_geo_ask_endpoint_platforms_deduped_before_call(client): + pid = _seed_project() + fake_response = GeoAskResponse(query="hi", results=[], total_duration_ms=1, query_lang="en") + with patch( + "opencmo.tools.geo_ask.ask_platforms", + new=AsyncMock(return_value=fake_response), + ) as ask_mock: + resp = client.post( + f"/api/v1/projects/{pid}/geo/ask", + json={"query": "hi", "platforms": ["Perplexity", "perplexity", " Perplexity "]}, + ) + assert resp.status_code == 200 + assert ask_mock.await_count == 1 + assert ask_mock.await_args.kwargs["platform_names"] == ["Perplexity"] diff --git a/tests/test_publishers.py b/tests/test_publishers.py index f8ac9e0..50a47f3 100644 --- a/tests/test_publishers.py +++ b/tests/test_publishers.py @@ -34,57 +34,68 @@ async def test_publish_reddit_no_consent(): @pytest.mark.asyncio async def test_publish_reddit_success(monkeypatch): """Real publish with mocked praw.""" - monkeypatch.setenv("REDDIT_CLIENT_ID", "test_id") - monkeypatch.setenv("REDDIT_CLIENT_SECRET", "test_secret") - monkeypatch.setenv("REDDIT_USERNAME", "test_user") - monkeypatch.setenv("REDDIT_PASSWORD", "test_pass") - - import opencmo.tools.publishers as pub + from opencmo import llm + + token = llm.set_request_keys({ + "REDDIT_CLIENT_ID": "test_id", + "REDDIT_CLIENT_SECRET": "test_secret", + "REDDIT_USERNAME": "test_user", + "REDDIT_PASSWORD": "test_pass", + }) + try: + import opencmo.tools.publishers as pub - mock_submission = MagicMock() - mock_submission.permalink = "/r/test/comments/abc123" - mock_submission.id = "abc123" + mock_submission = MagicMock() + mock_submission.permalink = "/r/test/comments/abc123" + mock_submission.id = "abc123" - mock_sub = MagicMock() - mock_sub.submit = MagicMock(return_value=mock_submission) + mock_sub = MagicMock() + mock_sub.submit = MagicMock(return_value=mock_submission) - mock_reddit = MagicMock() - mock_reddit.subreddit = MagicMock(return_value=mock_sub) + mock_reddit = MagicMock() + mock_reddit.subreddit = MagicMock(return_value=mock_sub) - # Create a fake praw module - mock_praw = MagicMock() - mock_praw.Reddit = MagicMock(return_value=mock_reddit) + mock_praw = MagicMock() + mock_praw.Reddit = MagicMock(return_value=mock_reddit) - pub._HAS_PRAW = True - monkeypatch.setattr(pub, "praw", mock_praw, raising=False) + pub._HAS_PRAW = True + monkeypatch.setattr(pub, "praw", mock_praw, raising=False) - result = await pub.publish_reddit_post_impl("test", "Title", "Body", dry_run=False) + result = await pub.publish_reddit_post_impl("test", "Title", "Body", dry_run=False) - assert result["ok"] - assert not result["dry_run"] - assert "reddit.com" in result["url"] + assert result["ok"] + assert not result["dry_run"] + assert "reddit.com" in result["url"] + finally: + llm.reset_request_keys(token) @pytest.mark.asyncio async def test_publish_reddit_error(monkeypatch): """Reddit API error returns error dict, doesn't raise.""" - monkeypatch.setenv("REDDIT_CLIENT_ID", "id") - monkeypatch.setenv("REDDIT_CLIENT_SECRET", "secret") - monkeypatch.setenv("REDDIT_USERNAME", "user") - monkeypatch.setenv("REDDIT_PASSWORD", "pass") - - import opencmo.tools.publishers as pub + from opencmo import llm + + token = llm.set_request_keys({ + "REDDIT_CLIENT_ID": "id", + "REDDIT_CLIENT_SECRET": "secret", + "REDDIT_USERNAME": "user", + "REDDIT_PASSWORD": "pass", + }) + try: + import opencmo.tools.publishers as pub - mock_praw = MagicMock() - mock_praw.Reddit = MagicMock(side_effect=Exception("Auth failed")) + mock_praw = MagicMock() + mock_praw.Reddit = MagicMock(side_effect=Exception("Auth failed")) - pub._HAS_PRAW = True - monkeypatch.setattr(pub, "praw", mock_praw, raising=False) + pub._HAS_PRAW = True + monkeypatch.setattr(pub, "praw", mock_praw, raising=False) - result = await pub.publish_reddit_post_impl("test", "T", "B", dry_run=False) + result = await pub.publish_reddit_post_impl("test", "T", "B", dry_run=False) - assert not result["ok"] - assert "Auth failed" in result["error"] + assert not result["ok"] + assert "Auth failed" in result["error"] + finally: + llm.reset_request_keys(token) @pytest.mark.asyncio @@ -132,29 +143,35 @@ async def test_publish_tweet_too_long(): @pytest.mark.asyncio async def test_publish_tweet_success(monkeypatch): """Real publish with mocked tweepy.""" - monkeypatch.setenv("TWITTER_API_KEY", "key") - monkeypatch.setenv("TWITTER_API_SECRET", "secret") - monkeypatch.setenv("TWITTER_ACCESS_TOKEN", "token") - monkeypatch.setenv("TWITTER_ACCESS_SECRET", "secret") - - import opencmo.tools.publishers as pub + from opencmo import llm + + token = llm.set_request_keys({ + "TWITTER_API_KEY": "key", + "TWITTER_API_SECRET": "secret", + "TWITTER_ACCESS_TOKEN": "token", + "TWITTER_ACCESS_SECRET": "secret", + }) + try: + import opencmo.tools.publishers as pub - mock_response = MagicMock() - mock_response.data = {"id": "12345"} + mock_response = MagicMock() + mock_response.data = {"id": "12345"} - mock_client = MagicMock() - mock_client.create_tweet = MagicMock(return_value=mock_response) + mock_client = MagicMock() + mock_client.create_tweet = MagicMock(return_value=mock_response) - mock_tweepy = MagicMock() - mock_tweepy.Client = MagicMock(return_value=mock_client) + mock_tweepy = MagicMock() + mock_tweepy.Client = MagicMock(return_value=mock_client) - pub._HAS_TWEEPY = True - monkeypatch.setattr(pub, "tweepy", mock_tweepy, raising=False) + pub._HAS_TWEEPY = True + monkeypatch.setattr(pub, "tweepy", mock_tweepy, raising=False) - result = await pub.publish_tweet_impl("Hello!", dry_run=False) + result = await pub.publish_tweet_impl("Hello!", dry_run=False) - assert result["ok"] - assert result["tweet_id"] == "12345" + assert result["ok"] + assert result["tweet_id"] == "12345" + finally: + llm.reset_request_keys(token) @pytest.mark.asyncio diff --git a/tests/test_settings_multitenant.py b/tests/test_settings_multitenant.py index 71bee9a..79848da 100644 --- a/tests/test_settings_multitenant.py +++ b/tests/test_settings_multitenant.py @@ -175,6 +175,23 @@ def test_get_key_async_reads_per_account_via_db_when_snapshot_empty(isolated_db) assert value == "alice_db_only" +def test_publish_credentials_do_not_fallback_to_env_or_system(isolated_db, monkeypatch): + """Tenant-missing publish creds must not resolve from global/system fallbacks.""" + admin_id = asyncio.run(storage.get_admin_account_id()) + _, tenant_account = _seed_account(email="tenant@example.test", name="Tenant") + asyncio.run(storage.set_account_setting(admin_id, "REDDIT_CLIENT_ID", "admin-cid")) + monkeypatch.setenv("REDDIT_CLIENT_ID", "env-cid") + + acct_token = llm.set_current_account_id(tenant_account) + snap_token = llm.set_current_account_settings({}) + try: + assert asyncio.run(llm.get_key_async("REDDIT_CLIENT_ID")) is None + assert llm.get_key("REDDIT_CLIENT_ID") is None + finally: + llm.reset_current_account_settings(snap_token) + llm.reset_current_account_id(acct_token) + + # --------------------------------------------------------------------------- # System fallback / legacy table # --------------------------------------------------------------------------- diff --git a/tests/test_trial_platform.py b/tests/test_trial_platform.py index 77721e0..537cc47 100644 --- a/tests/test_trial_platform.py +++ b/tests/test_trial_platform.py @@ -79,6 +79,47 @@ def _signup(client: TestClient, email: str, password: str = "password123") -> di return payload +def _seed_admin(client: TestClient, email: str = "admin@example.test", password: str = "password123") -> dict: + """Activate the bootstrapped admin user and log in. + + Admin bootstrap inserts a row with password_hash='!unusable'; self-service + signup can no longer claim that row (see PR #22), so tests must set the + password directly and mark the admin verified before logging in. + """ + from opencmo.storage._db import get_db + from opencmo.storage.accounts import hash_password + + async def _activate() -> int: + db = await get_db() + try: + cursor = await db.execute( + "SELECT id FROM users WHERE email = ?", + (email.lower(),), + ) + row = await cursor.fetchone() + assert row is not None, f"admin row not bootstrapped for {email}" + user_id = int(row[0]) + await db.execute( + "UPDATE users SET password_hash = ?, status = 'active' WHERE id = ?", + (hash_password(password), user_id), + ) + await db.commit() + return user_id + finally: + await db.close() + + user_id = asyncio.run(_activate()) + asyncio.run(storage.mark_user_verified(user_id)) + login = client.post( + "/api/v1/auth/login", + json={"email": email, "password": password}, + ) + assert login.status_code == 200, login.text + payload = login.json() + assert payload["authenticated"] is True + return payload + + def test_signup_login_me_and_logout(trial_db): with TestClient(app) as client: signup = _signup(client, "user@example.test") @@ -286,7 +327,7 @@ def test_legacy_project_global_unique_is_reconciled(tmp_path, monkeypatch): def test_admin_summary_requires_admin(trial_db): with TestClient(app) as client: - _signup(client, "admin@example.test") + _seed_admin(client) admin_cookie = client.cookies.get("opencmo_session") _signup(client, "normal@example.test") user_cookie = client.cookies.get("opencmo_session") @@ -305,7 +346,7 @@ def test_admin_summary_requires_admin(trial_db): def test_admin_account_actions_update_and_disable_access(trial_db): with TestClient(app) as client: - _signup(client, "admin@example.test") + _seed_admin(client) admin_cookie = client.cookies.get("opencmo_session") user_payload = _signup(client, "managed@example.test") user_cookie = client.cookies.get("opencmo_session")