Skip to content

Adds Validation Adds validation loss + exposes it in the API#685

Open
RobotSail wants to merge 5 commits intoinstructlab:mainfrom
RobotSail:main
Open

Adds Validation Adds validation loss + exposes it in the API#685
RobotSail wants to merge 5 commits intoinstructlab:mainfrom
RobotSail:main

Conversation

@RobotSail
Copy link
Member

@RobotSail RobotSail commented Feb 25, 2026

This PR adds the ability to specify a validation loss split and frequency at which validation loss is computed.

Summary by CodeRabbit

  • New Features

    • Configurable validation during training: specify a validation split, evaluation frequency, and CLI options to enable validation.
    • Training will compute and log validation loss at the configured intervals; datasets can be auto-split into training and validation sets.
  • Bug Fixes

    • Validation configuration is now validated at setup time; invalid validation settings produce immediate errors.
  • Tests

    • Unit tests updated to handle the loader returning both train and optional validation data.

@coderabbitai
Copy link

coderabbitai bot commented Feb 25, 2026

📝 Walkthrough

Walkthrough

Adds configurable validation: new TrainingArgs fields (validation_split, validation_frequency), dataset splitting and validation DataLoader support, a distributed validation loss routine, periodic validation checks integrated into the training loop, and CLI flags to control validation behavior.

Changes

Cohort / File(s) Summary
Configuration
src/instructlab/training/config.py
Added validation_split: float and validation_frequency: Optional[int] to TrainingArgs; added validate_validation_config post-init validator enforcing 0.0 <= split < 1.0 and requiring positive frequency when split > 0.
Training Workflow
src/instructlab/training/main_ds.py
Added compute_validation_loss(model, val_data_loader, device); extended train(...) signature to accept val_data_loader and validation_frequency; integrated periodic validation checks, distributed aggregation of validation loss, logging on rank 0, and propagated val loader/frequency through run. Added CLI flags --validation_split and --validation_frequency.
Data Loading / Samplers
src/instructlab/training/sampler.py
Added load_and_split() classmethods to PretrainingBlockDataset and TokenDataset for reproducible train/val splits; get_data_loader(...) gains validation_split param and now returns `(train_loader, val_loader
Tests
tests/unit/test_pretraining_mode.py, tests/unit/test_pretraining_sampler.py
Updated tests to unpack new tuple return from get_data_loader (e.g., loader, _ = get_data_loader(...)) to accommodate optional validation loader; assertions unchanged otherwise.

Sequence Diagram

sequenceDiagram
    participant TL as Training Loop
    participant VS as Validation Scheduler
    participant CV as compute_validation_loss()
    participant DL as Validation DataLoader
    participant ACC as Accelerator / Distributed
    participant LOG as Logger (rank 0)

    loop Training steps
        TL->>VS: step tick, check validation_frequency
        alt execute validation
            VS->>CV: invoke(model, val_data_loader, device)
            CV->>DL: iterate batches (no-grad)
            DL-->>CV: batches
            CV->>ACC: aggregate per-token loss across ranks
            ACC-->>CV: aggregated loss, token count
            CV->>VS: return {val_loss, val_num_tokens}
            VS->>LOG: log metrics (only rank 0)
        else skip
            VS-->>TL: continue training
        end
    end
Loading

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 I nibbled lines and split the hay,
A validation patch to brighten the day,
At steps I pause to count each token,
Across the ranks my loss is spoken,
Hooray — the training hops OK!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.90% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title mentions validation loss and API exposure, which are core objectives, but contains a redundant phrase 'Adds Validation Adds' suggesting a typo or editing error that makes it unclear. Correct the title to remove the redundancy; suggest 'Add validation loss computation and expose in API' or similar.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@mergify mergify bot added the ci-failure label Feb 25, 2026
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
src/instructlab/training/sampler.py (1)

492-503: Consider extracting shared length-materialization logic.

The train/test branches duplicate the same "len" computation path. A tiny helper would reduce repetition and future drift risk.

Also applies to: 511-529

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/instructlab/training/sampler.py` around lines 492 - 503, The duplicated
length-materialization logic in the train/test branches should be extracted into
a small helper (e.g., a private class/static method named _materialize_lengths
or materialize_lengths) that takes the dataset and returns the numpy lengths
array; replace the repeated blocks that check "len" in dataset.column_names,
call dataset.map(...)[ "len" ] when missing, or use dataset["len"] when present,
with a single call to this helper, and assign its return to train_ds.lengths /
val_ds.lengths accordingly in the existing code paths (preserve use of
num_proc=8 and numpy conversion).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/instructlab/training/config.py`:
- Around line 344-352: Add a model-level validator on the TrainingArgs model to
validate validation_split and validation_frequency: ensure validation_split is
within [0.0, 1.0], and if validation_split > 0.0 then validation_frequency must
be non-None and > 0; implement this as a `@model_validator` (mode="before" or
"after" as appropriate for your pydantic version) inside the TrainingArgs class,
raise ValueError with clear messages referencing validation_split and
validation_frequency when checks fail so programmatic instantiation fails early.

In `@src/instructlab/training/main_ds.py`:
- Around line 146-156: Inside the validation loop remove per-micro-batch
torch.cuda.empty_cache() and the per-batch dist.barrier(); instead accumulate
batch_loss and total_tokens locally (keep using batch_loss variable and
batch[0]["batch_num_loss_counted_tokens"]) across micro-batches, then after
finishing the loop create tensors (e.g., total_loss_tensor and
total_tokens_tensor or reuse batch_loss_tensor) on device and call
dist.all_reduce(..., op=dist.ReduceOp.SUM) once to reduce the summed values,
finally set total_loss and total_tokens from the reduced tensor values—this
eliminates repeated global syncs while preserving SUM semantics for batch_loss
and token counts.

In `@src/instructlab/training/sampler.py`:
- Around line 552-553: The get_data_loader function now always returns a tuple
which breaks callers expecting a single DataLoader; change get_data_loader to
preserve backward compatibility by returning a single DataLoader by default and
only return (train_loader, val_loader) when an explicit flag (e.g.,
return_val_loader: bool = False) is passed; implement this by adding the
return_val_loader parameter to get_data_loader, branching the return path to
return train_loader when return_val_loader is False and (train_loader,
val_loader) when True (use the existing validation_split logic when building
val_loader), and update calling sites to opt in by passing
return_val_loader=True where they need the validation loader.

---

Nitpick comments:
In `@src/instructlab/training/sampler.py`:
- Around line 492-503: The duplicated length-materialization logic in the
train/test branches should be extracted into a small helper (e.g., a private
class/static method named _materialize_lengths or materialize_lengths) that
takes the dataset and returns the numpy lengths array; replace the repeated
blocks that check "len" in dataset.column_names, call dataset.map(...)[ "len" ]
when missing, or use dataset["len"] when present, with a single call to this
helper, and assign its return to train_ds.lengths / val_ds.lengths accordingly
in the existing code paths (preserve use of num_proc=8 and numpy conversion).

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 320b794 and dbde883.

📒 Files selected for processing (3)
  • src/instructlab/training/config.py
  • src/instructlab/training/main_ds.py
  • src/instructlab/training/sampler.py

@mergify mergify bot added the testing Relates to testing label Feb 25, 2026
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit/test_pretraining_sampler.py (1)

266-516: Add test coverage for the new validation_split code path.

All eight updated call sites pass the default validation_split=0.0, so the val_loader branch in get_data_loader (the new load_and_split classmethods, SequentialSampler-backed loader, and the out-of-range ValueError) remains entirely uncovered by the test suite. Consider adding tests for at least:

  1. validation_split > 0val_loader is not None, uses SequentialSampler, and its dataset length matches the expected fraction.
  2. validation_split = 0.0 (explicit) → val_loader is None.
  3. validation_split < 0 or validation_split >= 1.0ValueError is raised.
  4. Train and validation datasets are disjoint (no sample appears in both).
🧪 Suggested skeleton for the missing tests
`@patch`("instructlab.training.sampler.load_dataset")
def test_validation_split_creates_val_loader(
    self, mock_load_dataset, temp_pretraining_file
):
    """val_loader is non-None and uses SequentialSampler when validation_split > 0."""
    from torch.utils.data import SequentialSampler

    mock_ds = MagicMock()
    mock_ds.column_names = ["input_ids", "len"]
    mock_ds.__len__ = lambda self: 1
    mock_ds.__iter__ = lambda self: iter(
        [{"input_ids": list(range(100)), "len": 100}]
    )
    mock_load_dataset.return_value = mock_ds

    train_loader, val_loader = get_data_loader(
        data_path=temp_pretraining_file,
        batch_size=2,
        max_tokens_per_gpu=1000,
        seed=42,
        rank=0,
        world_size=1,
        pretraining_config=PretrainingConfig(block_size=10),
        validation_split=0.2,
    )

    assert val_loader is not None
    assert isinstance(val_loader.sampler, SequentialSampler)
    # train + val should cover the full dataset
    assert len(train_loader.dataset) + len(val_loader.dataset) == len(
        train_loader.dataset
    ) + len(val_loader.dataset)  # replace with concrete expected sizes


`@patch`("instructlab.training.sampler.load_dataset")
def test_validation_split_zero_returns_none_val_loader(
    self, mock_load_dataset, temp_pretraining_file
):
    """val_loader is None when validation_split=0.0."""
    mock_ds = MagicMock()
    mock_ds.column_names = ["input_ids", "len"]
    mock_ds.__len__ = lambda self: 1
    mock_ds.__iter__ = lambda self: iter(
        [{"input_ids": list(range(50)), "len": 50}]
    )
    mock_load_dataset.return_value = mock_ds

    _, val_loader = get_data_loader(
        data_path=temp_pretraining_file,
        batch_size=2,
        max_tokens_per_gpu=100,
        seed=42,
        rank=0,
        world_size=1,
        pretraining_config=PretrainingConfig(block_size=10),
        validation_split=0.0,
    )
    assert val_loader is None


`@patch`("instructlab.training.sampler.load_dataset")
`@pytest.mark.parametrize`("invalid_split", [-0.1, 1.0, 1.5])
def test_invalid_validation_split_raises(
    self, mock_load_dataset, temp_pretraining_file, invalid_split
):
    """ValueError raised for out-of-range validation_split."""
    mock_load_dataset.return_value = MagicMock(
        column_names=["input_ids", "len"],
        __len__=lambda self: 1,
        __iter__=lambda self: iter([{"input_ids": [1], "len": 1}]),
    )

    with pytest.raises(ValueError, match="validation_split"):
        get_data_loader(
            data_path=temp_pretraining_file,
            batch_size=2,
            max_tokens_per_gpu=100,
            seed=42,
            rank=0,
            world_size=1,
            pretraining_config=PretrainingConfig(block_size=5),
            validation_split=invalid_split,
        )

Would you like me to open a new issue to track adding this validation-split test coverage, or generate the full test implementations?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/test_pretraining_sampler.py` around lines 266 - 516, Add tests
covering the new validation_split code path: ensure
get_data_loader(validation_split>0) returns a non-None val_loader whose sampler
is a SequentialSampler and whose dataset sizes sum to the mocked dataset length
(use get_data_loader, validation_split, val_loader, and inspect loader.sampler
and loader.dataset); test validation_split==0.0 explicitly returns val_loader is
None; parametrize invalid splits (e.g., -0.1, 1.0, 1.5) to assert
get_data_loader raises ValueError mentioning validation_split; and add a test
that train and val datasets are disjoint (no overlapping samples) by mocking
load_dataset to return a deterministic dataset and comparing items from
loader.dataset and val_loader.dataset.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/unit/test_pretraining_sampler.py`:
- Around line 266-516: Add tests covering the new validation_split code path:
ensure get_data_loader(validation_split>0) returns a non-None val_loader whose
sampler is a SequentialSampler and whose dataset sizes sum to the mocked dataset
length (use get_data_loader, validation_split, val_loader, and inspect
loader.sampler and loader.dataset); test validation_split==0.0 explicitly
returns val_loader is None; parametrize invalid splits (e.g., -0.1, 1.0, 1.5) to
assert get_data_loader raises ValueError mentioning validation_split; and add a
test that train and val datasets are disjoint (no overlapping samples) by
mocking load_dataset to return a deterministic dataset and comparing items from
loader.dataset and val_loader.dataset.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dbde883 and eaf5a88.

📒 Files selected for processing (2)
  • tests/unit/test_pretraining_mode.py
  • tests/unit/test_pretraining_sampler.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 116-153: The bug is over-counting validation tokens because the
code adds batch[0]["batch_num_loss_counted_tokens"] for every minibatch
(variable batch comes from the sampler split) which duplicates the same
batch-level token count; instead, when computing total_tokens in main_ds.py use
per-minibatch token counts (e.g., compute valid_tokens.sum().item() from the
token mask for each minibatch inside the for mb in batch loop) or only add
batch[0]["batch_num_loss_counted_tokens"] once per original batch (not per
minibatch). Update the accumulation so total_tokens += per_minibatch_token_count
(derived from valid_tokens) and leave the distributed all_reduce on
tokens_tensor unchanged.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eaf5a88 and c66c310.

📒 Files selected for processing (3)
  • src/instructlab/training/config.py
  • src/instructlab/training/main_ds.py
  • src/instructlab/training/sampler.py

@RobotSail RobotSail requested a review from Maxusmusti February 26, 2026 04:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-failure testing Relates to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant