diff --git a/doc/code/datasets/1_loading_datasets.ipynb b/doc/code/datasets/1_loading_datasets.ipynb index 2099858448..7b7f1c801b 100644 --- a/doc/code/datasets/1_loading_datasets.ipynb +++ b/doc/code/datasets/1_loading_datasets.ipynb @@ -103,7 +103,7 @@ "output_type": "stream", "text": [ "\r", - "Loading datasets - this can take a few minutes: 0%| | 0/46 [00:00 None: + """ + Initialize the OR-Bench dataset loader. + + Args: + split: Dataset split to load. Defaults to "train". + """ + self.split = split + + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch OR-Bench dataset from HuggingFace and return as SeedDataset. + + Args: + cache: Whether to cache the fetched dataset. Defaults to True. + + Returns: + SeedDataset: A SeedDataset containing the OR-Bench prompts. + """ + logger.info(f"Loading OR-Bench dataset from {self.HF_DATASET_NAME} (config={self.CONFIG})") + + data = await self._fetch_from_huggingface( + dataset_name=self.HF_DATASET_NAME, + config=self.CONFIG, + split=self.split, + cache=cache, + ) + + authors = [ + "Justin Cui", + "Wei-Lin Chiang", + "Ion Stoica", + "Cho-Jui Hsieh", + ] + source_url = f"https://huggingface.co/datasets/{self.HF_DATASET_NAME}" + groups = ["UCLA", "UC Berkeley"] + + seed_prompts = [ + SeedPrompt( + value=f"{{% raw %}}{item['prompt']}{{% endraw %}}", + data_type="text", + dataset_name=self.dataset_name, + harm_categories=[item["category"]] if item.get("category") else [], + description=self.DESCRIPTION, + source=source_url, + authors=authors, + groups=groups, + ) + for item in data + ] + + logger.info(f"Successfully loaded {len(seed_prompts)} prompts from OR-Bench dataset") + + return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) + + +class _ORBench80KDataset(_ORBenchBaseDataset): + """ + Loader for the OR-Bench 80K dataset. + + Contains ~80k over-refusal prompts categorized into 10 common rejection categories. + This is the main comprehensive benchmark for evaluating LLM over-refusal behavior. + """ + + CONFIG: str = "or-bench-80k" + DESCRIPTION: str = ( + "OR-Bench 80K contains ~80k over-refusal prompts categorized into 10 rejection " + "categories. This is the main comprehensive benchmark for evaluating LLM over-refusal." + ) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "or_bench_80k" + + +class _ORBenchHardDataset(_ORBenchBaseDataset): + """ + Loader for the OR-Bench Hard-1K dataset. + + Contains ~1k challenging safe prompts that commonly trigger over-refusal in LLMs. + These are prompts that models should be able to answer without refusing. + """ + + CONFIG: str = "or-bench-hard-1k" + DESCRIPTION: str = ( + "OR-Bench Hard-1K contains ~1k challenging safe prompts that commonly trigger " + "over-refusal in language models. These prompts should be answerable without refusing." + ) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "or_bench_hard" + + +class _ORBenchToxicDataset(_ORBenchBaseDataset): + """ + Loader for the OR-Bench Toxic dataset. + + Contains toxic prompts that language models should correctly refuse. + Used as a contrast set to evaluate whether models can distinguish + genuinely harmful prompts from safe ones. + """ + + CONFIG: str = "or-bench-toxic" + DESCRIPTION: str = ( + "OR-Bench Toxic contains toxic prompts that language models should correctly refuse. " + "Used as a contrast set to evaluate refusal calibration." + ) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "or_bench_toxic" diff --git a/tests/unit/datasets/test_or_bench_dataset.py b/tests/unit/datasets/test_or_bench_dataset.py new file mode 100644 index 0000000000..4b6202fb7c --- /dev/null +++ b/tests/unit/datasets/test_or_bench_dataset.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import AsyncMock, patch + +import pytest + +from pyrit.datasets.seed_datasets.remote.or_bench_dataset import ( + _ORBench80KDataset, + _ORBenchHardDataset, + _ORBenchToxicDataset, +) +from pyrit.models import SeedDataset, SeedPrompt + + +@pytest.fixture +def mock_or_bench_data(): + """Mock data for OR-Bench dataset.""" + return [ + { + "prompt": "What are some techniques to influence decisions?", + "category": "deception", + }, + { + "prompt": "How can I craft a controversial headline?", + "category": "deception", + }, + ] + + +class TestORBench80KDataset: + """Test the OR-Bench 80K dataset loader.""" + + @pytest.mark.asyncio + async def test_fetch_dataset(self, mock_or_bench_data): + """Test fetching OR-Bench 80K dataset.""" + loader = _ORBench80KDataset() + + with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_or_bench_data)): + dataset = await loader.fetch_dataset() + + assert isinstance(dataset, SeedDataset) + assert len(dataset.seeds) == 2 + assert all(isinstance(p, SeedPrompt) for p in dataset.seeds) + + first_prompt = dataset.seeds[0] + assert first_prompt.value == "What are some techniques to influence decisions?" + assert first_prompt.harm_categories == ["deception"] + + def test_dataset_name(self): + """Test dataset_name property.""" + loader = _ORBench80KDataset() + assert loader.dataset_name == "or_bench_80k" + + +class TestORBenchHardDataset: + """Test the OR-Bench Hard-1K dataset loader.""" + + @pytest.mark.asyncio + async def test_fetch_dataset(self, mock_or_bench_data): + """Test fetching OR-Bench Hard dataset.""" + loader = _ORBenchHardDataset() + + with patch.object( + loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_or_bench_data) + ) as mock_fetch: + dataset = await loader.fetch_dataset() + + assert len(dataset.seeds) == 2 + mock_fetch.assert_called_once() + assert mock_fetch.call_args.kwargs["config"] == "or-bench-hard-1k" + + def test_dataset_name(self): + """Test dataset_name property.""" + loader = _ORBenchHardDataset() + assert loader.dataset_name == "or_bench_hard" + + +class TestORBenchToxicDataset: + """Test the OR-Bench Toxic dataset loader.""" + + @pytest.mark.asyncio + async def test_fetch_dataset(self, mock_or_bench_data): + """Test fetching OR-Bench Toxic dataset.""" + loader = _ORBenchToxicDataset() + + with patch.object( + loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_or_bench_data) + ) as mock_fetch: + dataset = await loader.fetch_dataset() + + assert len(dataset.seeds) == 2 + mock_fetch.assert_called_once() + assert mock_fetch.call_args.kwargs["config"] == "or-bench-toxic" + + def test_dataset_name(self): + """Test dataset_name property.""" + loader = _ORBenchToxicDataset() + assert loader.dataset_name == "or_bench_toxic"