Skip to content

Commit e1159de

Browse files
authored
Update README for DISCO example (#57)
- Update installation commands for DISCO example - Simplify log messages for DISCO example - Update example logs in README for DISCO example
1 parent 45098b0 commit e1159de

2 files changed

Lines changed: 33 additions & 15 deletions

File tree

examples/mmlu_benchmark/README.md

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
# MMLU Benchmark Example
22

3-
Evaluate language models on [MMLU (Massive Multitask Language Understanding)](https://arxiv.org/abs/2009.03300) with optional efficient evaluation via [DISCO](https://arxiv.org/abs/2510.07959).
3+
Evaluate language models on [MMLU (Massive Multitask Language Understanding)](https://arxiv.org/abs/2009.03300) with optional efficient evaluation via [DISCO (Diversifying Sample Condensation)](https://arxiv.org/abs/2510.07959).
44

55
## Installation
66

7-
For basic MMLU evaluation:
7+
Install [uv package manager](https://docs.astral.sh/uv/) as described [here](https://docs.astral.sh/uv/getting-started/installation/).
8+
9+
Create Python environment:
10+
11+
```bash
12+
uv venv --python 3.11
13+
```
14+
15+
Install dependencies for basic MMLU evaluation:
816

917
```bash
10-
uv pip install .[mmlu]
18+
uv sync --extra mmlu
1119
```
1220

13-
For DISCO prediction (includes DISCO dependencies):
21+
Install dependencies for MMLU evaluation with DISCO:
1422

1523
```bash
16-
uv pip install .[disco]
24+
uv sync --extra disco
1725
```
1826

1927
## Run without DISCO (full evaluation)
@@ -31,9 +39,8 @@ Full evaluation results look like:
3139
Results Summary (Evaluated Tasks)
3240
================================================================================
3341
Total tasks: 14042
34-
Correct: 8291
35-
Accuracy (on anchor points): 0.5904
36-
Accuracy norm (on anchor points): 0.5904
42+
Correct: 8292
43+
Accuracy: 0.5905
3744
```
3845

3946
## Run with DISCO (predicted full-benchmark score)
@@ -47,10 +54,24 @@ uv run python examples/mmlu_benchmark/mmlu_benchmark.py --model_id alignment-han
4754
Predicted score output:
4855

4956
```
57+
================================================================================
58+
Results Summary (Evaluated Tasks)
59+
================================================================================
60+
Total tasks: 100
61+
Correct: 36
62+
Accuracy: 0.3600
63+
64+
================================================================================
65+
DISCO Prediction
66+
================================================================================
67+
Computing embeddings and predicting full benchmark accuracy...
68+
Fetching 9 files: 100%|██████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 19171.53it/s]
69+
Using: DISCO predictor from Hugging Face (arubique/DISCO-MMLU)
70+
5071
----------------------------------------
5172
DISCO Predicted Full Benchmark Accuracy:
5273
----------------------------------------
53-
Model 0: 0.606739
74+
Model 0 (alignment-handbook/zephyr-7b-sft-full): 0.602309
5475
```
5576

5677
## Arguments

examples/mmlu_benchmark/mmlu_benchmark.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,6 @@ def extract_eval_entries(res):
258258
print(f"Saved predictions tensor to {output_path}")
259259
print(f" Shape: {predictions.shape}")
260260
print(f" Dtype: {predictions.dtype}")
261-
else:
262-
print(f"Built predictions tensor with shape: {predictions.shape}")
263261

264262
return predictions
265263

@@ -723,8 +721,7 @@ def main():
723721
print("=" * 80)
724722
print(f"Total tasks: {metrics['total_tasks']}")
725723
print(f"Correct: {metrics['correct_count']}")
726-
print(f"Accuracy (on anchor points): {metrics['acc']:.4f}")
727-
print(f"Accuracy norm (on anchor points): {metrics['acc_norm']:.4f}")
724+
print(f"Accuracy: {metrics['acc']:.4f}")
728725

729726
# Build predictions tensor for DISCO
730727
predictions = None
@@ -754,8 +751,8 @@ def main():
754751
print("\n" + "-" * 40)
755752
print("DISCO Predicted Full Benchmark Accuracy:")
756753
print("-" * 40)
757-
for model_idx, acc in disco_results["predicted_accuracies"].items():
758-
print(f" Model {model_idx}: {acc:.6f}")
754+
for model_idx, acc in sorted(disco_results["predicted_accuracies"].items()):
755+
print(f" Model {model_idx} ({args.model_id}): {acc:.6f}")
759756

760757
# Save summary
761758
summary_data = {

0 commit comments

Comments
 (0)