The enormous PR to support larger DLMs and a minimal eval harness#36
Draft
dhruvdcoder wants to merge 4 commits into
Draft
The enormous PR to support larger DLMs and a minimal eval harness#36dhruvdcoder wants to merge 4 commits into
dhruvdcoder wants to merge 4 commits into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR: LLM Evaluation Support for xlm-core
Branch:
dhruvesh/LLMsThis PR extends xlm-core to support evaluation of large language models (LLMs), primarily targeting decoder-only models like Dream. The changes make the training harness flexible enough to handle eval-only runs (no loss function, no training metrics) and add concrete support for the MATH-500 benchmark as a first downstream task.
Summary of Changes (
src/xlm/)1. Eval-Only Harness Support (
harness.py)The
Harnessclass previously assumed a loss function, diagnostic metrics, and reported metrics were always present. This is true for training but breaks eval-only workflows where we only need prediction + post-hoc scoring.Changes:
loss_functionis nowOptional[LossFunction]. Iflossis missing from the config, the harness skips loss instantiation. A guard oncompute_lossraises a clear error if an LM-loss dataloader is used without a configured loss.diagnostic_metricsandreported_metricsconfigs now default tonull.instantiate_metricsusesOmegaConf.selectwith adefault=Nonefallback so missing keys don't crash Hydra instantiation.torch.compilesetup is skipped whenloss_function is None, since there is nothing to compile in eval-only mode.compute_post_hoc_metricsno longer mutates the original predictions JSONL. Instead, results (per-sample scores + aggregated metrics) are written to a separateresults_epoch=…_step=….jsonfile. Theupdate_logged_predictionsparameter has been removed.dataloader_nameas a kwarg, enabling routing in composite evaluators.2.
EvalDatasetManager(datamodule.py)A new lightweight
EvalDatasetManagerclass for eval-only dataset splits.Motivation: The existing
DatasetManagercarries a lot of machinery for training (iterable datasets, manual disk caching, on-the-fly processors, DDP sharding). Eval datasets are small, map-style, and need none of that. The manual cache is particularly problematic for eval because cache paths don't distinguish tokenizer/model, leading to stale cache bugs when switching models.Key properties:
Dataset.mapauto-cache (controlled viaload_from_cache_file).prepare_dataviapreprocess_function.DatasetManager(prepare_data,setup,get_dataloader,set_epoch) soTextDataModulecan use it transparently.stagesparameter controls which Lightning stages this dataset participates in (defaults to["validate", "test"]).Config: A
default_eval.yamlHydra config is provided as a base for eval dataset managers.Type updates in
TextDataModule: Thedataset_managersdict now acceptsUnion[DatasetManager, EvalDatasetManager, UnconditionalGenerationDatasetManager], and all iteration sites (prepare_data,setup,set_epoch) have updated type annotations.3. MATH-500 Evaluation Task (
tasks/math500.py)A new evaluation task module for the MATH-500 benchmark.
Math500Eval(post-hoc evaluator):text(model generation) andtruth/answer(gold answer) from logged predictions.math_verify.parseto extract structured mathematical expressions andmath_verify.verifyto check equivalence.math_equal_at_1accuracy, parse failure counts, verify failure counts, and optional timing/step statistics.math500_preprocess_fn(dataset preprocessor):Problem: … Answer: …).prompt_ids/target_idsfor the prediction pipeline.answerfield through to predictions viaadditional_fields_from_batch.Configs:
datasets/math500.yaml—DatasetManager-based config for MATH-500.datasets/math500_test.yaml—EvalDatasetManager-based config (lighter-weight, recommended).4. Composite Post-Hoc Evaluator (
tasks/composite_eval.py)Warning: I did not tested this yet.
A new
CompositePostHocEvaluatorthat routeseval()calls to task-specific evaluators based on dataloader name.Motivation: When running eval on multiple tasks simultaneously (e.g. MATH-500 + GSM8k), each prediction dataloader needs different evaluation logic. The composite evaluator maps dataloader name patterns to evaluator instances.
Behavior:
{pattern: evaluator}pairs.eval(), matches thedataloader_nameagainst patterns (substring match) and delegates to the first matching evaluator.5. Sharded Safetensors Support (
utils/hf_hub.py)Extended the HF Hub weight loading pipeline to handle sharded safetensors checkpoints (common for large LLMs like Qwen-7B, LLaMA, etc.).
New download path:
download_model_weightsnow tries: (1) singlemodel.safetensors, (2) sharded safetensors (model.safetensors.index.json+ shard files), (3)pytorch_model.bin.Shard-by-shard loading:
_load_sharded_safetensors_into_modelloads one shard at a time into the model, keeping peak RAM at approximately model size + one shard (instead of model size + full checkpoint). A strict-mode check is performed at the end by comparing checkpoint vs model keys.Full state dict fallback:
_state_dict_from_safetensors_indexmerges all shards into a single state dict for the legacyload_model_state_dict_from_filepath (higher peak RAM, but needed for existing callers).6.
skip_init_weights&init_dtypefor Large Model Loading (commands/lightning_train.py,utils/model_loading.py)skip_init_weights: When loading pretrained weights, random initialization of the model is wasteful (and slow for large models). Settingskip_init_weights: truein the config wraps model instantiation intransformers.modeling_utils.no_init_weights(), skipping random init entirely. Only activates when a checkpoint path is provided.init_dtype: A new config key that setstorch.set_default_dtypeduring model construction, allowing models to be instantiated directly infloat16orbfloat16instead offloat32(important for fitting large models in GPU memory).Both are supported in both the training entrypoint (
lightning_train.py) and the inference loader (model_loading.py).7. Sampling Temperature Support (
utils/nn.py)sample_from_top_kandsample_from_top_pnow accept an optionaltemperatureparameter (default 1.0). Temperature scaling is applied before top-k/top-p filtering, following standard practice.8. Generalized Trailing Token Removal (
utils/text.py)remove_trailing_padsis generalized to strip arbitrary trailing special tokens, not justpad_token. It now accepts an optionaltokens_to_removelist and matches longest-suffix-first to handle cases like" </s>"vs"</s>".9.
Seq2SeqCollatorInputUpdate &additional_fields_from_batchFix (datamodule.py,log_predictions.py)Seq2SeqCollatorInput: Refactored to useNotRequired(PEP 655) forinput_idsandtarget_ids, supporting prompt-only batches (where the model generates the full response) alongside traditional prompt+target batches. Updated docstring clarifies the role of each field.additional_fields_from_batch: WhenLogPredictionsis configured withadditional_fields_from_batch(e.g. to carryanswerthrough for post-hoc eval), those fields are now automatically appended to each writer'sfields_to_keep_in_outputlist so they aren't dropped during output serialization.10. DatasetManager Cache Control (
datamodule.py)Added
preprocess_load_from_cache_fileandon_the_fly_load_from_cache_fileparameters toDatasetManager, passed through toDataset.map(load_from_cache_file=…). This allows disabling HF auto-cache when the preprocessing function has changed but the hash hasn't (common with tokenizer-dependent transforms).11. Miscellaneous
DeNovoEval/FragmentEval: Added**kwargstoeval()signatures for forward-compatibility withCompositePostHocEvaluator(which passesdataloader_name).diagnostic_metrics: nullandreported_metrics: nulladded toconfig.yamlso they don't need to be explicitly set in eval-only experiments.default_hf.yamlcallback config: New callback defaults preset (excludes HF-specific callbacks).requirements.txt: Addedmath-verifydependency.requirements/llm_eval.txt: New requirements file for LLM evaluation with the ANTLR4 variant of math-verify.Summary of Changes (
xlm-models/)12. Dream Model Architecture (
dream/)The Dream model type consists of two layers on top of the upstream
DreamModelCorebackbone (fromxlm.backbones.dream):DreamModel: Inherits fromDreamModelCore, setsconfig_class = DreamConfig, and serves as the pure forward-pass model (input_ids -> MaskedLMOutput). No generation logic lives in the model itself — all generation is handled by the predictor.DreamBackbone: Subclass ofDreamModelthat adapts the forward signature to the xlm-core MLM predictor protocol (x_t, attention_mask, positions -> logits). It converts a 2Dattention_maskto the 4D format Dream's attention layers expect and returns raw logits (notMaskedLMOutput). HF checkpoints (e.g. Dream-v0-Instruct-7B) load directly without key-prefix issues since the weight names are identical toDreamModelCore.13. Dream Predictor (
dream/predictor_dream.py)DreamPredictoris a fully-featured diffusion predictor (~600 lines) built on the same architecture asMLMPredictor. It implements the xlm-corePredictor[MLMBatch, MLMPredictionDict]protocol so that all harness features (logging, post-hoc evaluation, callbacks) work out of the box.Key features:
temperature,top_k,top_p, and greedy (temperature=0) sampling, using the samesample_from_top_k/sample_from_top_p/sample_from_logitsutilities as MLM.top_prob,prob_diff, andentropyconfidence metrics for position selection, with configurableconfidence_temperaturefor stochastic vs greedy position ordering.LogitsShiftBy1: A callable logits hook that shifts next-token logits to align with per-position predictions (Dream-stylelogits[i] = backbone_output[i+1]). Defined inmlm.predictor_mlmand re-exported fromdream/__init__.py.BlockAllowSchedule: Unmasking schedule that reveals one block of positions at a time (e.g. 16 tokens per block), preventing the model from attending to not-yet-ready future blocks. Requiresmax_new_tokensto be a multiple ofblock_sizeandblock_sizeto divide evenly into diffusion steps.generate_until/ early stopping: Accepts a list of stop strings (e.g.<|endoftext|>,"Problem:"). At each step, tokenized patterns are checked against the current output using an efficient batchedunfold-based match. Generation stops when any pattern appears in the output with no mask tokens before it. Aforce_fill_maxcounter ensures termination even if masks remain (defaults to 3 hits).max_new_tokens: Appends mask tokens to the prompt, enabling open-ended generation where the suffix length is not predetermined.flags.DEBUG_PRINT_PREDSis set, prints intermediate and final decoded outputs per step.time_takenandsteps_takenper sample in the prediction dict.14. MLM Predictor Enhancements (
mlm/predictor_mlm.py)Changes to the MLM predictor that mirror the Dream predictor's new features. These are backward-compatible — all new parameters have defaults that preserve existing behavior.
New parameters (all optional with backward-compatible defaults):
temperature(default1.0): Token sampling temperature.temperature=0enables greedy decoding viatop_k=1.confidence_temperature(default1.0): Controls stochastic vs greedy position selection when using confidence-based unmasking. Previously only threshold-based selection was supported.logits_hook(defaultNone): Arbitrary transform on model logits before sampling. Enables Dream-style shifted logits on MLM models without code changes.Refactored internals:
_require_tokenizer(): Lazy tokenizer validation (replaces eager check in__init__), allowingtokenizer=Noneat construction time (Harness sets it later)._compute_confidence(): Factored out confidence score computation (shared with Dream predictor).confidence + threshold) for backward compatibility, and (2) new score-based (confidencewithoutthreshold) usingselect_random_indiceswith confidence scores and temperature. Thethreshold is Nonecheck was previously an error; it now routes to the new path.attention_maskhandling: Uses.get("attention_mask")with a fallback to all-ones, instead of assuming the key exists. Casts tobooldtype explicitly.flags.DEBUG_PRINT_PREDSsupport as Dream predictor.LogitsShiftBy1: New class (also inmlm/predictor_mlm.py) — see description in Dream section above. Exported frommlm/__init__.py.15. Seq2Seq Collator Flexibility (
mlm/datamodule_mlm.py)The seq2seq collators (
MLMSeq2SeqTrainCollator,MLMSeq2SeqCollator,MLMSeq2SeqPredCollator,_MLMSeq2SeqPredCollator) previously hard-coded the field namesprompt_idsandinput_ids. They now support configurable field names and pass-through fields.New collator parameters (all with backward-compatible defaults):
prompt_field(default"prompt_ids"): Key for the prompt token ids in each example.target_field(default"target_ids"): Key for the suffix/target token ids. Falls back toinput_idswhentarget_field="target_ids"and the row has notarget_idskey — this preserves backward compatibility with existing datasets that useinput_idsas the suffix field.pass_through_fields(default["answer", "target"]): List of keys to copy from the raw examples into the batch dict unmodified. This is how gold answers (e.g. MATH-500answer) flow from the dataset through collation into the batch soLogPredictions.additional_fields_from_batchcan pick them up.seq2seq_suffix_ids()helper: Centralized suffix extraction with thetarget_ids->input_idsfallback logic.Prompt-only batches in
MLMSeq2SeqPredCollator: When every example in the batch has an empty suffix (e.g. open-ended generation with no target), the collator now returns a batch withouttarget_idsinstead of crashing. Mixed batches (some empty, some non-empty suffixes) raiseValueError.print_batch_mlm: Now guardstarget_idsaccess withif "target_ids" in batch, avoidingKeyErroron prompt-only batches.Batch construction: Collators now build batch dicts via plain
dictliterals +_merge_pass_through()instead of callingMLMBatch(...)directly. This is needed because pass-through fields (likeanswer) are not part of theMLMBatchTypedDict.16.
MLMBatch/MLMSeq2SeqPredictionBatchType Updates (mlm/types_mlm.py)MLMSeq2SeqPredictionBatch.target_idsis nowNotRequired(PEP 655), reflecting that prediction batches may be prompt-only (no suffix to predict against).NotRequiredimport withtyping_extensionsfallback for Python < 3.11.17. Dream Model Registration & Configs
xlm_models.json: Added"dream": "dream"entry so the Dream model type is discoverable by the xlm-core plugin system.New Hydra configs:
model_type/dream_base.yaml— Shared base: setsHarnessas the lightning module,DreamPredictoras the predictor withLogitsShiftBy1logits hook.model_type/dream.yaml— Training variant: composesdream_base+accumulated_lossreported metrics for train/val/test. Existing config, now simplified to a 6-line file composingdream_base.model_type/dream_eval.yaml— Eval-only variant: composesdream_basewith no additional metrics.model/dream_7b.yaml— Model config forDreamBackbonematching the Dream-v0-Instruct-7B architecture (Qwen2-based, 28 layers, GQA, RoPE).collator/math500_pred_dream.yaml—MLMSeq2SeqPredCollatorconfigured for MATH-500 withpass_through_fields: [answer, target].datamodule/math500_dream.yaml— Datamodule composing the MATH-500 eval dataset with the Dream collator for val and test splits.experiment/math500_dream_eval.yaml— Full experiment config for MATH-500 eval on Dream-7B: loads the Dream tokenizer, setsinit_dtype: bfloat16,skip_init_weights: true, 512 diffusion steps, block-based unmasking (block_size=16), top_prob confidence, generate_until stop tokens, andMath500Evalpost-hoc evaluator.debug/math500_debug.yaml— Debug variant with reduced parameters for local testing.dream/__init__.py: ExportsDreamConfig,DreamModel,DreamBackbone,DreamPredictor,LogitsShiftBy1, andprint_batch_dream.dream/datamodule_dream.py: New file withprint_batch_dream— a debug batch printer that logs the first example's decoded prompt for Dream/MLM prediction batches.18. Backward Compatibility Summary (MLM)
All MLM changes are designed to be backward-compatible. The table below summarizes the risks:
confidencewithoutthresholdno longer errorsconfidence+thresholdare unchangedMLMSeq2SeqPredCollatormay omittarget_idsdict()notMLMBatch()tokenizer=Noneallowed at init_require_tokenizer()raises if used before settarget_fielddefaults to"target_ids"withinput_idsfallbackinput_idsand fallback handles ittarget_field="input_ids"can be set if neededRunning debug Dream eval.
@Durga-Prasad1
Install math verify
pip install -r requirements/llm_eval.txtand the you can run: