Adds Validation Adds validation loss + exposes it in the API#685
Adds Validation Adds validation loss + exposes it in the API#685RobotSail wants to merge 5 commits intoinstructlab:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds configurable validation: new TrainingArgs fields (validation_split, validation_frequency), dataset splitting and validation DataLoader support, a distributed validation loss routine, periodic validation checks integrated into the training loop, and CLI flags to control validation behavior. Changes
Sequence DiagramsequenceDiagram
participant TL as Training Loop
participant VS as Validation Scheduler
participant CV as compute_validation_loss()
participant DL as Validation DataLoader
participant ACC as Accelerator / Distributed
participant LOG as Logger (rank 0)
loop Training steps
TL->>VS: step tick, check validation_frequency
alt execute validation
VS->>CV: invoke(model, val_data_loader, device)
CV->>DL: iterate batches (no-grad)
DL-->>CV: batches
CV->>ACC: aggregate per-token loss across ranks
ACC-->>CV: aggregated loss, token count
CV->>VS: return {val_loss, val_num_tokens}
VS->>LOG: log metrics (only rank 0)
else skip
VS-->>TL: continue training
end
end
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
src/instructlab/training/sampler.py (1)
492-503: Consider extracting shared length-materialization logic.The train/test branches duplicate the same
"len"computation path. A tiny helper would reduce repetition and future drift risk.Also applies to: 511-529
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/instructlab/training/sampler.py` around lines 492 - 503, The duplicated length-materialization logic in the train/test branches should be extracted into a small helper (e.g., a private class/static method named _materialize_lengths or materialize_lengths) that takes the dataset and returns the numpy lengths array; replace the repeated blocks that check "len" in dataset.column_names, call dataset.map(...)[ "len" ] when missing, or use dataset["len"] when present, with a single call to this helper, and assign its return to train_ds.lengths / val_ds.lengths accordingly in the existing code paths (preserve use of num_proc=8 and numpy conversion).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/instructlab/training/config.py`:
- Around line 344-352: Add a model-level validator on the TrainingArgs model to
validate validation_split and validation_frequency: ensure validation_split is
within [0.0, 1.0], and if validation_split > 0.0 then validation_frequency must
be non-None and > 0; implement this as a `@model_validator` (mode="before" or
"after" as appropriate for your pydantic version) inside the TrainingArgs class,
raise ValueError with clear messages referencing validation_split and
validation_frequency when checks fail so programmatic instantiation fails early.
In `@src/instructlab/training/main_ds.py`:
- Around line 146-156: Inside the validation loop remove per-micro-batch
torch.cuda.empty_cache() and the per-batch dist.barrier(); instead accumulate
batch_loss and total_tokens locally (keep using batch_loss variable and
batch[0]["batch_num_loss_counted_tokens"]) across micro-batches, then after
finishing the loop create tensors (e.g., total_loss_tensor and
total_tokens_tensor or reuse batch_loss_tensor) on device and call
dist.all_reduce(..., op=dist.ReduceOp.SUM) once to reduce the summed values,
finally set total_loss and total_tokens from the reduced tensor values—this
eliminates repeated global syncs while preserving SUM semantics for batch_loss
and token counts.
In `@src/instructlab/training/sampler.py`:
- Around line 552-553: The get_data_loader function now always returns a tuple
which breaks callers expecting a single DataLoader; change get_data_loader to
preserve backward compatibility by returning a single DataLoader by default and
only return (train_loader, val_loader) when an explicit flag (e.g.,
return_val_loader: bool = False) is passed; implement this by adding the
return_val_loader parameter to get_data_loader, branching the return path to
return train_loader when return_val_loader is False and (train_loader,
val_loader) when True (use the existing validation_split logic when building
val_loader), and update calling sites to opt in by passing
return_val_loader=True where they need the validation loader.
---
Nitpick comments:
In `@src/instructlab/training/sampler.py`:
- Around line 492-503: The duplicated length-materialization logic in the
train/test branches should be extracted into a small helper (e.g., a private
class/static method named _materialize_lengths or materialize_lengths) that
takes the dataset and returns the numpy lengths array; replace the repeated
blocks that check "len" in dataset.column_names, call dataset.map(...)[ "len" ]
when missing, or use dataset["len"] when present, with a single call to this
helper, and assign its return to train_ds.lengths / val_ds.lengths accordingly
in the existing code paths (preserve use of num_proc=8 and numpy conversion).
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/instructlab/training/config.pysrc/instructlab/training/main_ds.pysrc/instructlab/training/sampler.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit/test_pretraining_sampler.py (1)
266-516: Add test coverage for the newvalidation_splitcode path.All eight updated call sites pass the default
validation_split=0.0, so theval_loaderbranch inget_data_loader(the newload_and_splitclassmethods,SequentialSampler-backed loader, and the out-of-rangeValueError) remains entirely uncovered by the test suite. Consider adding tests for at least:
validation_split > 0→val_loaderis notNone, usesSequentialSampler, and its dataset length matches the expected fraction.validation_split = 0.0(explicit) →val_loader is None.validation_split < 0orvalidation_split >= 1.0→ValueErroris raised.- Train and validation datasets are disjoint (no sample appears in both).
🧪 Suggested skeleton for the missing tests
`@patch`("instructlab.training.sampler.load_dataset") def test_validation_split_creates_val_loader( self, mock_load_dataset, temp_pretraining_file ): """val_loader is non-None and uses SequentialSampler when validation_split > 0.""" from torch.utils.data import SequentialSampler mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 1 mock_ds.__iter__ = lambda self: iter( [{"input_ids": list(range(100)), "len": 100}] ) mock_load_dataset.return_value = mock_ds train_loader, val_loader = get_data_loader( data_path=temp_pretraining_file, batch_size=2, max_tokens_per_gpu=1000, seed=42, rank=0, world_size=1, pretraining_config=PretrainingConfig(block_size=10), validation_split=0.2, ) assert val_loader is not None assert isinstance(val_loader.sampler, SequentialSampler) # train + val should cover the full dataset assert len(train_loader.dataset) + len(val_loader.dataset) == len( train_loader.dataset ) + len(val_loader.dataset) # replace with concrete expected sizes `@patch`("instructlab.training.sampler.load_dataset") def test_validation_split_zero_returns_none_val_loader( self, mock_load_dataset, temp_pretraining_file ): """val_loader is None when validation_split=0.0.""" mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 1 mock_ds.__iter__ = lambda self: iter( [{"input_ids": list(range(50)), "len": 50}] ) mock_load_dataset.return_value = mock_ds _, val_loader = get_data_loader( data_path=temp_pretraining_file, batch_size=2, max_tokens_per_gpu=100, seed=42, rank=0, world_size=1, pretraining_config=PretrainingConfig(block_size=10), validation_split=0.0, ) assert val_loader is None `@patch`("instructlab.training.sampler.load_dataset") `@pytest.mark.parametrize`("invalid_split", [-0.1, 1.0, 1.5]) def test_invalid_validation_split_raises( self, mock_load_dataset, temp_pretraining_file, invalid_split ): """ValueError raised for out-of-range validation_split.""" mock_load_dataset.return_value = MagicMock( column_names=["input_ids", "len"], __len__=lambda self: 1, __iter__=lambda self: iter([{"input_ids": [1], "len": 1}]), ) with pytest.raises(ValueError, match="validation_split"): get_data_loader( data_path=temp_pretraining_file, batch_size=2, max_tokens_per_gpu=100, seed=42, rank=0, world_size=1, pretraining_config=PretrainingConfig(block_size=5), validation_split=invalid_split, )Would you like me to open a new issue to track adding this validation-split test coverage, or generate the full test implementations?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/test_pretraining_sampler.py` around lines 266 - 516, Add tests covering the new validation_split code path: ensure get_data_loader(validation_split>0) returns a non-None val_loader whose sampler is a SequentialSampler and whose dataset sizes sum to the mocked dataset length (use get_data_loader, validation_split, val_loader, and inspect loader.sampler and loader.dataset); test validation_split==0.0 explicitly returns val_loader is None; parametrize invalid splits (e.g., -0.1, 1.0, 1.5) to assert get_data_loader raises ValueError mentioning validation_split; and add a test that train and val datasets are disjoint (no overlapping samples) by mocking load_dataset to return a deterministic dataset and comparing items from loader.dataset and val_loader.dataset.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/unit/test_pretraining_sampler.py`:
- Around line 266-516: Add tests covering the new validation_split code path:
ensure get_data_loader(validation_split>0) returns a non-None val_loader whose
sampler is a SequentialSampler and whose dataset sizes sum to the mocked dataset
length (use get_data_loader, validation_split, val_loader, and inspect
loader.sampler and loader.dataset); test validation_split==0.0 explicitly
returns val_loader is None; parametrize invalid splits (e.g., -0.1, 1.0, 1.5) to
assert get_data_loader raises ValueError mentioning validation_split; and add a
test that train and val datasets are disjoint (no overlapping samples) by
mocking load_dataset to return a deterministic dataset and comparing items from
loader.dataset and val_loader.dataset.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 116-153: The bug is over-counting validation tokens because the
code adds batch[0]["batch_num_loss_counted_tokens"] for every minibatch
(variable batch comes from the sampler split) which duplicates the same
batch-level token count; instead, when computing total_tokens in main_ds.py use
per-minibatch token counts (e.g., compute valid_tokens.sum().item() from the
token mask for each minibatch inside the for mb in batch loop) or only add
batch[0]["batch_num_loss_counted_tokens"] once per original batch (not per
minibatch). Update the accumulation so total_tokens += per_minibatch_token_count
(derived from valid_tokens) and leave the distributed all_reduce on
tokens_tensor unchanged.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/instructlab/training/config.pysrc/instructlab/training/main_ds.pysrc/instructlab/training/sampler.py
This PR adds the ability to specify a validation loss split and frequency at which validation loss is computed.
Summary by CodeRabbit
New Features
Bug Fixes
Tests