diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..0277ee8 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,19 @@ +.git +.github +.claude +.venv +.env +.env.* +__pycache__ +*.pyc +*.pyo +certs/ +data/ +logs/ +docs/ +scripts/ +tests/ +*.log +.DS_Store +.vscode +VIDEO_DESCRIPTION_FIX.md diff --git a/.github/workflows/container-release.yml b/.github/workflows/container-release.yml new file mode 100644 index 0000000..bf8cd9d --- /dev/null +++ b/.github/workflows/container-release.yml @@ -0,0 +1,51 @@ +name: Container Release + +on: + pull_request: + branches: [main] + push: + branches: [main] + tags: ["v*"] + +permissions: + contents: read + packages: write + +jobs: + build-and-push: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to ghcr.io + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ghcr.io/killrvideo/kv-be-python-fastapi-dataapi-table + tags: | + type=sha,prefix=sha- + type=raw,value=latest,enable={{is_default_branch}} + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=semver,pattern={{major}} + + - name: Build and push + uses: docker/build-push-action@v6 + with: + context: . + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.gitignore b/.gitignore index 0d70896..ecada93 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ .env .python-version -.DS_Store +.DS_Store # Python __pycache__/ @@ -15,6 +15,11 @@ __pycache__/ *.log +# Data files (source of truth is in killrvideo-data project) +data/ +logs/ +dsbulk.conf + certs/ .vscode/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ebe1c9c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.12-slim AS builder + +WORKDIR /app + +RUN pip install --no-cache-dir poetry && \ + poetry config virtualenvs.in-project true + +COPY pyproject.toml poetry.lock ./ +RUN poetry install --only main --no-root --no-interaction + +COPY README.md ./ +COPY app/ app/ +RUN poetry install --only main --no-interaction + +# --------------------------------------------------------------------------- + +FROM python:3.12-slim + +WORKDIR /app + +COPY --from=builder /app/.venv /app/.venv +COPY --from=builder /app/app /app/app + +ENV PATH="/app/.venv/bin:$PATH" \ + PYTHONUNBUFFERED=1 + +EXPOSE 8000 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/app/api/v1/endpoints/user_activity.py b/app/api/v1/endpoints/user_activity.py new file mode 100644 index 0000000..fc14104 --- /dev/null +++ b/app/api/v1/endpoints/user_activity.py @@ -0,0 +1,51 @@ +"""API endpoint for querying user activity timelines.""" + +from __future__ import annotations + +from typing import Annotated, Literal, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, Query + +from app.api.v1.dependencies import PaginationParams +from app.models.common import PaginatedResponse, Pagination +from app.models.user_activity import UserActivityResponse +from app.services import user_activity_service + +router = APIRouter(tags=["User Activity"]) + + +@router.get( + "/users/{user_id_path}/activity", + response_model=PaginatedResponse[UserActivityResponse], + summary="Get user activity timeline", +) +async def get_user_activity( + user_id_path: UUID, + pagination: Annotated[PaginationParams, Depends()], + activity_type: Optional[Literal["view", "comment", "rate"]] = Query( + None, description="Filter by activity type (view, comment, rate)" + ), +): + """Return a paginated timeline of a user's activity over the last 30 days.""" + + activities, total = await user_activity_service.list_user_activity( + userid=user_id_path, + page=pagination.page, + page_size=pagination.pageSize, + activity_type=activity_type, + ) + + total_pages = (total + pagination.pageSize - 1) // pagination.pageSize + + response_items = [UserActivityResponse.model_validate(a) for a in activities] + + return PaginatedResponse[UserActivityResponse]( + data=response_items, + pagination=Pagination( + currentPage=pagination.page, + pageSize=pagination.pageSize, + totalItems=total, + totalPages=total_pages, + ), + ) diff --git a/app/api/v1/endpoints/video_catalog.py b/app/api/v1/endpoints/video_catalog.py index c620681..6fee967 100644 --- a/app/api/v1/endpoints/video_catalog.py +++ b/app/api/v1/endpoints/video_catalog.py @@ -256,7 +256,10 @@ async def record_view( ) # READY – record the view - await video_service.record_video_view(video_id_path) + await video_service.record_video_view( + video_id_path, + viewer_user_id=current_user.userid if current_user else None, + ) return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/app/main.py b/app/main.py index e734e90..8315a26 100644 --- a/app/main.py +++ b/app/main.py @@ -28,6 +28,7 @@ reco_internal, flags, moderation, + user_activity, ) # -------------------------------------------------------------- @@ -67,6 +68,7 @@ api_router_v1.include_router(reco_internal.router) api_router_v1.include_router(flags.router) api_router_v1.include_router(moderation.router) +api_router_v1.include_router(user_activity.router) app.include_router(api_router_v1) diff --git a/app/models/user_activity.py b/app/models/user_activity.py new file mode 100644 index 0000000..290cacb --- /dev/null +++ b/app/models/user_activity.py @@ -0,0 +1,46 @@ +"""Pydantic models for user activity tracking.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Literal, Optional +from uuid import UUID + +from pydantic import BaseModel, Field, ConfigDict + +from app.models.common import UserID + +ACTIVITY_TYPES = Literal["view", "comment", "rate"] + + +class UserActivity(BaseModel): + """Internal representation of a user activity row. + + Field names match DB column names (snake_case) exactly — no aliases needed. + """ + + model_config = ConfigDict(populate_by_name=True, from_attributes=True) + + userid: UserID + day: str + activity_type: ACTIVITY_TYPES + activity_id: UUID + activity_timestamp: datetime + + +class UserActivityResponse(BaseModel): + """API response representation for a single user activity item.""" + + model_config = ConfigDict(populate_by_name=True, from_attributes=True) + + userId: UserID = Field(..., validation_alias="userid") + activityType: str = Field(..., validation_alias="activity_type") + activityId: UUID = Field(..., validation_alias="activity_id") + activityTimestamp: datetime = Field(..., validation_alias="activity_timestamp") + + +__all__ = [ + "ACTIVITY_TYPES", + "UserActivity", + "UserActivityResponse", +] diff --git a/app/services/comment_service.py b/app/services/comment_service.py index 1323473..74619b0 100644 --- a/app/services/comment_service.py +++ b/app/services/comment_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from typing import Optional, List, Tuple from uuid import UUID, uuid1 @@ -15,10 +16,13 @@ from app.external_services.sentiment_mock import MockSentimentAnalyzer import inspect # local import to avoid new dependency from app.utils.db_helpers import safe_count +from app.services.user_activity_service import record_user_activity # testing mocks from unittest.mock import AsyncMock, MagicMock +logger = logging.getLogger(__name__) + COMMENTS_BY_VIDEO_TABLE_NAME = "comments" COMMENTS_BY_USER_TABLE_NAME = "comments_by_user" @@ -90,6 +94,18 @@ async def add_comment_to_video( await comments_by_video_table.insert_one(document=comment_doc) await comments_by_user_table.insert_one(document=comment_doc) + # Track in user_activity (never fail the comment operation) + try: + await record_user_activity( + userid=current_user.userid, + activity_type="comment", + activity_id=comment_id, + ) + except Exception: + logger.warning( + "user_activity insert failed for comment; ignoring", exc_info=True + ) + return new_comment diff --git a/app/services/rating_service.py b/app/services/rating_service.py index 9ec244c..f69c59e 100644 --- a/app/services/rating_service.py +++ b/app/services/rating_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from datetime import datetime, timezone from typing import Optional, List, Dict, Any from uuid import UUID @@ -18,53 +19,67 @@ from app.models.video import VideoID, VideoStatusEnum from app.models.user import User from app.services import video_service +from app.services.user_activity_service import record_user_activity from astrapy.exceptions.data_api_exceptions import DataAPIResponseException +logger = logging.getLogger(__name__) + RATINGS_TABLE_NAME = video_service.VIDEO_RATINGS_TABLE_NAME # "video_ratings_by_user" RATINGS_SUMMARY_TABLE_NAME = video_service.VIDEO_RATINGS_SUMMARY_TABLE_NAME async def _update_video_aggregate_rating( video_id: VideoID, - ratings_db_table: AstraDBCollection, - videos_db_table: AstraDBCollection, + new_rating: int, + old_rating: int | None = None, + summary_db_table: AstraDBCollection | None = None, ) -> None: - """Recalculate average and total ratings count for the given video.""" + """Increment counters on the video_ratings summary table. - cursor = ratings_db_table.find( - filter={"videoid": str(video_id)}, projection={"rating": 1} - ) - docs: List[Dict[str, Any]] = ( - await cursor.to_list() if hasattr(cursor, "to_list") else cursor - ) + * **New rating** (old_rating is None): increment rating_counter by 1 and + rating_total by new_rating. + * **Updated rating** (old_rating provided): increment rating_total by + (new_rating - old_rating) only — counter stays the same. + """ + + if summary_db_table is None: + summary_db_table = await get_table(RATINGS_SUMMARY_TABLE_NAME) - if docs: - values = [int(d["rating"]) for d in docs if "rating" in d] - total = len(values) - average = sum(values) / total if total else None + vid_str = str(video_id) + + if old_rating is None: + inc_doc: Dict[str, Any] = {"rating_counter": 1, "rating_total": new_rating} else: - total = 0 - average = None + delta = new_rating - old_rating + inc_doc = {"rating_total": delta} try: - await videos_db_table.update_one( - filter={"videoid": str(video_id)}, - update={ - "$set": { - "averageRating": average, - "totalRatingsCount": total, - "updatedAt": datetime.now(timezone.utc), - } - }, + await summary_db_table.update_one( + filter={"videoid": vid_str}, + update={"$inc": inc_doc}, + upsert=True, ) except DataAPIResponseException as exc: - # If the videos table schema does not include these columns (common - # when running against the default KillrVideo schema) Astra will - # reject the update with UNKNOWN_TABLE_COLUMNS. That is not fatal – - # the API can still compute aggregates on-the-fly. - if "UNKNOWN_TABLE_COLUMNS" not in str(exc): + if "Update operation not supported" in str( + exc + ) or "unsupported operations" in str(exc): + existing = await summary_db_table.find_one( + filter={"videoid": vid_str} + ) + counter = int(existing.get("rating_counter", 0)) if existing else 0 + total = int(existing.get("rating_total", 0)) if existing else 0 + if old_rating is None: + counter += 1 + total += new_rating + else: + total += new_rating - old_rating + await summary_db_table.update_one( + filter={"videoid": vid_str}, + update={"$set": {"rating_counter": counter, "rating_total": total}}, + upsert=True, + ) + else: raise - # Otherwise silently ignore so the rating operation succeeds. async def rate_video( @@ -112,6 +127,16 @@ async def rate_video( createdAt=created_at, updatedAt=now, ) + # Track in user_activity (never fail the rating operation) + try: + await record_user_activity( + userid=current_user.userid, + activity_type="rate", + ) + except Exception: + logger.debug( + "user_activity insert failed for rate; ignoring", exc_info=True + ) else: rating_obj = Rating( videoId=video_id, @@ -127,10 +152,23 @@ async def rate_video( "rating_date": now, } await db_table.insert_one(document=insert_doc) - - # update aggregate + # Track in user_activity (never fail the rating operation) + try: + await record_user_activity( + userid=current_user.userid, + activity_type="rate", + ) + except Exception: + logger.debug( + "user_activity insert failed for rate; ignoring", exc_info=True + ) + + # update aggregate counters on the summary table + old_rating_value: int | None = None + if existing_doc: + old_rating_value = int(existing_doc["rating"]) await _update_video_aggregate_rating( - video_id, db_table, await get_table(video_service.VIDEOS_TABLE_NAME) + video_id, new_rating=request.rating, old_rating=old_rating_value ) return rating_obj @@ -144,19 +182,35 @@ async def get_video_ratings_summary( video_id: VideoID, current_user_id: UUID | None = None, ratings_db_table: Optional[AstraDBCollection] = None, + summary_db_table: Optional[AstraDBCollection] = None, ) -> AggregateRatingResponse: """Return aggregated rating info for a video and optionally the caller's rating.""" - # Fetch video to access pre-computed aggregates + # 404 check – make sure the video exists target_video = await video_service.get_video_by_id(video_id) if target_video is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Video not found" ) - avg = target_video.averageRating - total = target_video.totalRatingsCount + # Read counters from the video_ratings summary table + if summary_db_table is None: + summary_db_table = await get_table(RATINGS_SUMMARY_TABLE_NAME) + + summary_doc = await summary_db_table.find_one( + filter={"videoid": str(video_id)} + ) + + if summary_doc: + rating_counter = int(summary_doc.get("rating_counter", 0)) + rating_total = int(summary_doc.get("rating_total", 0)) + avg = round(rating_total / rating_counter, 2) if rating_counter > 0 else None + total = rating_counter + else: + avg = None + total = 0 + # Look up current user's individual rating user_rating_value: RatingValue | None = None if current_user_id is not None: if ratings_db_table is None: diff --git a/app/services/user_activity_service.py b/app/services/user_activity_service.py new file mode 100644 index 0000000..d45a13a --- /dev/null +++ b/app/services/user_activity_service.py @@ -0,0 +1,165 @@ +"""Service layer for tracking user activity (views, comments, ratings).""" + +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, timezone, timedelta +from typing import Optional, List, Tuple +from uuid import UUID, uuid1 + +from app.db.astra_client import get_table, AstraDBCollection +from app.models.user_activity import UserActivity, ACTIVITY_TYPES + +USER_ACTIVITY_TABLE_NAME = "user_activity" + +ANONYMOUS_USER_ID = UUID("00000000-0000-0000-0000-000000000000") + +# Hard cap on total rows scanned across all 30 partitions to prevent OOM. +MAX_ACTIVITY_ROWS = 1000 + +logger = logging.getLogger(__name__) + + +async def record_user_activity( + userid: UUID, + activity_type: ACTIVITY_TYPES, + activity_id: Optional[UUID] = None, + db_table: Optional[AstraDBCollection] = None, +) -> None: + """Insert a single user activity row. + + Parameters + ---------- + userid: + The user who performed the action (use ANONYMOUS_USER_ID for unauthenticated). + activity_type: + One of 'view', 'comment', 'rate'. + activity_id: + Optional time-based UUID linking back to the activity. Auto-generated if not provided. + db_table: + Optional pre-fetched table reference (for testing). + """ + + if db_table is None: + db_table = await get_table(USER_ACTIVITY_TABLE_NAME) + + if activity_id is None: + activity_id = uuid1() + + now_utc = datetime.now(timezone.utc) + day_partition = now_utc.strftime("%Y-%m-%d") + + try: + await db_table.insert_one( + { + "userid": str(userid), + "day": day_partition, + "activity_type": activity_type, + "activity_id": str(activity_id), + "activity_timestamp": now_utc.isoformat(), + } + ) + except Exception: + logger.warning( + "Failed to record user activity for userid=%s activity_type=%s; skipping.", + userid, + activity_type, + exc_info=True, + ) + + +async def _fetch_day_rows( + db_table: AstraDBCollection, + userid: UUID, + day_key: str, + activity_type: Optional[str], + limit: int, +) -> List[dict]: + """Fetch activity rows for a single day partition. + + Returns an empty list on error so one bad partition does not abort the + entire read. + """ + try: + query_filter: dict = {"userid": str(userid), "day": day_key} + if activity_type: + query_filter["activity_type"] = activity_type + + cursor = db_table.find(filter=query_filter, limit=limit) + + if hasattr(cursor, "to_list"): + return await cursor.to_list() + return cursor # type: ignore[return-value] + except Exception: + logger.warning( + "Failed to fetch user activity for day=%s userid=%s; skipping partition.", + day_key, + userid, + exc_info=True, + ) + return [] + + +async def list_user_activity( + userid: UUID, + page: int, + page_size: int, + activity_type: Optional[str] = None, + db_table: Optional[AstraDBCollection] = None, +) -> Tuple[List[UserActivity], int]: + """Query user activity across the last 30 days of partitions. + + Queries all 30 day-partitions concurrently via asyncio.gather and applies a + hard cap of MAX_ACTIVITY_ROWS total rows to prevent unbounded memory usage. + + Returns + ------- + Tuple[List[UserActivity], int] + A page of activity items and the total count. + """ + + if db_table is None: + db_table = await get_table(USER_ACTIVITY_TABLE_NAME) + + today = datetime.now(timezone.utc).date() + start_date = today - timedelta(days=29) + + partition_keys: List[str] = [ + (start_date + timedelta(days=delta)).strftime("%Y-%m-%d") + for delta in range(30) + ] + + # Run all 30 partition queries concurrently; divide the total cap evenly + # across all partitions so that 30 x per_day_limit stays bounded at + # ~MAX_ACTIVITY_ROWS rather than 30 x MAX_ACTIVITY_ROWS. + per_day_limit = max(1, MAX_ACTIVITY_ROWS // 30) + + results: List[List[dict]] = await asyncio.gather( + *[ + _fetch_day_rows(db_table, userid, day_key, activity_type, per_day_limit) + for day_key in partition_keys + ] + ) + + all_rows: List[dict] = [] + for day_rows in results: + all_rows.extend(day_rows) + if len(all_rows) >= MAX_ACTIVITY_ROWS: + all_rows = all_rows[:MAX_ACTIVITY_ROWS] + break + + # Sort by activity_timestamp descending (newest first) + all_rows.sort( + key=lambda r: r.get("activity_timestamp", ""), + reverse=True, + ) + + total = len(all_rows) + + # Paginate + skip = (page - 1) * page_size + page_rows = all_rows[skip : skip + page_size] + + activities = [UserActivity.model_validate(r) for r in page_rows] + return activities, total diff --git a/app/services/video_service.py b/app/services/video_service.py index 6093e0d..a9e537c 100644 --- a/app/services/video_service.py +++ b/app/services/video_service.py @@ -88,10 +88,6 @@ logging.getLogger().getEffectiveLevel(), ) -# Flag to track if we've logged the view tracking limitation warning -_logged_views_disabled = False - - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -383,84 +379,60 @@ async def update_video_details( async def record_video_view( video_id: VideoID, + viewer_user_id: Optional[UUID] = None, db_table: Optional[AstraDBCollection] = None, ) -> None: """Increment the view counter stored directly in the *videos* table. - NOTE: View tracking is currently disabled. The 'views' column exists in the - CQL schema but is not yet exposed via the Astra DB Table API. This function - will gracefully no-op until API support is added. - - The dedicated ``video_playback_stats`` counter table is no longer updated – - we instead mutate the new ``views`` bigint column in the primary table so - the entire workflow remains Data-API-only. + The Astra DB Table API does not support ``$inc``, so we use a read-modify-write + cycle: fetch the current ``views`` value and write back with ``$set``. """ if db_table is None: db_table = await get_table(VIDEOS_TABLE_NAME) + # The Astra DB Table API does not support $inc. Use read-modify-write with $set. + current = ( + await db_table.find_one(filter={"videoid": _uuid_for_db(video_id, db_table)}) + or {} + ) + new_count = int(current.get("views", 0)) + 1 + await db_table.update_one( + filter={"videoid": _uuid_for_db(video_id, db_table)}, + update={"$set": {"views": new_count}}, + upsert=True, + ) + + from app.services.user_activity_service import ( + record_user_activity, + ANONYMOUS_USER_ID, + ) + + # Log individual view event in the time-series activity table (non-critical) try: - # Fast path – $inc is accepted on normal bigint columns - await db_table.update_one( - filter={"videoid": _uuid_for_db(video_id, db_table)}, - update={"$inc": {"views": 1}}, - upsert=True, + activity_table = await get_table(VIDEO_ACTIVITY_TABLE_NAME) + now_utc = datetime.now(timezone.utc) + day_partition = now_utc.strftime("%Y-%m-%d") # Cassandra date literal format + + await activity_table.insert_one( + { + "videoid": _uuid_for_db(video_id, db_table), + "day": day_partition, + "watch_time": str(uuid1()), # time-based UUID for clustering order + } ) - except DataAPIResponseException as exc: - global _logged_views_disabled - error_str = str(exc) - - # Check if this is the known Table API limitation where the 'views' column - # exists in CQL schema but isn't exposed via the Table API yet - if ( - "UNKNOWN_TABLE_COLUMNS" in error_str - or "UNSUPPORTED_UPDATE_OPERATIONS" in error_str - ): - # Log warning once per process lifecycle to avoid log spam - if not _logged_views_disabled: - logger.warning( - "View tracking is currently disabled. The 'views' column exists in " - "the CQL schema (docs/schema-astra.cql:95) but is not yet exposed " - "via the Astra DB Table API. Views will not be tracked until API " - "support is added. Error codes: UNKNOWN_TABLE_COLUMNS / " - "UNSUPPORTED_UPDATE_OPERATIONS_FOR_TABLE" - ) - _logged_views_disabled = True - return # Gracefully no-op without breaking the API contract - - # Some deployments (Astra *tables*) currently reject $inc on bigint – - # fall back to a manual read-modify-write cycle. - if ( - "Update operation not supported" in error_str - or "unsupported operations" in error_str - ): - current = ( - await db_table.find_one( - filter={"videoid": _uuid_for_db(video_id, db_table)} - ) - or {} - ) - new_count = int(current.get("views", 0)) + 1 - await db_table.update_one( - filter={"videoid": _uuid_for_db(video_id, db_table)}, - update={"$set": {"views": new_count}}, - upsert=True, - ) - else: - raise + except Exception: + logger.warning("video_activity insert failed for view; ignoring", exc_info=True) - # Log individual view event in the time-series activity table (unchanged) - activity_table = await get_table(VIDEO_ACTIVITY_TABLE_NAME) - now_utc = datetime.now(timezone.utc) - day_partition = now_utc.strftime("%Y-%m-%d") # Cassandra date literal format - - await activity_table.insert_one( - { - "videoid": _uuid_for_db(video_id, db_table), - "day": day_partition, - "watch_time": str(uuid1()), # time-based UUID for clustering order - } - ) + # Track in user_activity (never fail the view operation) + try: + effective_user_id = viewer_user_id if viewer_user_id else ANONYMOUS_USER_ID + await record_user_activity( + userid=effective_user_id, + activity_type="view", + ) + except Exception: + logger.warning("user_activity insert failed for view; ignoring", exc_info=True) async def list_videos_with_query( @@ -507,19 +479,23 @@ async def list_videos_with_query( span.set_attribute("page", page) span.set_attribute("page_size", page_size) - cursor = db_table.find( - filter=query_filter, skip=skip, limit=page_size, sort=sort_options - ) + from app.utils.db_helpers import safe_count, suppress_astrapy_warnings - docs: List[Dict[str, Any]] = [] - if hasattr(cursor, "to_list"): - docs = await cursor.to_list() - else: # Stub collection path - docs = cursor # type: ignore[assignment] + with suppress_astrapy_warnings( + "ZERO_FILTER_OPERATIONS", + "IN_MEMORY_SORTING", + ): + cursor = db_table.find( + filter=query_filter, skip=skip, limit=page_size, sort=sort_options + ) - # Use helper that gracefully degrades on tables - from app.utils.db_helpers import safe_count + docs: List[Dict[str, Any]] = [] + if hasattr(cursor, "to_list"): + docs = await cursor.to_list() + else: # Stub collection path + docs = cursor # type: ignore[assignment] + # Use helper that gracefully degrades on tables total_items = await safe_count( db_table, query_filter=query_filter, diff --git a/app/utils/db_helpers.py b/app/utils/db_helpers.py index dda9b09..391aa1b 100644 --- a/app/utils/db_helpers.py +++ b/app/utils/db_helpers.py @@ -13,11 +13,33 @@ object. """ -from typing import Any, Dict +import logging +from contextlib import contextmanager +from typing import Any, Dict, Iterator from astrapy.exceptions.data_api_exceptions import DataAPIResponseException # type: ignore -__all__ = ["safe_count"] +__all__ = ["safe_count", "suppress_astrapy_warnings"] + +_ASTRAPY_LOGGER = logging.getLogger("astrapy.utils.api_commander") + + +class _SuppressAstrapyWarnings(logging.Filter): + """Drop WARNING records whose message contains any of the given substrings. + + astrapy emits WARNINGs for certain operations that we handle or expect + (e.g. ``UNSUPPORTED_TABLE_COMMAND`` on tables, ``ZERO_FILTER_OPERATIONS`` + for unfiltered queries). This filter suppresses only the specified codes + so legitimate warnings still surface. + """ + + def __init__(self, substrings: frozenset[str]) -> None: + super().__init__() + self._substrings = substrings + + def filter(self, record: logging.LogRecord) -> bool: + msg = record.getMessage() + return not any(s in msg for s in self._substrings) async def safe_count( @@ -33,12 +55,34 @@ async def safe_count( an exception. The same applies to stub collections used in unit-tests. """ + _filter = _SuppressAstrapyWarnings(frozenset({"UNSUPPORTED_TABLE_COMMAND"})) + _ASTRAPY_LOGGER.addFilter(_filter) try: return await db_table.count_documents(filter=query_filter, upper_bound=10**9) - except (TypeError, DataAPIResponseException) as exc: # pragma: no cover – fallback + except (TypeError, DataAPIResponseException) as exc: if isinstance( exc, DataAPIResponseException ) and "UNSUPPORTED_TABLE_COMMAND" not in str(exc): # An unexpected Data API error – surface to caller. raise return fallback_len + finally: + _ASTRAPY_LOGGER.removeFilter(_filter) + + +@contextmanager +def suppress_astrapy_warnings(*warning_codes: str) -> Iterator[None]: + """Temporarily suppress astrapy warnings matching any of *warning_codes*. + + Usage:: + + with suppress_astrapy_warnings("ZERO_FILTER_OPERATIONS", "IN_MEMORY_SORTING"): + cursor = db_table.find(...) + docs = await cursor.to_list() + """ + _filter = _SuppressAstrapyWarnings(frozenset(warning_codes)) + _ASTRAPY_LOGGER.addFilter(_filter) + try: + yield + finally: + _ASTRAPY_LOGGER.removeFilter(_filter) diff --git a/docs/killrvideo_openapi.yaml b/docs/killrvideo_openapi.yaml index 6fc07ed..519e725 100644 --- a/docs/killrvideo_openapi.yaml +++ b/docs/killrvideo_openapi.yaml @@ -1271,6 +1271,70 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /api/v1/users/{user_id_path}/activity: + get: + tags: + - User Activity + summary: Get user activity timeline + description: Return a paginated timeline of a user's activity over the last + 30 days. + operationId: get_user_activity_api_v1_users__user_id_path__activity_get + parameters: + - name: user_id_path + in: path + required: true + schema: + type: string + format: uuid + title: User Id Path + - name: activity_type + in: query + required: false + schema: + anyOf: + - enum: + - view + - comment + - rate + type: string + - type: 'null' + description: Filter by activity type (view, comment, rate) + title: Activity Type + description: Filter by activity type (view, comment, rate) + - name: page + in: query + required: false + schema: + type: integer + minimum: 1 + description: Page number + default: 1 + title: Page + description: Page number + - name: pageSize + in: query + required: false + schema: + type: integer + maximum: 100 + minimum: 1 + description: Items per page + default: 10 + title: Pagesize + description: Items per page + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/PaginatedResponse_UserActivityResponse_' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /: get: summary: Health check @@ -1579,6 +1643,20 @@ components: - data - pagination title: PaginatedResponse[FlagResponse] + PaginatedResponse_UserActivityResponse_: + properties: + data: + items: + $ref: '#/components/schemas/UserActivityResponse' + type: array + title: Data + pagination: + $ref: '#/components/schemas/Pagination' + type: object + required: + - data + - pagination + title: PaginatedResponse[UserActivityResponse] PaginatedResponse_VideoSummary_: properties: data: @@ -1743,6 +1821,31 @@ components: - email - userId title: User + UserActivityResponse: + properties: + userId: + type: string + format: uuid + title: Userid + activityType: + type: string + title: Activitytype + activityId: + type: string + format: uuid + title: Activityid + activityTimestamp: + type: string + format: date-time + title: Activitytimestamp + type: object + required: + - userId + - activityType + - activityId + - activityTimestamp + title: UserActivityResponse + description: API response representation for a single user activity item. UserCreateRequest: properties: firstName: @@ -1873,7 +1976,7 @@ components: description: anyOf: - type: string - maxLength: 1000 + maxLength: 2000 - type: 'null' title: Description tags: @@ -2127,7 +2230,7 @@ components: description: anyOf: - type: string - maxLength: 1000 + maxLength: 2000 - type: 'null' title: Description tags: diff --git a/tests/api/v1/endpoints/test_user_activity.py b/tests/api/v1/endpoints/test_user_activity.py new file mode 100644 index 0000000..56de12e --- /dev/null +++ b/tests/api/v1/endpoints/test_user_activity.py @@ -0,0 +1,144 @@ +"""API-level tests for the user activity endpoint.""" + +import pytest +from httpx import AsyncClient +from fastapi import status +from uuid import uuid4, uuid1 +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +from app.main import app +from app.core.config import settings +from app.models.user_activity import UserActivity + + +@pytest.fixture +def sample_activities(): + userid = uuid4() + now = datetime.now(timezone.utc) + return [ + UserActivity( + userid=userid, + day=now.strftime("%Y-%m-%d"), + activity_type="view", + activity_id=uuid1(), + activity_timestamp=now, + ), + UserActivity( + userid=userid, + day=now.strftime("%Y-%m-%d"), + activity_type="comment", + activity_id=uuid1(), + activity_timestamp=now, + ), + ], userid + + +@pytest.mark.asyncio +async def test_get_user_activity_success(sample_activities): + activities, userid = sample_activities + + with patch( + "app.api.v1.endpoints.user_activity.user_activity_service.list_user_activity", + new_callable=AsyncMock, + ) as mock_list: + mock_list.return_value = (activities, 2) + + async with AsyncClient(app=app, base_url="http://test") as ac: + resp = await ac.get( + f"{settings.API_V1_STR}/users/{userid}/activity", + ) + + assert resp.status_code == status.HTTP_200_OK + body = resp.json() + assert len(body["data"]) == 2 + assert body["pagination"]["totalItems"] == 2 + mock_list.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_user_activity_empty(): + with patch( + "app.api.v1.endpoints.user_activity.user_activity_service.list_user_activity", + new_callable=AsyncMock, + ) as mock_list: + mock_list.return_value = ([], 0) + + async with AsyncClient(app=app, base_url="http://test") as ac: + resp = await ac.get( + f"{settings.API_V1_STR}/users/{uuid4()}/activity", + ) + + assert resp.status_code == status.HTTP_200_OK + body = resp.json() + assert body["data"] == [] + assert body["pagination"]["totalItems"] == 0 + assert body["pagination"]["totalPages"] == 0 + + +@pytest.mark.asyncio +async def test_get_user_activity_with_type_filter(sample_activities): + activities, userid = sample_activities + view_only = [a for a in activities if a.activity_type == "view"] + + with patch( + "app.api.v1.endpoints.user_activity.user_activity_service.list_user_activity", + new_callable=AsyncMock, + ) as mock_list: + mock_list.return_value = (view_only, 1) + + async with AsyncClient(app=app, base_url="http://test") as ac: + resp = await ac.get( + f"{settings.API_V1_STR}/users/{userid}/activity?activity_type=view", + ) + + assert resp.status_code == status.HTTP_200_OK + body = resp.json() + assert len(body["data"]) == 1 + item = body["data"][0] + assert item["activityType"] == "view" + # Verify the filter was passed through + call_kwargs = mock_list.call_args[1] + assert call_kwargs["activity_type"] == "view" + + +@pytest.mark.asyncio +async def test_get_user_activity_invalid_uuid(): + """Invalid UUID in path returns 422.""" + async with AsyncClient(app=app, base_url="http://test") as ac: + resp = await ac.get(f"{settings.API_V1_STR}/users/not-a-uuid/activity") + assert resp.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +async def test_get_user_activity_pagination(sample_activities): + activities, userid = sample_activities + + with patch( + "app.api.v1.endpoints.user_activity.user_activity_service.list_user_activity", + new_callable=AsyncMock, + ) as mock_list: + mock_list.return_value = ([activities[1]], 2) + + async with AsyncClient(app=app, base_url="http://test") as ac: + resp = await ac.get( + f"{settings.API_V1_STR}/users/{userid}/activity?page=2&pageSize=1", + ) + + assert resp.status_code == status.HTTP_200_OK + body = resp.json() + assert len(body["data"]) == 1 + assert body["pagination"]["currentPage"] == 2 + assert body["pagination"]["pageSize"] == 1 + assert body["pagination"]["totalItems"] == 2 + assert body["pagination"]["totalPages"] == 2 + + +@pytest.mark.asyncio +async def test_get_user_activity_invalid_type_returns_422(): + """An unrecognised activity_type value must produce a 422 validation error.""" + async with AsyncClient(app=app, base_url="http://test") as ac: + resp = await ac.get( + f"{settings.API_V1_STR}/users/{uuid4()}/activity?activity_type=invalid" + ) + assert resp.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY diff --git a/tests/conftest.py b/tests/conftest.py index da683bd..819e451 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,8 @@ # --------------------------------------------------------------------------- # Stub for `astrapy` when the real package is not installed (CI / unit tests) # --------------------------------------------------------------------------- -if "astrapy" not in sys.modules: # pragma: no cover +import importlib.util as _importlib_util +if _importlib_util.find_spec("astrapy") is None: # pragma: no cover astrapy_stub = types.ModuleType("astrapy") db_stub = types.ModuleType("astrapy.db") diff --git a/tests/services/test_comment_service.py b/tests/services/test_comment_service.py index 607a68c..af5d64c 100644 --- a/tests/services/test_comment_service.py +++ b/tests/services/test_comment_service.py @@ -48,6 +48,10 @@ async def test_add_comment_success(viewer_user: User, sample_video: Video): patch( "app.services.comment_service.get_table", new_callable=AsyncMock ) as mock_get_table, + patch( + "app.services.comment_service.record_user_activity", + new_callable=AsyncMock, + ), ): mock_get_vid.return_value = sample_video mock_table_video = AsyncMock() @@ -68,6 +72,82 @@ async def test_add_comment_success(viewer_user: User, sample_video: Video): assert comment.text == request.text +@pytest.mark.asyncio +async def test_add_comment_calls_record_user_activity(viewer_user: User, sample_video: Video): + """add_comment_to_video calls record_user_activity with correct args.""" + request = CommentCreateRequest(text="Great stuff!") + sample_video.status = VideoStatusEnum.READY + + with ( + patch( + "app.services.comment_service.video_service.get_video_by_id", + new_callable=AsyncMock, + ) as mock_get_vid, + patch( + "app.services.comment_service.get_table", new_callable=AsyncMock + ) as mock_get_table, + patch( + "app.services.comment_service.record_user_activity", + new_callable=AsyncMock, + ) as mock_record_activity, + ): + mock_get_vid.return_value = sample_video + mock_table_video = AsyncMock() + mock_table_user = AsyncMock() + mock_get_table.side_effect = [mock_table_video, mock_table_user] + + comment = await comment_service.add_comment_to_video( + video_id=sample_video.videoid, + request=request, + current_user=viewer_user, + ) + + # Verify record_user_activity was called with the correct contract arguments + mock_record_activity.assert_awaited_once_with( + userid=viewer_user.userid, + activity_type="comment", + activity_id=comment.commentid, + ) + + +@pytest.mark.asyncio +async def test_add_comment_user_activity_failure_does_not_break(viewer_user: User, sample_video: Video): + """If record_user_activity raises, the comment operation still succeeds.""" + request = CommentCreateRequest(text="Still works!") + sample_video.status = VideoStatusEnum.READY + + with ( + patch( + "app.services.comment_service.video_service.get_video_by_id", + new_callable=AsyncMock, + ) as mock_get_vid, + patch( + "app.services.comment_service.get_table", new_callable=AsyncMock + ) as mock_get_table, + patch( + "app.services.comment_service.record_user_activity", + new_callable=AsyncMock, + ) as mock_record_activity, + ): + mock_get_vid.return_value = sample_video + mock_table_video = AsyncMock() + mock_table_user = AsyncMock() + mock_get_table.side_effect = [mock_table_video, mock_table_user] + mock_record_activity.side_effect = Exception("activity service error") + + # Should NOT raise despite record_user_activity failure + comment = await comment_service.add_comment_to_video( + video_id=sample_video.videoid, + request=request, + current_user=viewer_user, + ) + + # Comment tables were still written + assert mock_table_video.insert_one.call_count == 1 + assert mock_table_user.insert_one.call_count == 1 + assert comment.text == request.text + + @pytest.mark.asyncio async def test_add_comment_video_not_ready(viewer_user: User, sample_video: Video): request = CommentCreateRequest(text="Hello") diff --git a/tests/services/test_rating_service.py b/tests/services/test_rating_service.py index 9740f5a..14fdd59 100644 --- a/tests/services/test_rating_service.py +++ b/tests/services/test_rating_service.py @@ -22,13 +22,9 @@ def viewer_user() -> User: ) -@pytest.mark.asyncio -async def test_rate_video_new(viewer_user: User): - video_id = uuid4() - req = RatingCreateOrUpdateRequest(rating=4) - - ready_video = Video( - videoid=video_id, +def _make_video(video_id=None, **kwargs): + defaults = dict( + videoid=video_id or uuid4(), userid=uuid4(), added_date=datetime.now(timezone.utc), name="Title", @@ -37,6 +33,19 @@ async def test_rate_video_new(viewer_user: User): status=VideoStatusEnum.READY, title="Title", ) + defaults.update(kwargs) + return Video(**defaults) + + +# --------------------------------------------------------------------------- +# rate_video – new rating +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_rate_video_new(viewer_user: User): + video_id = uuid4() + req = RatingCreateOrUpdateRequest(rating=4) with ( patch( @@ -47,25 +56,27 @@ async def test_rate_video_new(viewer_user: User): "app.services.rating_service.get_table", new_callable=AsyncMock ) as mock_get_table, ): - mock_get_vid.return_value = ready_video + mock_get_vid.return_value = _make_video(video_id) ratings_tbl = AsyncMock() - videos_tbl = AsyncMock() - mock_get_table.side_effect = [ratings_tbl, videos_tbl] + summary_tbl = AsyncMock() + mock_get_table.return_value = summary_tbl ratings_tbl.find_one.return_value = None ratings_tbl.insert_one.return_value = {} - ratings_tbl.find = MagicMock(return_value=[]) - ratings_tbl.count_documents.return_value = 0 result = await rating_service.rate_video( video_id, req, viewer_user, db_table=ratings_tbl ) assert result.rating == 4 ratings_tbl.insert_one.assert_called_once() + # Should call $inc on summary table with counter=1, total=4 + summary_tbl.update_one.assert_awaited_once() + call_kwargs = summary_tbl.update_one.call_args.kwargs + assert call_kwargs["update"] == {"$inc": {"rating_counter": 1, "rating_total": 4}} # --------------------------------------------------------------------------- -# Existing rating update +# rate_video – update existing rating # --------------------------------------------------------------------------- @@ -74,23 +85,11 @@ async def test_rate_video_update(viewer_user: User): video_id = uuid4() req = RatingCreateOrUpdateRequest(rating=5) - ready_video = Video( - videoid=video_id, - userid=uuid4(), - added_date=datetime.now(timezone.utc), - name="Title", - location="http://a.b/c.mp4", - location_type=0, - status=VideoStatusEnum.READY, - title="Title", - ) - existing_doc = { "videoid": str(video_id), "userid": str(viewer_user.userid), "rating": 3, - "created_at": datetime.now(timezone.utc), - "updated_at": datetime.now(timezone.utc), + "rating_date": datetime.now(timezone.utc), } with ( @@ -101,15 +100,11 @@ async def test_rate_video_update(viewer_user: User): patch( "app.services.rating_service.get_table", new_callable=AsyncMock ) as mock_get_table, - patch( - "app.services.rating_service._update_video_aggregate_rating", - new_callable=AsyncMock, - ) as mock_update_agg, ): - mock_get_vid.return_value = ready_video + mock_get_vid.return_value = _make_video(video_id) ratings_tbl = AsyncMock() - videos_tbl = AsyncMock() - mock_get_table.side_effect = [ratings_tbl, videos_tbl] + summary_tbl = AsyncMock() + mock_get_table.return_value = summary_tbl ratings_tbl.find_one.return_value = existing_doc ratings_tbl.update_one.return_value = {} @@ -120,30 +115,99 @@ async def test_rate_video_update(viewer_user: User): ratings_tbl.update_one.assert_called_once() assert result.rating == req.rating - mock_update_agg.assert_awaited_once() + # Should call $inc on summary table with delta only (5 - 3 = 2) + summary_tbl.update_one.assert_awaited_once() + call_kwargs = summary_tbl.update_one.call_args.kwargs + assert call_kwargs["update"] == {"$inc": {"rating_total": 2}} # --------------------------------------------------------------------------- -# Summary fetch +# rate_video – user activity # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_get_video_ratings_summary_with_user(viewer_user: User): +async def test_rate_video_user_activity_failure_does_not_break(viewer_user: User): + """If user_activity insert fails, the rating still succeeds.""" video_id = uuid4() + req = RatingCreateOrUpdateRequest(rating=4) - video_obj = Video( - videoid=video_id, - userid=uuid4(), - added_date=datetime.now(timezone.utc), - name="Title", - location="http://a.b/c.mp4", - location_type=0, - status=VideoStatusEnum.READY, - title="Title", - averageRating=4.5, - totalRatingsCount=2, - ) + with ( + patch( + "app.services.rating_service.video_service.get_video_by_id", + new_callable=AsyncMock, + ) as mock_get_vid, + patch( + "app.services.rating_service.get_table", new_callable=AsyncMock + ) as mock_get_table, + patch( + "app.services.rating_service.record_user_activity", + new_callable=AsyncMock, + side_effect=Exception("DB error"), + ) as mock_record, + ): + mock_get_vid.return_value = _make_video(video_id) + ratings_tbl = AsyncMock() + summary_tbl = AsyncMock() + mock_get_table.return_value = summary_tbl + ratings_tbl.find_one.return_value = None + ratings_tbl.insert_one.return_value = {} + + result = await rating_service.rate_video(video_id, req, viewer_user, db_table=ratings_tbl) + assert result.rating == 4 + ratings_tbl.insert_one.assert_called_once() + mock_record.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_rate_video_new_calls_record_user_activity(viewer_user: User): + """New rating triggers record_user_activity with activity_type='rate'.""" + video_id = uuid4() + req = RatingCreateOrUpdateRequest(rating=4) + + with ( + patch( + "app.services.rating_service.video_service.get_video_by_id", + new_callable=AsyncMock, + ) as mock_get_vid, + patch( + "app.services.rating_service.get_table", new_callable=AsyncMock + ) as mock_get_table, + patch( + "app.services.user_activity_service.get_table", new_callable=AsyncMock + ) as mock_ua_get_table, + ): + mock_get_vid.return_value = _make_video(video_id) + ratings_tbl = AsyncMock() + summary_tbl = AsyncMock() + mock_get_table.return_value = summary_tbl + + ratings_tbl.find_one.return_value = None + ratings_tbl.insert_one.return_value = {} + + mock_ua_table = AsyncMock() + mock_ua_get_table.return_value = mock_ua_table + + await rating_service.rate_video(video_id, req, viewer_user, db_table=ratings_tbl) + + mock_ua_table.insert_one.assert_awaited_once() + doc = mock_ua_table.insert_one.call_args.args[0] if mock_ua_table.insert_one.call_args.args else mock_ua_table.insert_one.call_args.kwargs + assert doc["userid"] == str(viewer_user.userid) + assert doc["activity_type"] == "rate" + + +@pytest.mark.asyncio +async def test_rate_video_update_calls_record_user_activity(viewer_user: User): + """Updated rating also triggers record_user_activity.""" + video_id = uuid4() + req = RatingCreateOrUpdateRequest(rating=5) + + existing_doc = { + "videoid": str(video_id), + "userid": str(viewer_user.userid), + "rating": 3, + "rating_date": datetime.now(timezone.utc), + } with ( patch( @@ -153,16 +217,120 @@ async def test_get_video_ratings_summary_with_user(viewer_user: User): patch( "app.services.rating_service.get_table", new_callable=AsyncMock ) as mock_get_table, + patch( + "app.services.rating_service._update_video_aggregate_rating", + new_callable=AsyncMock, + ) as mock_update_agg, + patch( + "app.services.user_activity_service.get_table", new_callable=AsyncMock + ) as mock_ua_get_table, ): - mock_get_vid.return_value = video_obj + mock_get_vid.return_value = _make_video(video_id) ratings_tbl = AsyncMock() - mock_get_table.return_value = ratings_tbl - ratings_tbl.find_one.return_value = {"rating": 5} + mock_get_table.return_value = AsyncMock() + + ratings_tbl.find_one.return_value = existing_doc + ratings_tbl.update_one.return_value = {} + + mock_ua_table = AsyncMock() + mock_ua_get_table.return_value = mock_ua_table + + await rating_service.rate_video(video_id, req, viewer_user, db_table=ratings_tbl) + + mock_ua_table.insert_one.assert_awaited_once() + doc = mock_ua_table.insert_one.call_args.args[0] if mock_ua_table.insert_one.call_args.args else mock_ua_table.insert_one.call_args.kwargs + assert doc["activity_type"] == "rate" + + +# --------------------------------------------------------------------------- +# Summary fetch +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_video_ratings_summary_with_user(viewer_user: User): + video_id = uuid4() + + summary_tbl = AsyncMock() + summary_tbl.find_one.return_value = {"rating_counter": 2, "rating_total": 9} + + ratings_tbl = AsyncMock() + ratings_tbl.find_one.return_value = {"rating": 5} + + with patch( + "app.services.rating_service.video_service.get_video_by_id", + new_callable=AsyncMock, + ) as mock_get_vid: + mock_get_vid.return_value = _make_video(video_id) summary = await rating_service.get_video_ratings_summary( - video_id, current_user_id=viewer_user.userid, ratings_db_table=ratings_tbl + video_id, + current_user_id=viewer_user.userid, + ratings_db_table=ratings_tbl, + summary_db_table=summary_tbl, ) assert summary.averageRating == 4.5 assert summary.totalRatingsCount == 2 assert summary.currentUserRating == 5 + + +@pytest.mark.asyncio +async def test_get_video_ratings_summary_no_ratings(): + """Returns None/0 when no counter doc exists in the summary table.""" + video_id = uuid4() + + summary_tbl = AsyncMock() + summary_tbl.find_one.return_value = None + + with patch( + "app.services.rating_service.video_service.get_video_by_id", + new_callable=AsyncMock, + ) as mock_get_vid: + mock_get_vid.return_value = _make_video(video_id) + + summary = await rating_service.get_video_ratings_summary( + video_id, + summary_db_table=summary_tbl, + ) + + assert summary.averageRating is None + assert summary.totalRatingsCount == 0 + assert summary.currentUserRating is None + + +# --------------------------------------------------------------------------- +# _update_video_aggregate_rating – unit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_update_aggregate_new_rating_counter_increment(): + """New rating: $inc counter by 1 and total by rating value.""" + video_id = uuid4() + summary_tbl = AsyncMock() + + await rating_service._update_video_aggregate_rating( + video_id, new_rating=4, old_rating=None, summary_db_table=summary_tbl + ) + + summary_tbl.update_one.assert_awaited_once() + call_kwargs = summary_tbl.update_one.call_args.kwargs + assert call_kwargs["update"] == {"$inc": {"rating_counter": 1, "rating_total": 4}} + assert call_kwargs["upsert"] is True + + +@pytest.mark.asyncio +async def test_update_aggregate_updated_rating_delta(): + """Updated rating: $inc total by delta only, counter unchanged.""" + video_id = uuid4() + summary_tbl = AsyncMock() + + await rating_service._update_video_aggregate_rating( + video_id, new_rating=5, old_rating=3, summary_db_table=summary_tbl + ) + + summary_tbl.update_one.assert_awaited_once() + call_kwargs = summary_tbl.update_one.call_args.kwargs + assert call_kwargs["update"] == {"$inc": {"rating_total": 2}} + assert call_kwargs["upsert"] is True diff --git a/tests/services/test_user_activity_service.py b/tests/services/test_user_activity_service.py new file mode 100644 index 0000000..651328c --- /dev/null +++ b/tests/services/test_user_activity_service.py @@ -0,0 +1,437 @@ +"""Unit tests for the user_activity_service module.""" + +import pytest +from unittest.mock import AsyncMock, patch +from uuid import uuid4, uuid1 +from datetime import datetime, timezone + +from app.services.user_activity_service import ( + record_user_activity, + list_user_activity, + ANONYMOUS_USER_ID, + USER_ACTIVITY_TABLE_NAME, + MAX_ACTIVITY_ROWS, +) +from app.models.user_activity import UserActivity + + +@pytest.mark.asyncio +async def test_record_user_activity_auto_activity_id(): + """record_user_activity generates a uuid1 when activity_id is not provided.""" + mock_table = AsyncMock() + userid = uuid4() + + await record_user_activity( + userid=userid, + activity_type="view", + db_table=mock_table, + ) + + mock_table.insert_one.assert_awaited_once() + doc = mock_table.insert_one.call_args[1].get( + "document", mock_table.insert_one.call_args[0][0] + if mock_table.insert_one.call_args[0] else mock_table.insert_one.call_args[1] + ) + # insert_one is called with a positional dict + insert_call = mock_table.insert_one.call_args + if insert_call.args: + doc = insert_call.args[0] + else: + doc = insert_call.kwargs + + assert doc["userid"] == str(userid) + assert doc["activity_type"] == "view" + assert doc["activity_id"] is not None + assert doc["day"] == datetime.now(timezone.utc).strftime("%Y-%m-%d") + + +@pytest.mark.asyncio +async def test_record_user_activity_timestamp_is_isoformat_string(): + """record_user_activity serializes activity_timestamp as an ISO 8601 string.""" + mock_table = AsyncMock() + + await record_user_activity( + userid=uuid4(), + activity_type="view", + db_table=mock_table, + ) + + insert_call = mock_table.insert_one.call_args + doc = insert_call.args[0] if insert_call.args else insert_call.kwargs + assert isinstance(doc["activity_timestamp"], str) + # Verify it is a valid ISO format string by parsing it back + parsed = datetime.fromisoformat(doc["activity_timestamp"]) + assert parsed.tzinfo is not None + + +@pytest.mark.asyncio +async def test_record_user_activity_explicit_activity_id(): + """record_user_activity uses the provided activity_id.""" + mock_table = AsyncMock() + userid = uuid4() + explicit_id = uuid1() + + await record_user_activity( + userid=userid, + activity_type="comment", + activity_id=explicit_id, + db_table=mock_table, + ) + + mock_table.insert_one.assert_awaited_once() + insert_call = mock_table.insert_one.call_args + doc = insert_call.args[0] if insert_call.args else insert_call.kwargs + assert doc["activity_id"] == str(explicit_id) + assert doc["activity_type"] == "comment" + + +@pytest.mark.asyncio +async def test_record_user_activity_anonymous(): + """record_user_activity works with the anonymous sentinel UUID.""" + mock_table = AsyncMock() + + await record_user_activity( + userid=ANONYMOUS_USER_ID, + activity_type="view", + db_table=mock_table, + ) + + mock_table.insert_one.assert_awaited_once() + insert_call = mock_table.insert_one.call_args + doc = insert_call.args[0] if insert_call.args else insert_call.kwargs + assert doc["userid"] == str(ANONYMOUS_USER_ID) + + +@pytest.mark.asyncio +async def test_record_user_activity_fetches_table(): + """record_user_activity calls get_table when db_table is None.""" + mock_table = AsyncMock() + + with patch( + "app.services.user_activity_service.get_table", + new_callable=AsyncMock, + return_value=mock_table, + ) as mock_get_table: + await record_user_activity( + userid=uuid4(), + activity_type="rate", + ) + + mock_get_table.assert_awaited_once_with(USER_ACTIVITY_TABLE_NAME) + mock_table.insert_one.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_list_user_activity_fetches_table(): + """list_user_activity calls get_table when db_table is None.""" + def mock_find(filter=None, **kwargs): + cursor = AsyncMock() + cursor.to_list = AsyncMock(return_value=[]) + return cursor + + mock_table = AsyncMock() + mock_table.find = mock_find + + with patch( + "app.services.user_activity_service.get_table", + new_callable=AsyncMock, + return_value=mock_table, + ) as mock_get_table: + await list_user_activity(userid=uuid4(), page=1, page_size=10) + mock_get_table.assert_awaited_once_with(USER_ACTIVITY_TABLE_NAME) + + +@pytest.mark.asyncio +async def test_list_user_activity_returns_paginated_results(): + """list_user_activity returns a page of results and total count.""" + userid = uuid4() + now = datetime.now(timezone.utc) + today_str = now.strftime("%Y-%m-%d") + + rows = [ + { + "userid": str(userid), + "day": today_str, + "activity_type": "view", + "activity_id": str(uuid1()), + "activity_timestamp": now.isoformat(), + }, + { + "userid": str(userid), + "day": today_str, + "activity_type": "comment", + "activity_id": str(uuid1()), + "activity_timestamp": now.isoformat(), + }, + ] + + def mock_find(filter=None, **kwargs): + cursor = AsyncMock() + if filter and filter.get("day") == today_str: + cursor.to_list = AsyncMock(return_value=rows) + else: + cursor.to_list = AsyncMock(return_value=[]) + return cursor + + mock_table = AsyncMock() + mock_table.find = mock_find + + activities, total = await list_user_activity( + userid=userid, + page=1, + page_size=10, + db_table=mock_table, + ) + + assert total == 2 + assert len(activities) == 2 + assert all(isinstance(a, UserActivity) for a in activities) + + +@pytest.mark.asyncio +async def test_list_user_activity_with_type_filter(): + """list_user_activity filters by activity_type when provided.""" + userid = uuid4() + now = datetime.now(timezone.utc) + today_str = now.strftime("%Y-%m-%d") + + view_row = { + "userid": str(userid), + "day": today_str, + "activity_type": "view", + "activity_id": str(uuid1()), + "activity_timestamp": now.isoformat(), + } + + def mock_find(filter=None, **kwargs): + cursor = AsyncMock() + if filter and filter.get("day") == today_str and filter.get("activity_type") == "view": + cursor.to_list = AsyncMock(return_value=[view_row]) + else: + cursor.to_list = AsyncMock(return_value=[]) + return cursor + + mock_table = AsyncMock() + mock_table.find = mock_find + + activities, total = await list_user_activity( + userid=userid, + page=1, + page_size=10, + activity_type="view", + db_table=mock_table, + ) + + assert total == 1 + assert activities[0].activity_type == "view" + + +@pytest.mark.asyncio +async def test_list_user_activity_empty(): + """list_user_activity returns empty results for unknown users.""" + def mock_find(filter=None, **kwargs): + cursor = AsyncMock() + cursor.to_list = AsyncMock(return_value=[]) + return cursor + + mock_table = AsyncMock() + mock_table.find = mock_find + + activities, total = await list_user_activity( + userid=uuid4(), + page=1, + page_size=10, + db_table=mock_table, + ) + + assert total == 0 + assert activities == [] + + +@pytest.mark.asyncio +async def test_list_user_activity_pagination_page_2(): + """list_user_activity correctly returns page 2.""" + userid = uuid4() + now = datetime.now(timezone.utc) + today_str = now.strftime("%Y-%m-%d") + + rows = [ + { + "userid": str(userid), + "day": today_str, + "activity_type": "view", + "activity_id": str(uuid1()), + "activity_timestamp": now.isoformat(), + } + for _ in range(3) + ] + + def mock_find(filter=None, **kwargs): + cursor = AsyncMock() + if filter and filter.get("day") == today_str: + cursor.to_list = AsyncMock(return_value=rows) + else: + cursor.to_list = AsyncMock(return_value=[]) + return cursor + + mock_table = AsyncMock() + mock_table.find = mock_find + + activities, total = await list_user_activity( + userid=userid, + page=2, + page_size=2, + db_table=mock_table, + ) + + assert total == 3 + assert len(activities) == 1 # page 2 with page_size=2, only 1 left + + +@pytest.mark.asyncio +async def test_list_user_activity_error_in_partition_is_skipped(): + """list_user_activity skips a failing partition and returns remaining results.""" + userid = uuid4() + now = datetime.now(timezone.utc) + today_str = now.strftime("%Y-%m-%d") + + good_row = { + "userid": str(userid), + "day": today_str, + "activity_type": "view", + "activity_id": str(uuid1()), + "activity_timestamp": now.isoformat(), + } + + call_count = 0 + + def mock_find(filter=None, **kwargs): + nonlocal call_count + cursor = AsyncMock() + if filter and filter.get("day") == today_str: + # Today's partition returns good data + cursor.to_list = AsyncMock(return_value=[good_row]) + else: + # All other partitions raise an error + cursor.to_list = AsyncMock(side_effect=Exception("DB timeout")) + call_count += 1 + return cursor + + mock_table = AsyncMock() + mock_table.find = mock_find + + activities, total = await list_user_activity( + userid=userid, + page=1, + page_size=10, + db_table=mock_table, + ) + + # Only the good row should come through; erring partitions are skipped + assert total == 1 + assert len(activities) == 1 + assert activities[0].activity_type == "view" + + +@pytest.mark.asyncio +async def test_record_user_activity_db_failure_graceful(): + """record_user_activity does not raise on DB failure — it logs a warning instead.""" + mock_table = AsyncMock() + mock_table.insert_one.side_effect = Exception("DB connection lost") + + with patch("app.services.user_activity_service.logger") as mock_logger: + # Should NOT raise + await record_user_activity( + userid=uuid4(), + activity_type="view", + db_table=mock_table, + ) + + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert "Failed to record user activity" in warning_msg + + +@pytest.mark.asyncio +async def test_record_user_activity_each_activity_type(): + """record_user_activity accepts all activity types: view, comment, rate.""" + for activity_type in ("view", "comment", "rate"): + mock_table = AsyncMock() + await record_user_activity( + userid=uuid4(), + activity_type=activity_type, + db_table=mock_table, + ) + mock_table.insert_one.assert_awaited_once() + insert_call = mock_table.insert_one.call_args + doc = insert_call.args[0] if insert_call.args else insert_call.kwargs + assert doc["activity_type"] == activity_type + + +@pytest.mark.asyncio +async def test_list_user_activity_per_day_limit_is_bounded(): + """list_user_activity passes a per-partition limit of MAX_ACTIVITY_ROWS//30, not MAX_ACTIVITY_ROWS. + + With 30 partitions the naive limit would allow up to 30 x MAX_ACTIVITY_ROWS rows + to be fetched before the post-gather trim. The bounded limit keeps the total + near MAX_ACTIVITY_ROWS. + """ + captured_limits: list[int] = [] + + def mock_find(filter=None, limit=None, **kwargs): + if limit is not None: + captured_limits.append(limit) + cursor = AsyncMock() + cursor.to_list = AsyncMock(return_value=[]) + return cursor + + mock_table = AsyncMock() + mock_table.find = mock_find + + await list_user_activity(userid=uuid4(), page=1, page_size=10, db_table=mock_table) + + expected_limit = max(1, MAX_ACTIVITY_ROWS // 30) + assert len(captured_limits) == 30, "Expected find() to be called for all 30 partitions" + assert all( + lim == expected_limit for lim in captured_limits + ), f"All per-day limits should be {expected_limit}, got: {set(captured_limits)}" + # Confirm the per-day limit is strictly less than MAX_ACTIVITY_ROWS so that + # 30 partitions cannot return more than ~MAX_ACTIVITY_ROWS rows total. + assert expected_limit < MAX_ACTIVITY_ROWS + + +@pytest.mark.asyncio +async def test_list_user_activity_page_beyond_data(): + """Requesting a page beyond available data returns empty list with correct total.""" + userid = uuid4() + now = datetime.now(timezone.utc) + today_str = now.strftime("%Y-%m-%d") + + rows = [ + { + "userid": str(userid), + "day": today_str, + "activity_type": "view", + "activity_id": str(uuid1()), + "activity_timestamp": now, + } + for _ in range(3) + ] + + def mock_find(filter=None, **kwargs): + cursor = AsyncMock() + if filter and filter.get("day") == today_str: + cursor.to_list = AsyncMock(return_value=rows) + else: + cursor.to_list = AsyncMock(return_value=[]) + return cursor + + mock_table = AsyncMock() + mock_table.find = mock_find + + activities, total = await list_user_activity( + userid=userid, page=99, page_size=10, db_table=mock_table, + ) + + assert total == 3 + assert activities == [] diff --git a/tests/services/test_video_service.py b/tests/services/test_video_service.py index fc09db6..ff2fb80 100644 --- a/tests/services/test_video_service.py +++ b/tests/services/test_video_service.py @@ -227,32 +227,153 @@ async def test_update_video_details_no_changes(): @pytest.mark.asyncio async def test_record_video_view_success(): + """View counter uses read-modify-write ($set) — the Table API does not support $inc.""" vid = uuid4() - # Mock the playback stats table passed explicitly mock_stats_table = AsyncMock() - mock_stats_table.update_one.return_value = AsyncMock() - - # Mock the activity table returned via get_table + mock_stats_table.find_one.return_value = {"views": 5} mock_activity_table = AsyncMock() - with patch( - "app.services.video_service.get_table", new_callable=AsyncMock - ) as mock_get_table: - # First call inside record_video_view is for VIDEO_PLAYBACK_STATS_TABLE_NAME - # but we already pass mock_stats_table, so get_table will be used only once + with ( + patch( + "app.services.video_service.get_table", new_callable=AsyncMock + ) as mock_get_table, + patch( + "app.services.user_activity_service.record_user_activity", + new_callable=AsyncMock, + ) as mock_record_user_activity, + ): mock_get_table.return_value = mock_activity_table - await video_service.record_video_view(vid, mock_stats_table) + await video_service.record_video_view(vid, db_table=mock_stats_table) - # Validate stats table increment + # Must read current value first, then write the incremented value + mock_stats_table.find_one.assert_awaited_once() mock_stats_table.update_one.assert_called_once_with( - filter={"videoid": vid}, update={"$inc": {"views": 1}}, upsert=True + filter={"videoid": vid}, update={"$set": {"views": 6}}, upsert=True ) # Validate activity table log mock_activity_table.insert_one.assert_called_once() + # Validate user activity was tracked (regression guard) + mock_record_user_activity.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_record_video_view_no_existing_row(): + """When the video has no existing views row, count starts at 1.""" + vid = uuid4() + + mock_stats_table = AsyncMock() + mock_stats_table.find_one.return_value = None # no row yet + mock_activity_table = AsyncMock() + + with ( + patch("app.services.video_service.get_table", new_callable=AsyncMock) as mock_get_table, + patch("app.services.user_activity_service.record_user_activity", new_callable=AsyncMock), + ): + mock_get_table.return_value = mock_activity_table + await video_service.record_video_view(vid, db_table=mock_stats_table) + + mock_stats_table.update_one.assert_called_once_with( + filter={"videoid": vid}, update={"$set": {"views": 1}}, upsert=True + ) + + +@pytest.mark.asyncio +async def test_record_video_view_authenticated_user_activity(): + """Authenticated view calls record_user_activity with real user ID.""" + vid = uuid4() + viewer_id = uuid4() + mock_stats_table = AsyncMock() + mock_stats_table.find_one.return_value = None + mock_activity_table = AsyncMock() + + with ( + patch( + "app.services.video_service.get_table", new_callable=AsyncMock + ) as mock_get_table, + patch( + "app.services.user_activity_service.record_user_activity", + new_callable=AsyncMock, + ) as mock_record_user_activity, + ): + mock_get_table.return_value = mock_activity_table + + await video_service.record_video_view( + vid, viewer_user_id=viewer_id, db_table=mock_stats_table + ) + + # record_user_activity should be called with the real user ID; no activity_id + # because view events auto-generate a uuid1() to satisfy the TimeUUID column + mock_record_user_activity.assert_awaited_once_with( + userid=viewer_id, + activity_type="view", + ) + + +@pytest.mark.asyncio +async def test_record_video_view_anonymous_user_activity(): + """Anonymous view calls record_user_activity with nil UUID sentinel.""" + from app.services.user_activity_service import ANONYMOUS_USER_ID + + vid = uuid4() + mock_stats_table = AsyncMock() + mock_stats_table.find_one.return_value = None + mock_activity_table = AsyncMock() + + with ( + patch( + "app.services.video_service.get_table", new_callable=AsyncMock + ) as mock_get_table, + patch( + "app.services.user_activity_service.record_user_activity", + new_callable=AsyncMock, + ) as mock_record_user_activity, + ): + mock_get_table.return_value = mock_activity_table + + await video_service.record_video_view( + vid, viewer_user_id=None, db_table=mock_stats_table + ) + + # record_user_activity should be called with the anonymous sentinel UUID; no activity_id + # because view events auto-generate a uuid1() to satisfy the TimeUUID column + mock_record_user_activity.assert_awaited_once_with( + userid=ANONYMOUS_USER_ID, + activity_type="view", + ) + + +@pytest.mark.asyncio +async def test_record_video_view_user_activity_failure_does_not_break(): + """If record_user_activity raises, the video view still succeeds.""" + vid = uuid4() + mock_stats_table = AsyncMock() + mock_stats_table.find_one.return_value = None + mock_activity_table = AsyncMock() + + with ( + patch( + "app.services.video_service.get_table", new_callable=AsyncMock + ) as mock_get_table, + patch( + "app.services.user_activity_service.record_user_activity", + new_callable=AsyncMock, + ) as mock_record_user_activity, + ): + mock_get_table.return_value = mock_activity_table + mock_record_user_activity.side_effect = Exception("DB error") + + # Should NOT raise despite user_activity failure + await video_service.record_video_view( + vid, viewer_user_id=uuid4(), db_table=mock_stats_table + ) + + # video_activity insert still happened + mock_activity_table.insert_one.assert_called_once() + # ------------------------------------------------------------ # list_latest_videos (delegate to generic) – just verify query call @@ -276,7 +397,7 @@ async def test_list_latest_videos(): mock_list_with_query.assert_called_once_with( {}, 1, - 3, + 10, sort_options={"added_date": -1}, db_table=mock_db, source_table_name=video_service.VIDEOS_TABLE_NAME, @@ -293,20 +414,18 @@ async def test_list_latest_videos(): @pytest.mark.asyncio async def test_search_videos_by_keyword(): mock_db = AsyncMock() - mock_db.find.return_value = [] - mock_db.count_documents.return_value = 0 with patch( - "app.services.video_service.list_videos_with_query", + "app.services.video_service.search_videos_by_semantic", new_callable=AsyncMock, - ) as mock_list_with_query: - mock_list_with_query.return_value = ([], 0) + ) as mock_semantic: + mock_semantic.return_value = ([], 0) summaries, total = await video_service.search_videos_by_keyword( query="test", page=1, page_size=10, db_table=mock_db ) - mock_list_with_query.assert_called_once() + mock_semantic.assert_called_once() assert summaries == [] assert total == 0 diff --git a/tests/utils/test_db_helpers.py b/tests/utils/test_db_helpers.py new file mode 100644 index 0000000..f47eca8 --- /dev/null +++ b/tests/utils/test_db_helpers.py @@ -0,0 +1,147 @@ +"""Tests for app.utils.db_helpers.safe_count.""" + +from __future__ import annotations + +import logging + +import pytest +from unittest.mock import AsyncMock + +from astrapy.exceptions.data_api_exceptions import DataAPIResponseException # type: ignore[import] + +from app.utils.db_helpers import safe_count, suppress_astrapy_warnings + +_ASTRAPY_LOGGER = "astrapy.utils.api_commander" + + +def _make_exc(error_code: str) -> DataAPIResponseException: + """Build a DataAPIResponseException whose str() contains the given error code.""" + return DataAPIResponseException( + error_code, + command={}, + raw_response={"errors": [{"errorCode": error_code, "message": error_code}]}, + error_descriptors=[], + warning_descriptors=[], + ) + + +# --------------------------------------------------------------------------- +# Correct fallback behaviour +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_safe_count_returns_fallback_for_unsupported_table_command(): + """Returns fallback_len when the table doesn't support countDocuments.""" + db_table = AsyncMock() + db_table.count_documents.side_effect = _make_exc("UNSUPPORTED_TABLE_COMMAND") + + result = await safe_count(db_table, query_filter={}, fallback_len=7) + + assert result == 7 + + +@pytest.mark.asyncio +async def test_safe_count_returns_actual_count_when_supported(): + """Returns the real count when count_documents succeeds.""" + db_table = AsyncMock() + db_table.count_documents.return_value = 42 + + result = await safe_count(db_table, query_filter={"userid": "abc"}, fallback_len=3) + + assert result == 42 + + +@pytest.mark.asyncio +async def test_safe_count_propagates_unexpected_data_api_error(): + """Re-raises DataAPIResponseException for unrelated error codes.""" + db_table = AsyncMock() + db_table.count_documents.side_effect = _make_exc("SOME_OTHER_ERROR") + + with pytest.raises(DataAPIResponseException): + await safe_count(db_table, query_filter={}, fallback_len=5) + + +# --------------------------------------------------------------------------- +# Warning suppression — simulates astrapy's real behaviour (log then raise) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_safe_count_does_not_log_warning_for_unsupported_table_command(caplog): + """UNSUPPORTED_TABLE_COMMAND must NOT produce a WARNING in the logs. + + astrapy logs a WARNING from api_commander *before* raising the exception. + We simulate that by having the mock emit the warning then raise, matching + what happens in production when count_documents hits a CQL table. + """ + astrapy_logger = logging.getLogger(_ASTRAPY_LOGGER) + exc = _make_exc("UNSUPPORTED_TABLE_COMMAND") + + async def _fake_count_documents(*args, **kwargs): + astrapy_logger.warning("APICommander about to raise from: UNSUPPORTED_TABLE_COMMAND") + raise exc + + db_table = AsyncMock() + db_table.count_documents = _fake_count_documents + + with caplog.at_level(logging.WARNING, logger=_ASTRAPY_LOGGER): + await safe_count(db_table, query_filter={}, fallback_len=3) + + unsupported_warnings = [ + r for r in caplog.records + if "UNSUPPORTED_TABLE_COMMAND" in r.getMessage() + and r.levelno >= logging.WARNING + ] + assert unsupported_warnings == [], ( + "safe_count should suppress astrapy's UNSUPPORTED_TABLE_COMMAND warning" + ) + + +# --------------------------------------------------------------------------- +# suppress_astrapy_warnings context manager +# --------------------------------------------------------------------------- + + +def test_suppress_astrapy_warnings_suppresses_matching_codes(caplog): + """Warnings matching any of the specified codes are suppressed.""" + astrapy_logger = logging.getLogger(_ASTRAPY_LOGGER) + + with caplog.at_level(logging.WARNING, logger=_ASTRAPY_LOGGER): + with suppress_astrapy_warnings("ZERO_FILTER_OPERATIONS", "IN_MEMORY_SORTING"): + astrapy_logger.warning("ZERO_FILTER_OPERATIONS on table videos") + astrapy_logger.warning("IN_MEMORY_SORTING due to non-partition key") + + matching = [ + r for r in caplog.records + if ("ZERO_FILTER_OPERATIONS" in r.getMessage() + or "IN_MEMORY_SORTING" in r.getMessage()) + and r.levelno >= logging.WARNING + ] + assert matching == [], "suppress_astrapy_warnings should suppress matching warnings" + + +def test_suppress_astrapy_warnings_passes_unrelated_warnings(caplog): + """Warnings that don't match any specified code still appear.""" + astrapy_logger = logging.getLogger(_ASTRAPY_LOGGER) + + with caplog.at_level(logging.WARNING, logger=_ASTRAPY_LOGGER): + with suppress_astrapy_warnings("ZERO_FILTER_OPERATIONS"): + astrapy_logger.warning("SOMETHING_ELSE happened") + + unrelated = [ + r for r in caplog.records + if "SOMETHING_ELSE" in r.getMessage() + ] + assert len(unrelated) == 1, "Unrelated warnings must not be suppressed" + + +def test_suppress_astrapy_warnings_removes_filter_after_exit(): + """The filter is removed from the logger when the context exits.""" + astrapy_logger = logging.getLogger(_ASTRAPY_LOGGER) + baseline = len(astrapy_logger.filters) + + with suppress_astrapy_warnings("ZERO_FILTER_OPERATIONS"): + assert len(astrapy_logger.filters) == baseline + 1 + + assert len(astrapy_logger.filters) == baseline