diff --git a/contentcuration/kolibri_public/tests/test_channelmetadata_viewset.py b/contentcuration/kolibri_public/tests/test_channelmetadata_viewset.py index 616752d8e6..220a22a56c 100644 --- a/contentcuration/kolibri_public/tests/test_channelmetadata_viewset.py +++ b/contentcuration/kolibri_public/tests/test_channelmetadata_viewset.py @@ -4,7 +4,11 @@ from kolibri_public.tests.utils.mixer import KolibriPublicMixer from le_utils.constants.labels.subjects import SUBJECTSLIST +from contentcuration.models import Channel +from contentcuration.models import ChannelVersion +from contentcuration.models import ContentNode from contentcuration.models import Country +from contentcuration.models import SecretToken from contentcuration.tests import testdata from contentcuration.tests.base import StudioAPITestCase from contentcuration.tests.helpers import reverse_with_query @@ -159,3 +163,193 @@ def test_filter_by_countries(self): self.assertCountEqual(response1["countries"], ["C1", "C3"]) self.assertCountEqual(response2["countries"], ["C1", "C2", "C3"]) + + +class ChannelMetadataTokenFilterTestCase(StudioAPITestCase): + """ + Test cases for token-based filtering in ChannelMetadataViewSet. + """ + + def setUp(self): + super().setUp() + self.user = testdata.user("any@user.com") + self.client.force_authenticate(self.user) + self.categories = [ + SUBJECTSLIST[0], + SUBJECTSLIST[1], + ] + + def _create_channel_with_main_tree(self, mixer): + """ + Helper method to create a Channel with a published main_tree. + """ + root_node = ContentNode.objects.create(published=True) + channel = Channel.objects.create( + actor_id=self.user.id, + deleted=False, + public=False, + main_tree=root_node, + ) + public_root_node = mixer.blend("kolibri_public.ContentNode") + return channel, public_root_node + + def test_filter_by_channel_token(self): + """ + Test that filtering by a channel's secret_token returns the correct channel. + """ + mixer = KolibriPublicMixer() + + channel, public_root_node = self._create_channel_with_main_tree(mixer) + token = SecretToken.objects.create(token="testchanneltokenabc", is_primary=True) + channel.secret_tokens.add(token) + + metadata = mixer.blend( + ChannelMetadata, id=channel.id, root=public_root_node, public=False + ) + + response = self.client.get( + reverse_with_query( + "publicchannel-list", + query={"token": "testchanneltokenabc"}, + ), + ) + + self.assertEqual(response.status_code, 200, response.content) + self.assertEqual(len(response.data), 1) + self.assertEqual(UUID(response.data[0]["id"]), UUID(metadata.id)) + self.assertEqual(response.data[0]["countries"], []) + + def test_filter_by_channel_version_token(self): + """ + Test that filtering by a ChannelVersion's secret_token returns the correct channel + with version-specific data. + """ + mixer = KolibriPublicMixer() + + channel, public_root_node = self._create_channel_with_main_tree(mixer) + channel.version = 5 + channel.save() + + token = SecretToken.objects.create( + token="testversiontokenxyz", is_primary=False + ) + ChannelVersion.objects.create( + channel=channel, + version=3, + secret_token=token, + size=123456789, + resource_count=100, + included_languages=["en", "es"], + included_categories=self.categories, + ) + + metadata = mixer.blend( + ChannelMetadata, + id=channel.id, + root=public_root_node, + published_size=999999999, + total_resource_count=200, + public=False, + ) + + response = self.client.get( + reverse_with_query( + "publicchannel-list", + query={"token": "testversiontokenxyz"}, + ), + ) + + self.assertEqual(response.status_code, 200, response.content) + self.assertEqual(len(response.data), 1) + self.assertEqual(UUID(response.data[0]["id"]), UUID(metadata.id)) + self.assertEqual(response.data[0]["published_size"], 123456789) + self.assertEqual(response.data[0]["total_resource_count"], 100) + self.assertCountEqual(response.data[0]["included_languages"], ["en", "es"]) + self.assertCountEqual(response.data[0]["categories"], self.categories) + self.assertEqual(response.data[0]["countries"], []) + + def test_token_filter_disabled_when_token_not_provided(self): + """ + Test that regular filters still work when no token is provided. + """ + mixer = KolibriPublicMixer() + + metadata1 = mixer.blend(ChannelMetadata, public=True) + mixer.blend(ChannelMetadata, public=False) + + response = self.client.get( + reverse_with_query( + "publicchannel-list", + query={"public": "true"}, + ), + ) + + self.assertEqual(response.status_code, 200, response.content) + self.assertEqual(len(response.data), 1) + self.assertEqual(str(UUID(response.data[0]["id"])), str(metadata1.id)) + + def test_token_filter_disables_other_filters(self): + """ + Test that when a token is provided, other query parameters are ignored. + """ + mixer = KolibriPublicMixer() + + channel, public_root_node = self._create_channel_with_main_tree(mixer) + token = SecretToken.objects.create( + token="testignorefilterstoken", is_primary=True + ) + channel.secret_tokens.add(token) + + metadata = mixer.blend( + ChannelMetadata, id=channel.id, root=public_root_node, public=False + ) + + response = self.client.get( + reverse_with_query( + "publicchannel-list", + query={"token": "testignorefilterstoken", "public": "true"}, + ), + ) + + self.assertEqual(response.status_code, 200, response.content) + self.assertEqual(len(response.data), 1) + self.assertEqual(UUID(response.data[0]["id"]), UUID(metadata.id)) + + def test_token_normalization_removes_dashes(self): + """ + Test that tokens are normalized by removing dashes. + """ + mixer = KolibriPublicMixer() + + channel, public_root_node = self._create_channel_with_main_tree(mixer) + token = SecretToken.objects.create(token="abcd1234efgh5678", is_primary=True) + channel.secret_tokens.add(token) + + metadata = mixer.blend( + ChannelMetadata, id=channel.id, root=public_root_node, public=False + ) + + response = self.client.get( + reverse_with_query( + "publicchannel-list", + query={"token": "abcd-1234-efgh-5678"}, + ), + ) + + self.assertEqual(response.status_code, 200, response.content) + self.assertEqual(len(response.data), 1) + self.assertEqual(UUID(response.data[0]["id"]), UUID(metadata.id)) + + def test_nonexistent_token_returns_empty_list(self): + """ + Test that a non-existent token returns an empty list. + """ + response = self.client.get( + reverse_with_query( + "publicchannel-list", + query={"token": "nonexistent-token-12345"}, + ), + ) + + self.assertEqual(response.status_code, 200, response.content) + self.assertEqual(len(response.data), 0) diff --git a/contentcuration/kolibri_public/views.py b/contentcuration/kolibri_public/views.py index 867a677798..ba91f039ba 100644 --- a/contentcuration/kolibri_public/views.py +++ b/contentcuration/kolibri_public/views.py @@ -42,6 +42,8 @@ from contentcuration.middleware.locale import locale_exempt from contentcuration.middleware.session import session_exempt +from contentcuration.models import Channel +from contentcuration.models import ChannelVersion from contentcuration.models import Country from contentcuration.models import generate_storage_url from contentcuration.utils.pagination import ValuesViewsetCursorPagination @@ -176,14 +178,167 @@ class ChannelMetadataViewSet(ReadOnlyValuesViewset): "lang_name": "root__lang__native_name", } + def get_queryset_from_token(self, token): + """ + Retrieve channel data based on a token. + + This method checks both Channel.secret_tokens and ChannelVersion.secret_token + to find matching channels. It returns an annotated Channel queryset from the + contentcuration models. + """ + normalized_token = token.replace("-", "").strip() + + channels = Channel.objects.filter( + secret_tokens__token=normalized_token, + deleted=False, + main_tree__published=True, + ) + + if channels.exists(): + return channels, None + + channel_versions = ChannelVersion.objects.filter( + secret_token__token=normalized_token + ).select_related("channel__main_tree__language") + + if channel_versions.exists(): + channel_ids = [cv.channel_id for cv in channel_versions] + + version_data = {} + for cv in channel_versions: + version_data[str(cv.channel_id)] = { + "published_size": cv.size, + "total_resource_count": cv.resource_count, + "last_updated": cv.date_published, + "included_languages": cv.included_languages or [], + "categories": cv.included_categories or [], + "version": cv.version, + } + + channels = Channel.objects.filter( + id__in=channel_ids, + deleted=False, + ) + + return channels, version_data + + return Channel.objects.none(), None + + def _serialize_token_queryset(self, queryset): + + channels_data = [] + for channel in queryset.select_related("main_tree__language"): + item = { + "id": str(channel.id), + "name": channel.name, + "description": channel.description, + "tagline": channel.tagline, + "author": "", + "version": channel.version, + "thumbnail": channel.thumbnail_encoding or "", + "last_updated": channel.last_published, + "root": channel.main_tree.node_id, + "root__lang__lang_code": channel.main_tree.language.lang_code + if channel.main_tree.language + else None, + "root__lang__native_name": channel.main_tree.language.lang_name + if channel.main_tree.language + else None, + "root__available": True, + "root__num_coach_contents": 0, + "public": channel.public, + "total_resource_count": channel.total_resource_count, + "published_size": channel.published_size, + "categories": [], + } + channels_data.append(item) + + return channels_data + + def _consolidate_token_items(self, items, version_data): + for item in items: + channel_id = str(item["id"]) + data = version_data.get(channel_id) + if data: + if data["published_size"] is not None: + item["published_size"] = data["published_size"] + if data["total_resource_count"] is not None: + item["total_resource_count"] = data["total_resource_count"] + if data["last_updated"] is not None: + item["last_updated"] = data["last_updated"] + if data["categories"]: + item["categories"] = data["categories"] + item["included_languages"] = data["included_languages"] or [] + else: + item["included_languages"] = [] + item["last_published"] = item["last_updated"] + item["countries"] = [] + return items + + def serialize(self, queryset): + + if queryset.model == Channel: + items = self._serialize_token_queryset(queryset) + version_data = getattr(self, "_version_data", None) + if version_data: + items = self._consolidate_token_items(items, version_data) + else: + for item in items: + item["included_languages"] = [] + item["last_published"] = item["last_updated"] + item["countries"] = [] + return items + + return super().serialize(queryset) + def get_queryset(self): + """ + Get the base queryset for the viewset. + + If a 'token' query parameter is present, this will return channels + matching that token from the contentcuration models. Otherwise, returns all channels. + """ + token = self.request.query_params.get("token") + if token: + self._token_queryset, self._version_data = self.get_queryset_from_token( + token + ) + return self._token_queryset + self._version_data = None return models.ChannelMetadata.objects.all() + def filter_queryset(self, queryset): + """ + Filter the queryset. + + If a 'token' query parameter is present, all other filters are disabled + and the queryset is returned unfiltered. Otherwise, applies the normal + filter behavior. + """ + token = self.request.query_params.get("token") + if token: + return queryset + return super().filter_queryset(queryset) + def consolidate(self, items, queryset): - # Only keep a single item for every channel ID, to get rid of possible - # duplicates caused by filtering + """ + Consolidate items for serialization. + + When using token-based access, items are already consolidated in serialize(), + so we just return them as-is for Channel querysets. + """ + # For Channel querysets from token-based access, items are already consolidated + if queryset.model == Channel: + return items + + # For ChannelMetadata querysets, use the default consolidation items = list(OrderedDict((item["id"], item) for item in items).values()) + version_data = getattr(self, "_version_data", None) + if version_data: + return self._consolidate_token_items(items, version_data) + return self._consolidate_regular_items(items, queryset) + def _consolidate_regular_items(self, items, queryset): included_languages = {} for ( channel_id, @@ -196,9 +351,6 @@ def consolidate(self, items, queryset): if channel_id not in included_languages: included_languages[channel_id] = [] included_languages[channel_id].append(language_id) - for item in items: - item["included_languages"] = included_languages.get(item["id"], []) - item["last_published"] = item["last_updated"] countries = {} for (channel_id, country_code) in Country.objects.filter( @@ -209,8 +361,9 @@ def consolidate(self, items, queryset): countries[channel_id].append(country_code) for item in items: + item["included_languages"] = included_languages.get(item["id"], []) + item["last_published"] = item["last_updated"] item["countries"] = countries.get(item["id"], []) - return items