diff --git a/open_diloco/train_diloco_torch.py b/open_diloco/train_diloco_torch.py index 6fa66d0..cb68c0f 100644 --- a/open_diloco/train_diloco_torch.py +++ b/open_diloco/train_diloco_torch.py @@ -230,6 +230,10 @@ def tokenize_function(data): if eval_steps is not None: eval_dataset = tokenized_datasets["validation"] + # Limit streaming validation to 1000 samples to avoid infinite loop + # (IterableDataset has no __len__, causing DataLoader to loop forever) + if hasattr(eval_dataset, "take"): + eval_dataset = eval_dataset.take(1000) eval_dataloader = DataLoader( eval_dataset, collate_fn=data_collator,