Implement RLHF DPO (Direct Preference Optimization) training#1403
Implement RLHF DPO (Direct Preference Optimization) training#1403BitcrushedHeart wants to merge 11 commits intoNerogar:masterfrom
Conversation
|
Any example images? |
Images now scale up to fill the display box (not just down), and prompts are shown in a collapsible expander instead of truncated.
- Fix failing export test to match 7-value return signature - Add ValueError for unsupported DPO reference modes - Remove dead hasattr check in reference_model - Close PIL file handles promptly in DPO curation window - Read only first 256KB for metadata extraction with fallback - Add config migration 14 for transfer_* fields - Add DPO loss math integration tests including beta=0 sanity check
SwarmUI embeds metadata in WebP EXIF UserComment as UTF-16LE (UNICODE character code), causing prompt extraction to fail. Fall back to UTF-16LE/BE decoding when UTF-8 finds no metadata.
Output folder is now selected at start. Each pick exports the chosen/rejected pair to disk instantly and updates a manifest.json, so closing mid-session loses no work. On restart, completed groups are auto-skipped via manifest lookup. Groups are shuffled for variety. Finalization (train/val split + concepts.json) is a separate optional step at the end.
Metadata scan and dhash dedup now run in a background thread, feeding a Queue(maxsize=10) of ready groups. The UI shows a live scan counter and transitions to picking as soon as the first group is ready. If the user outruns the dedup, a brief "Preparing..." spinner appears. Groups already completed in the manifest are skipped before deduping, saving the most expensive work.
After accepting a chosen/rejected pair, if there are still 2+ remaining images in the group, a Yes/No/Cancel dialog lets the user continue scoring the same prompt or move on. Works in both ELO and Selection modes. Cancel acts as an undo for mis-clicks.
Very interesting PR! I wanna test it, but right now ComfyUI metadata does not seem to work. If I select an input folder with ComfyUI-generated images (from a very basic SDXL workflow), I get "no images with extractable prompt metadata found". |
If you can send me an image on Discord from ComfyUI (any image, it doesn't matter), I can add this for you now and push it. If you have any from Forge as well I'll do the same - I don't have either of these so I don't know what they expect. |
Add save-best option that snapshots model weights when validation accuracy improves and restores them at end of training. Simplify patience to track accuracy only (remove DPOPatienceMode). Add review mode to the DPO pair tool with orphan detection and pair removal. Default execution mode is now Full Concurrent. Polish curation UI with card layout and progress bar.
5b55d90 to
e1b8b09
Compare
Parse plain-text A1111/Forge parameters (positive prompt before "Negative prompt:" marker) and ComfyUI workflow JSON (trace KSampler positive conditioning to CLIPTextEncode node, with fallback to longest text encoder output). Handles SDXL/SD3/Flux text encoder variants.
Forge & Comfy now supported in latest commit. |
|
I did some testing with the PR: (2) Test training with 15 pairs (NoobAI vpred) worked without error. I see minimal improvement of my concept in sample images. Tho the dpo/val_accuracy graph does not really reflect it. But I guess can’t expect that with literally only 1 val pair. |
Patience now uses DPOPatienceMode (EITHER/BOTH) and tracks val_loss alongside accuracy with rounding to 5 decimal places. Add visual pair review accessible from RLHF tab that merges train/val splits and supports orphan detection. Fix grab_set crash on Linux by adding wait_visibility before grab.
Fixed in 26e3b7a. Validation accuracy is binary - with 1 validation pair it is literally a coin flip. Val loss would be a better metric for you in this case. |
DPO in OneTrainer
What This Is
OneTrainer's DPO implementation lets you show the model two images for the same prompt - one you prefer, one you don't - and trains an adapter to produce more of the former. The reference model is kept adapter-sized (either the raw base or a frozen snapshot of your existing LoRA), so the whole thing runs on a single consumer GPU without doubling your VRAM budget.
It's model-agnostic. The loss function doesn't know or care whether you're training on Flux, SDXL, SD 1.5, Z-Image, or anything else OneTrainer supports. It hooks into the existing
predict()pipeline and works with whatever prediction targets the model already uses.How DPO Works Here
Standard DPO compares how much more the policy model prefers the chosen sample over the rejected one, relative to how much the reference model does. The loss is:
In language models, those scores are token log-probabilities. We don't have those. What we do have is the prediction error from the denoising objective - the same MSE that standard training minimises. A lower MSE means the model "agrees" more with that image at that timestep, so we use negative MSE as a score proxy:
This is reduced across all non-batch dimensions, giving one scalar per sample. The key insight is that this works regardless of prediction type - epsilon, v-prediction, flow velocity - because
predict()already returns the right(predicted, target)pair for each model family. The DPO loss never needs to know which one it's looking at.The default beta is 5000, which looks absurdly high compared to the 0.1-0.5 typical in LLM DPO. The reason is scale: MSE values are tiny compared to token log-probs, so you need a correspondingly larger beta to keep the sigmoid in a useful gradient range rather than having it saturate immediately.
Two optional extras sit on top of the base loss. Label smoothing interpolates between the DPO loss and its complement, which helps when your preference labels are noisy (i.e. you weren't fully sure which image was better). Supervised mix adds a standard training loss from the chosen image to prevent the adapter drifting too far from what it already knows. The supervised term reuses the chosen policy forward pass, so it doesn't cost an extra forward pass.
Training Types
There are two modes, and you don't pick them directly - the code infers which one you're using from whether you've loaded a base adapter.
New Adapter means no base adapter is loaded. The reference model is the raw base with no adapter applied. The policy is a fresh adapter training from scratch. Use this for general quality preferences - "prefer sharp images over blurry ones" - where you don't have a prior fine-tune to build on.
Existing Adapter means you've loaded a supervised LoRA/OFT checkpoint. The reference model is a frozen snapshot of those adapter weights taken at the start of training. The policy starts from the same weights and diverges during DPO. Use this when you've already got a character or style LoRA and want to refine its outputs - "my character LoRA is good but sometimes fumbles hands, so here's examples of good hands vs bad hands."
The output is always an adapter file in both cases.
The Reference Model
This is the part that makes DPO practical on hobbyist hardware.
For New Adapter mode, the reference pass just temporarily unhooks the adapter from the model, runs a forward pass through the raw base weights, and hooks the adapter back in. Simple.
For Existing Adapter mode, the implementation clones every adapter parameter tensor once at the start of training - this is the frozen reference snapshot. During each reference forward pass, it swaps
param.datapointers from the live training weights to the frozen snapshot, runs the forward pass, and swaps back in afinallyblock. This is O(1) pointer assignment per parameter, not a copy. The base model weights are shared between policy and reference the whole time. The only duplicated data is the adapter tensors themselves - typically 50-100MB for a LoRA.No second model is loaded. No base weights are duplicated. The "reference model" is a list of frozen tensors and a pointer swap.
Forward Pass Scheduling
DPO needs four forward passes per training step: reference-chosen, reference-rejected, policy-chosen, policy-rejected. How you schedule those is a VRAM/speed tradeoff exposed through the execution mode setting.
Sequential (default): Run all four one at a time, deleting activations between each. Both reference passes are
no_gradso they don't retain activations anyway. The policy chosen pass runs, its output is deleted, then the policy rejected pass runs. Peak VRAM is roughly the same as standard LoRA training. Slowest, but fits on 24GB.Policy Concurrent: Reference passes still run sequentially and get cleaned up. Both policy passes keep their activations alive simultaneously - the chosen output isn't deleted before the rejected pass starts. This saves recomputation during backward at the cost of holding two gradient-tracked passes in memory. Uses more VRAM than Sequential.
Full Concurrent: All four outputs stay alive until the scores are computed. Fastest scheduling, highest memory. Uses considerably more VRAM than either of the other modes.
The execution mode doesn't change the maths. Same loss, same gradients, same result - just different memory/speed profiles.
Shared Noise
When shared noise is on (the default), the chosen and rejected forward passes use the same timestep and the same noise. This is achieved through a slightly indirect mechanism: both passes receive
TrainProgressobjects with the sameglobal_step, which seeds the RNG identically insidepredict().The reasoning is simple - if chosen and rejected get different noise draws, some of the gradient signal comes from the noise difference rather than the preference difference. Sharing noise isolates the preference signal.
When shared noise is off, the rejected pass gets
global_step + 1, giving it an independent RNG stream. This can be useful for more diverse gradient signals at the cost of noisier loss curves.Worth noting: this mechanism depends on
predict()seeding its RNG fromglobal_step. If that ever changes, the noise sharing breaks silently. There's a comment in the code about this, but it's a real coupling rather than a guaranteed interface.Data Pipeline
Concept Types
DPO uses four explicit concept types:
DPO_CHOSEN,DPO_REJECTED,DPO_CHOSEN_VAL, andDPO_REJECTED_VAL. There's no adjacency-based inference and no fallback to standard concepts. If you want DPO, you configure DPO concepts.Pair resolution is strict. The code collects all enabled chosen and rejected concepts, errors if both sides aren't present, errors if the counts don't match, and zips them in config order.
Filename Matching
Within each concept pair,
PairByFilenamematches samples by filename stem relative to the concept root, with extensions stripped and path separators normalised. This allows subdirectory structures as long as they mirror each other.The pairing module fails fast if it finds unmatched files on either side, if prompts differ between matched pairs, or if crop resolutions differ. DPO pairs are supposed to differ only in quality, not in what they depict or how they're framed.
Augmentation Suppression
All augmentations are disabled for DPO concepts - both in the runtime dataloader and in the concept editor UI. This includes image and text variations, crop jitter, flips, rotation, colour adjustments, mask transforms, tag shuffling, tag dropout, and capitalisation randomisation.
The logic is simple: if chosen and rejected images get different augmentations applied, you're teaching the model to prefer augmentation A over augmentation B, not to prefer the image quality you actually selected for. Augmentations are disabled entirely rather than synchronised across pairs, because synchronisation would add complexity for minimal benefit.
Curation Tool
The DPO Pair Tool (launched from the Tools tab) handles preference pair creation from generated images.
It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata, raw byte scanning for JPEG/WebP - and groups images by exact
(prompt, aspectratio)match. From there it supports two scoring modes.ELO mode shows two images side by side and lets you pick a winner. Keyboard driven - left arrow, right arrow, down to skip. Adaptive pairing prefers images with similar ratings for more informative comparisons.
Direct selection mode shows all images in a group and lets you pick the best and worst directly. Faster, less rigorous.
A "Pairs per group" setting (default 1) controls how many pairs are extracted per prompt group before advancing. Images are removed from the pool after being used in a pair, so there's no reuse within a group. If fewer than two unused images remain, the tool advances early.
Export produces
chosen/train/,chosen/val/,rejected/train/,rejected/val/with matched filenames, caption text files, and aconcepts.jsonready for OneTrainer. Validation splitting is done at the prompt-group level so no prompt appears in both splits.Validation
When DPO validation is enabled, the trainer doesn't fall back to per-concept reconstruction MSE. Instead it runs
calculate_dpo_loss()undertorch.no_grad()on the validation pairs and logs:dpo/val_loss- the raw DPO loss term (not the mixed training loss)dpo/val_accuracy- fraction of val pairs where the policy prefers chosen over rejected more than the reference doesdpo/val_chosen_reward- mean chosen preference ratiodpo/val_rejected_reward- mean rejected preference ratioOne caveat:
dpo/val_losslogs the pure DPO term, not the total loss including supervised mix. If you're running withsupervised_mix > 0, the training loss and validation loss are measuring slightly different things. This doesn't affect early stopping since patience is based on accuracy and chosen reward, not val_loss directly.Early Stopping
Patience is tracked against
dpo/val_accuracyanddpo/val_chosen_reward, not reconstruction MSE.Both best values start at negative infinity. A metric is considered stalled when the new value is less than or equal to the best seen so far. Two patience modes:
Any improvement resets the counter. When the counter hits the configured patience value, training stops.
This is deliberately tied to the DPO validation metrics rather than MSE because you can have both chosen and rejected MSE decreasing (the model gets better at denoising everything) while the preference separation stays flat or degrades. MSE trends tell you about reconstruction quality; DPO metrics tell you about preference learning. Early stopping should care about the latter.
TensorBoard Metrics
Training logs:
loss/train_step,loss/dpo(total after smoothing + supervised mix),dpo/raw_loss(pure DPO term),dpo/chosen_reward,dpo/rejected_reward,dpo/accuracy.Validation logs:
dpo/val_loss,dpo/val_accuracy,dpo/val_chosen_reward,dpo/val_rejected_reward.Diagnostic interpretation: accuracy approaching 1.0 at initialisation means your pairs are too easy - the policy already prefers chosen over rejected relative to the reference before any training has happened. Beta doesn't affect accuracy directly since accuracy is a pure comparison on the ratios, not the scaled logits. Accuracy stuck at 0.5 means the model isn't learning the preference - check data quality. If
dpo/raw_lossdiverges whileloss/dpolooks stable, the supervised mix is masking a problem. At beta=0, the loss should equal log(2) regardless of model state.Startup Validation
If RLHF is enabled but the configuration is wrong, the trainer fails during initialisation rather than mid-run.
The training method gate checks that you're using LoRA - DPO with full fine-tuning raises
NotImplementedErrorimmediately. Missing or mismatched DPO concepts raiseRuntimeErrorwhen the dataloader is built, which happens before any training step. Unmatched filenames raiseRuntimeErrorduring dataset initialisation. There's also a defensive check incalculate_dpo_loss()for missing rejected latents, though in practice this shouldn't be reachable if the dataloader is configured correctly.All four failure paths produce explicit exceptions during startup. None of them result in silent skipping or unhandled mid-training crashes.
Configuration
The DPO config fields, all prefixed
rlhf_:rlhf_enabled- master togglerlhf_mode- currently onlyDPO, exists for future extensibilityrlhf_dpo_beta- temperature, default 5000rlhf_dpo_label_smoothing- for noisy preference labels, default 0rlhf_dpo_ref_mode-NEW_ADAPTERorEXISTING_ADAPTER, auto-derived from whether a base adapter is loadedrlhf_dpo_execution_mode-SEQUENTIAL,POLICY_CONCURRENT, orFULL_CONCURRENTrlhf_supervised_mix- weight for supervised loss on chosen images, default 0rlhf_dpo_shared_noise- share noise between chosen/rejected, default truerlhf_dpo_validation- enable DPO-specific validationrlhf_dpo_validation_percentage- train/val split percentage, default 10rlhf_dpo_patience_enabled- enable early stoppingrlhf_dpo_patience_value- number of stalled validations before stoppingrlhf_dpo_patience_mode-EITHERorBOTHrlhf_dpo_ref_modeis resynced fromeffective_dpo_ref_mode()on every config save/load, so it stays aligned with whether a base adapter is loaded regardless of what's in the JSON.Testing
Unit tests covered the basics - enum contents, config round-tripping through JSON, loss behaviour including the beta=0 producing log(2) sanity check, all three execution modes returning finite loss and metrics, no-grad validation metrics, concept type semantics, and patience logic for both EITHER and BOTH modes.
The heavier coverage, which runs all six combinations of training type x execution mode (NEW_ADAPTER and EXISTING_ADAPTER x SEQUENTIAL, POLICY_CONCURRENT, FULL_CONCURRENT) using tiny fake tensors and a dummy adapter model on CPU. It checked finite loss and backward, adapter gradients, all five training metrics, reference snapshot integrity under Existing Adapter mode, TensorBoard event tags, and hot-swap correctness. No CUDA required.
Constraints
DPO is adapter-only in the LoRA path. The output is always an adapter file. Training type is derived, not user-selected. Pairing requires explicit chosen/rejected concepts with matching filenames, prompts, and crop resolutions. Augmentations are disabled, not synchronised. Validation uses DPO metrics directly. There's no reward model, no external reference path, and no full-parameter DPO.
Code Touchpoints
The implementation lives across:
BaseModelSetup.py- loss function, reference model context manager, metric cachingGenericTrainer.py- DPO loss branching, validation, patience, TensorBoard loggingTrainConfig.py- config fields, effective mode derivation, serialisationDPORefMode.py,DPOExecutionMode.py,ConceptType.py- enumsdpo_curation_util.py- export, pair checking, shared utilitiesPairByFilename.py- filename-based sample pairingDataLoaderMgdsMixin.py- augmentation suppression, concept type filteringDataLoaderText2ImageMixin.py- DPO output module construction, pair resolutionRLHFTab.py- settings UIDPOCurationWindow.py- pair curation toolConceptWindow.py- DPO concept type handling in the editor