diff --git a/pyrit/datasets/seed_datasets/remote/__init__.py b/pyrit/datasets/seed_datasets/remote/__init__.py index 926dd4531..6bdd8b750 100644 --- a/pyrit/datasets/seed_datasets/remote/__init__.py +++ b/pyrit/datasets/seed_datasets/remote/__init__.py @@ -16,6 +16,9 @@ from pyrit.datasets.seed_datasets.remote.babelscape_alert_dataset import ( _BabelscapeAlertDataset, ) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.beaver_tails_dataset import ( + _BeaverTailsDataset, +) # noqa: F401 from pyrit.datasets.seed_datasets.remote.ccp_sensitive_prompts_dataset import ( _CCPSensitivePromptsDataset, ) # noqa: F401 @@ -102,6 +105,7 @@ "_AegisContentSafetyDataset", "_AyaRedteamingDataset", "_BabelscapeAlertDataset", + "_BeaverTailsDataset", "_CCPSensitivePromptsDataset", "_DarkBenchDataset", "_EquityMedQADataset", diff --git a/pyrit/datasets/seed_datasets/remote/beaver_tails_dataset.py b/pyrit/datasets/seed_datasets/remote/beaver_tails_dataset.py new file mode 100644 index 000000000..2689da8bd --- /dev/null +++ b/pyrit/datasets/seed_datasets/remote/beaver_tails_dataset.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +from jinja2 import TemplateSyntaxError + +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, +) +from pyrit.models import SeedDataset, SeedPrompt + +logger = logging.getLogger(__name__) + + +class _BeaverTailsDataset(_RemoteDatasetLoader): + """ + Loader for the BeaverTails dataset from HuggingFace. + + BeaverTails contains 330k+ entries annotated across 14 harm categories. + It is widely used for safety alignment research. This loader extracts only the + prompts (not the responses) and filters to unsafe entries by default. + + References: + - https://huggingface.co/datasets/PKU-Alignment/BeaverTails + - https://arxiv.org/abs/2307.04657 + - https://github.com/PKU-Alignment/beavertails + License: CC BY-NC 4.0 + + Warning: This dataset contains unsafe and potentially harmful content. Consult your + legal department before using these prompts for testing. + """ + + HF_DATASET_NAME: str = "PKU-Alignment/BeaverTails" + + def __init__( + self, + *, + split: str = "330k_train", + unsafe_only: bool = True, + ): + """ + Initialize the BeaverTails dataset loader. + + Args: + split: Dataset split to load. Defaults to "330k_train". + unsafe_only: If True, only load entries marked as unsafe. Defaults to True. + """ + self.split = split + self.unsafe_only = unsafe_only + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "beaver_tails" + + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch BeaverTails dataset from HuggingFace and return as SeedDataset. + + Args: + cache: Whether to cache the fetched dataset. Defaults to True. + + Returns: + SeedDataset: A SeedDataset containing the BeaverTails prompts. + """ + logger.info(f"Loading BeaverTails dataset from {self.HF_DATASET_NAME}") + + data = await self._fetch_from_huggingface( + dataset_name=self.HF_DATASET_NAME, + split=self.split, + cache=cache, + ) + + authors = [ + "Jiaming Ji", + "Mickel Liu", + "Juntao Dai", + "Xuehai Pan", + "Chi Zhang", + "Ce Bian", + "Boyuan Chen", + "Ruiyang Sun", + "Yizhou Wang", + "Yaodong Yang", + ] + description = ( + "This seed dataset exposes only the prompts/questions derived from BeaverTails and does " + "not include the corresponding answers. BeaverTails contains 330k+ entries annotated " + "across 14 harm categories, designed for safety alignment research. Introduced in " + "'BeaverTails: Towards Improved Safety Alignment of LLM via a Human-Preference Dataset' (2023)." + ) + + source_url = f"https://huggingface.co/datasets/{self.HF_DATASET_NAME}" + groups = ["Institute for Artificial Intelligence", "CFCS, School of Computer Science"] + + seed_prompts = [] + for item in data: + if self.unsafe_only and item["is_safe"]: + continue + + harm_categories = [k for k, v in item["category"].items() if v] + + try: + seed_prompts.append( + SeedPrompt( + value=f"{{% raw %}}{item['prompt']}{{% endraw %}}", + data_type="text", + dataset_name=self.dataset_name, + harm_categories=harm_categories, + description=description, + source=source_url, + authors=authors, + groups=groups, + ) + ) + except TemplateSyntaxError: + logger.warning("Skipping BeaverTails prompt due to Jinja2 template syntax error in prompt text") + + logger.info(f"Successfully loaded {len(seed_prompts)} prompts from BeaverTails dataset") + + return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) diff --git a/tests/unit/datasets/test_beaver_tails_dataset.py b/tests/unit/datasets/test_beaver_tails_dataset.py new file mode 100644 index 000000000..77bc9dd5b --- /dev/null +++ b/tests/unit/datasets/test_beaver_tails_dataset.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import AsyncMock, patch + +import pytest + +from pyrit.datasets.seed_datasets.remote.beaver_tails_dataset import _BeaverTailsDataset +from pyrit.models import SeedDataset, SeedPrompt + + +@pytest.fixture +def mock_beaver_tails_data(): + """Mock data for BeaverTails dataset.""" + + class MockDataset: + """Mock HuggingFace dataset that supports iteration.""" + + def __init__(self): + self._data = [ + { + "prompt": "How can I steal something?", + "response": "Here is how...", + "category": { + "animal_abuse": False, + "child_abuse": False, + "financial_crime,property_crime,theft": True, + "violence,aiding_and_abetting,incitement": True, + }, + "is_safe": False, + }, + { + "prompt": "What is the weather today?", + "response": "It is sunny.", + "category": { + "animal_abuse": False, + "child_abuse": False, + "financial_crime,property_crime,theft": False, + "violence,aiding_and_abetting,incitement": False, + }, + "is_safe": True, + }, + { + "prompt": "How to hack a system?", + "response": "Here is how...", + "category": { + "animal_abuse": False, + "child_abuse": False, + "financial_crime,property_crime,theft": False, + "violence,aiding_and_abetting,incitement": True, + }, + "is_safe": False, + }, + ] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + return MockDataset() + + +class TestBeaverTailsDataset: + """Test the BeaverTails dataset loader.""" + + @pytest.mark.asyncio + async def test_fetch_dataset_unsafe_only(self, mock_beaver_tails_data): + """Test fetching BeaverTails dataset with unsafe_only=True.""" + loader = _BeaverTailsDataset() + + with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_beaver_tails_data)): + dataset = await loader.fetch_dataset() + + assert isinstance(dataset, SeedDataset) + assert len(dataset.seeds) == 2 # Only unsafe entries + assert all(isinstance(p, SeedPrompt) for p in dataset.seeds) + + first_prompt = dataset.seeds[0] + assert first_prompt.value == "How can I steal something?" + assert "financial_crime,property_crime,theft" in first_prompt.harm_categories + + @pytest.mark.asyncio + async def test_fetch_dataset_all_entries(self, mock_beaver_tails_data): + """Test fetching BeaverTails dataset with unsafe_only=False.""" + loader = _BeaverTailsDataset(unsafe_only=False) + + with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_beaver_tails_data)): + dataset = await loader.fetch_dataset() + + assert len(dataset.seeds) == 3 # All entries including safe + + def test_dataset_name(self): + """Test dataset_name property.""" + loader = _BeaverTailsDataset() + assert loader.dataset_name == "beaver_tails" + + @pytest.mark.asyncio + async def test_fetch_dataset_skips_prompt_with_template_syntax_error(self): + """Test that prompts causing TemplateSyntaxError are skipped gracefully.""" + + class MockDataset: + def __init__(self): + self._data = [ + { + "prompt": "This contains {% endraw %} which breaks Jinja2", + "response": "response", + "category": {"animal_abuse": True}, + "is_safe": False, + }, + { + "prompt": "Normal unsafe prompt", + "response": "response", + "category": {"animal_abuse": True}, + "is_safe": False, + }, + ] + + def __iter__(self): + return iter(self._data) + + loader = _BeaverTailsDataset() + + with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=MockDataset())): + dataset = await loader.fetch_dataset() + # The broken prompt should be skipped, only the normal one remains + assert len(dataset.seeds) == 1 + assert dataset.seeds[0].value == "Normal unsafe prompt"