Skip to content

Implement RLHF DPO (Direct Preference Optimization) training#1403

Open
BitcrushedHeart wants to merge 11 commits intoNerogar:masterfrom
BitcrushedHeart:RLHF
Open

Implement RLHF DPO (Direct Preference Optimization) training#1403
BitcrushedHeart wants to merge 11 commits intoNerogar:masterfrom
BitcrushedHeart:RLHF

Conversation

@BitcrushedHeart
Copy link
Copy Markdown
Contributor

DPO in OneTrainer

What This Is

OneTrainer's DPO implementation lets you show the model two images for the same prompt - one you prefer, one you don't - and trains an adapter to produce more of the former. The reference model is kept adapter-sized (either the raw base or a frozen snapshot of your existing LoRA), so the whole thing runs on a single consumer GPU without doubling your VRAM budget.

It's model-agnostic. The loss function doesn't know or care whether you're training on Flux, SDXL, SD 1.5, Z-Image, or anything else OneTrainer supports. It hooks into the existing predict() pipeline and works with whatever prediction targets the model already uses.

How DPO Works Here

Standard DPO compares how much more the policy model prefers the chosen sample over the rejected one, relative to how much the reference model does. The loss is:

L = -log(sigmoid(beta * ((score_policy(chosen) - score_ref(chosen))
                        - (score_policy(rejected) - score_ref(rejected)))))

In language models, those scores are token log-probabilities. We don't have those. What we do have is the prediction error from the denoising objective - the same MSE that standard training minimises. A lower MSE means the model "agrees" more with that image at that timestep, so we use negative MSE as a score proxy:

score(image) = -mean((predicted - target)^2)

This is reduced across all non-batch dimensions, giving one scalar per sample. The key insight is that this works regardless of prediction type - epsilon, v-prediction, flow velocity - because predict() already returns the right (predicted, target) pair for each model family. The DPO loss never needs to know which one it's looking at.

The default beta is 5000, which looks absurdly high compared to the 0.1-0.5 typical in LLM DPO. The reason is scale: MSE values are tiny compared to token log-probs, so you need a correspondingly larger beta to keep the sigmoid in a useful gradient range rather than having it saturate immediately.

Two optional extras sit on top of the base loss. Label smoothing interpolates between the DPO loss and its complement, which helps when your preference labels are noisy (i.e. you weren't fully sure which image was better). Supervised mix adds a standard training loss from the chosen image to prevent the adapter drifting too far from what it already knows. The supervised term reuses the chosen policy forward pass, so it doesn't cost an extra forward pass.

Training Types

There are two modes, and you don't pick them directly - the code infers which one you're using from whether you've loaded a base adapter.

New Adapter means no base adapter is loaded. The reference model is the raw base with no adapter applied. The policy is a fresh adapter training from scratch. Use this for general quality preferences - "prefer sharp images over blurry ones" - where you don't have a prior fine-tune to build on.

Existing Adapter means you've loaded a supervised LoRA/OFT checkpoint. The reference model is a frozen snapshot of those adapter weights taken at the start of training. The policy starts from the same weights and diverges during DPO. Use this when you've already got a character or style LoRA and want to refine its outputs - "my character LoRA is good but sometimes fumbles hands, so here's examples of good hands vs bad hands."

The output is always an adapter file in both cases.

The Reference Model

This is the part that makes DPO practical on hobbyist hardware.

For New Adapter mode, the reference pass just temporarily unhooks the adapter from the model, runs a forward pass through the raw base weights, and hooks the adapter back in. Simple.

For Existing Adapter mode, the implementation clones every adapter parameter tensor once at the start of training - this is the frozen reference snapshot. During each reference forward pass, it swaps param.data pointers from the live training weights to the frozen snapshot, runs the forward pass, and swaps back in a finally block. This is O(1) pointer assignment per parameter, not a copy. The base model weights are shared between policy and reference the whole time. The only duplicated data is the adapter tensors themselves - typically 50-100MB for a LoRA.

No second model is loaded. No base weights are duplicated. The "reference model" is a list of frozen tensors and a pointer swap.

Forward Pass Scheduling

DPO needs four forward passes per training step: reference-chosen, reference-rejected, policy-chosen, policy-rejected. How you schedule those is a VRAM/speed tradeoff exposed through the execution mode setting.

Sequential (default): Run all four one at a time, deleting activations between each. Both reference passes are no_grad so they don't retain activations anyway. The policy chosen pass runs, its output is deleted, then the policy rejected pass runs. Peak VRAM is roughly the same as standard LoRA training. Slowest, but fits on 24GB.

Policy Concurrent: Reference passes still run sequentially and get cleaned up. Both policy passes keep their activations alive simultaneously - the chosen output isn't deleted before the rejected pass starts. This saves recomputation during backward at the cost of holding two gradient-tracked passes in memory. Uses more VRAM than Sequential.

Full Concurrent: All four outputs stay alive until the scores are computed. Fastest scheduling, highest memory. Uses considerably more VRAM than either of the other modes.

The execution mode doesn't change the maths. Same loss, same gradients, same result - just different memory/speed profiles.

Shared Noise

When shared noise is on (the default), the chosen and rejected forward passes use the same timestep and the same noise. This is achieved through a slightly indirect mechanism: both passes receive TrainProgress objects with the same global_step, which seeds the RNG identically inside predict().

The reasoning is simple - if chosen and rejected get different noise draws, some of the gradient signal comes from the noise difference rather than the preference difference. Sharing noise isolates the preference signal.

When shared noise is off, the rejected pass gets global_step + 1, giving it an independent RNG stream. This can be useful for more diverse gradient signals at the cost of noisier loss curves.

Worth noting: this mechanism depends on predict() seeding its RNG from global_step. If that ever changes, the noise sharing breaks silently. There's a comment in the code about this, but it's a real coupling rather than a guaranteed interface.

Data Pipeline

Concept Types

DPO uses four explicit concept types: DPO_CHOSEN, DPO_REJECTED, DPO_CHOSEN_VAL, and DPO_REJECTED_VAL. There's no adjacency-based inference and no fallback to standard concepts. If you want DPO, you configure DPO concepts.

Pair resolution is strict. The code collects all enabled chosen and rejected concepts, errors if both sides aren't present, errors if the counts don't match, and zips them in config order.

Filename Matching

Within each concept pair, PairByFilename matches samples by filename stem relative to the concept root, with extensions stripped and path separators normalised. This allows subdirectory structures as long as they mirror each other.

The pairing module fails fast if it finds unmatched files on either side, if prompts differ between matched pairs, or if crop resolutions differ. DPO pairs are supposed to differ only in quality, not in what they depict or how they're framed.

Augmentation Suppression

All augmentations are disabled for DPO concepts - both in the runtime dataloader and in the concept editor UI. This includes image and text variations, crop jitter, flips, rotation, colour adjustments, mask transforms, tag shuffling, tag dropout, and capitalisation randomisation.

The logic is simple: if chosen and rejected images get different augmentations applied, you're teaching the model to prefer augmentation A over augmentation B, not to prefer the image quality you actually selected for. Augmentations are disabled entirely rather than synchronised across pairs, because synchronisation would add complexity for minimal benefit.

Curation Tool

The DPO Pair Tool (launched from the Tools tab) handles preference pair creation from generated images.

It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata, raw byte scanning for JPEG/WebP - and groups images by exact (prompt, aspectratio) match. From there it supports two scoring modes.

ELO mode shows two images side by side and lets you pick a winner. Keyboard driven - left arrow, right arrow, down to skip. Adaptive pairing prefers images with similar ratings for more informative comparisons.

Direct selection mode shows all images in a group and lets you pick the best and worst directly. Faster, less rigorous.

A "Pairs per group" setting (default 1) controls how many pairs are extracted per prompt group before advancing. Images are removed from the pool after being used in a pair, so there's no reuse within a group. If fewer than two unused images remain, the tool advances early.

Export produces chosen/train/, chosen/val/, rejected/train/, rejected/val/ with matched filenames, caption text files, and a concepts.json ready for OneTrainer. Validation splitting is done at the prompt-group level so no prompt appears in both splits.

Validation

When DPO validation is enabled, the trainer doesn't fall back to per-concept reconstruction MSE. Instead it runs calculate_dpo_loss() under torch.no_grad() on the validation pairs and logs:

  • dpo/val_loss - the raw DPO loss term (not the mixed training loss)
  • dpo/val_accuracy - fraction of val pairs where the policy prefers chosen over rejected more than the reference does
  • dpo/val_chosen_reward - mean chosen preference ratio
  • dpo/val_rejected_reward - mean rejected preference ratio

One caveat: dpo/val_loss logs the pure DPO term, not the total loss including supervised mix. If you're running with supervised_mix > 0, the training loss and validation loss are measuring slightly different things. This doesn't affect early stopping since patience is based on accuracy and chosen reward, not val_loss directly.

Early Stopping

Patience is tracked against dpo/val_accuracy and dpo/val_chosen_reward, not reconstruction MSE.

Both best values start at negative infinity. A metric is considered stalled when the new value is less than or equal to the best seen so far. Two patience modes:

  • Either: increment patience counter if accuracy stalls OR chosen reward stalls
  • Both: increment only if both stall simultaneously

Any improvement resets the counter. When the counter hits the configured patience value, training stops.

This is deliberately tied to the DPO validation metrics rather than MSE because you can have both chosen and rejected MSE decreasing (the model gets better at denoising everything) while the preference separation stays flat or degrades. MSE trends tell you about reconstruction quality; DPO metrics tell you about preference learning. Early stopping should care about the latter.

TensorBoard Metrics

Training logs: loss/train_step, loss/dpo (total after smoothing + supervised mix), dpo/raw_loss (pure DPO term), dpo/chosen_reward, dpo/rejected_reward, dpo/accuracy.

Validation logs: dpo/val_loss, dpo/val_accuracy, dpo/val_chosen_reward, dpo/val_rejected_reward.

Diagnostic interpretation: accuracy approaching 1.0 at initialisation means your pairs are too easy - the policy already prefers chosen over rejected relative to the reference before any training has happened. Beta doesn't affect accuracy directly since accuracy is a pure comparison on the ratios, not the scaled logits. Accuracy stuck at 0.5 means the model isn't learning the preference - check data quality. If dpo/raw_loss diverges while loss/dpo looks stable, the supervised mix is masking a problem. At beta=0, the loss should equal log(2) regardless of model state.

Startup Validation

If RLHF is enabled but the configuration is wrong, the trainer fails during initialisation rather than mid-run.

The training method gate checks that you're using LoRA - DPO with full fine-tuning raises NotImplementedError immediately. Missing or mismatched DPO concepts raise RuntimeError when the dataloader is built, which happens before any training step. Unmatched filenames raise RuntimeError during dataset initialisation. There's also a defensive check in calculate_dpo_loss() for missing rejected latents, though in practice this shouldn't be reachable if the dataloader is configured correctly.

All four failure paths produce explicit exceptions during startup. None of them result in silent skipping or unhandled mid-training crashes.

Configuration

The DPO config fields, all prefixed rlhf_:

  • rlhf_enabled - master toggle
  • rlhf_mode - currently only DPO, exists for future extensibility
  • rlhf_dpo_beta - temperature, default 5000
  • rlhf_dpo_label_smoothing - for noisy preference labels, default 0
  • rlhf_dpo_ref_mode - NEW_ADAPTER or EXISTING_ADAPTER, auto-derived from whether a base adapter is loaded
  • rlhf_dpo_execution_mode - SEQUENTIAL, POLICY_CONCURRENT, or FULL_CONCURRENT
  • rlhf_supervised_mix - weight for supervised loss on chosen images, default 0
  • rlhf_dpo_shared_noise - share noise between chosen/rejected, default true
  • rlhf_dpo_validation - enable DPO-specific validation
  • rlhf_dpo_validation_percentage - train/val split percentage, default 10
  • rlhf_dpo_patience_enabled - enable early stopping
  • rlhf_dpo_patience_value - number of stalled validations before stopping
  • rlhf_dpo_patience_mode - EITHER or BOTH

rlhf_dpo_ref_mode is resynced from effective_dpo_ref_mode() on every config save/load, so it stays aligned with whether a base adapter is loaded regardless of what's in the JSON.

Testing

Unit tests covered the basics - enum contents, config round-tripping through JSON, loss behaviour including the beta=0 producing log(2) sanity check, all three execution modes returning finite loss and metrics, no-grad validation metrics, concept type semantics, and patience logic for both EITHER and BOTH modes.

The heavier coverage, which runs all six combinations of training type x execution mode (NEW_ADAPTER and EXISTING_ADAPTER x SEQUENTIAL, POLICY_CONCURRENT, FULL_CONCURRENT) using tiny fake tensors and a dummy adapter model on CPU. It checked finite loss and backward, adapter gradients, all five training metrics, reference snapshot integrity under Existing Adapter mode, TensorBoard event tags, and hot-swap correctness. No CUDA required.

Constraints

DPO is adapter-only in the LoRA path. The output is always an adapter file. Training type is derived, not user-selected. Pairing requires explicit chosen/rejected concepts with matching filenames, prompts, and crop resolutions. Augmentations are disabled, not synchronised. Validation uses DPO metrics directly. There's no reward model, no external reference path, and no full-parameter DPO.

Code Touchpoints

The implementation lives across:

  • BaseModelSetup.py - loss function, reference model context manager, metric caching
  • GenericTrainer.py - DPO loss branching, validation, patience, TensorBoard logging
  • TrainConfig.py - config fields, effective mode derivation, serialisation
  • DPORefMode.py, DPOExecutionMode.py, ConceptType.py - enums
  • dpo_curation_util.py - export, pair checking, shared utilities
  • PairByFilename.py - filename-based sample pairing
  • DataLoaderMgdsMixin.py - augmentation suppression, concept type filtering
  • DataLoaderText2ImageMixin.py - DPO output module construction, pair resolution
  • RLHFTab.py - settings UI
  • DPOCurationWindow.py - pair curation tool
  • ConceptWindow.py - DPO concept type handling in the editor

@yamatazen
Copy link
Copy Markdown

Any example images?

Images now scale up to fill the display box (not just down), and
prompts are shown in a collapsible expander instead of truncated.
- Fix failing export test to match 7-value return signature
- Add ValueError for unsupported DPO reference modes
- Remove dead hasattr check in reference_model
- Close PIL file handles promptly in DPO curation window
- Read only first 256KB for metadata extraction with fallback
- Add config migration 14 for transfer_* fields
- Add DPO loss math integration tests including beta=0 sanity check
SwarmUI embeds metadata in WebP EXIF UserComment as UTF-16LE
(UNICODE character code), causing prompt extraction to fail.
Fall back to UTF-16LE/BE decoding when UTF-8 finds no metadata.
Output folder is now selected at start. Each pick exports the
chosen/rejected pair to disk instantly and updates a manifest.json,
so closing mid-session loses no work. On restart, completed groups
are auto-skipped via manifest lookup. Groups are shuffled for
variety. Finalization (train/val split + concepts.json) is a
separate optional step at the end.
Metadata scan and dhash dedup now run in a background thread,
feeding a Queue(maxsize=10) of ready groups. The UI shows a live
scan counter and transitions to picking as soon as the first group
is ready. If the user outruns the dedup, a brief "Preparing..."
spinner appears. Groups already completed in the manifest are
skipped before deduping, saving the most expensive work.
After accepting a chosen/rejected pair, if there are still 2+
remaining images in the group, a Yes/No/Cancel dialog lets the
user continue scoring the same prompt or move on. Works in both
ELO and Selection modes. Cancel acts as an undo for mis-clicks.
@Silvicultor
Copy link
Copy Markdown

It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata

Very interesting PR! I wanna test it, but right now ComfyUI metadata does not seem to work. If I select an input folder with ComfyUI-generated images (from a very basic SDXL workflow), I get "no images with extractable prompt metadata found".
Would be nice if either (or ideally both) ComfyUI or Forge metadata could be supported.

@BitcrushedHeart
Copy link
Copy Markdown
Contributor Author

It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata

Very interesting PR! I wanna test it, but right now ComfyUI metadata does not seem to work. If I select an input folder with ComfyUI-generated images (from a very basic SDXL workflow), I get "no images with extractable prompt metadata found". Would be nice if either (or ideally both) ComfyUI or Forge metadata could be supported.

If you can send me an image on Discord from ComfyUI (any image, it doesn't matter), I can add this for you now and push it. If you have any from Forge as well I'll do the same - I don't have either of these so I don't know what they expect.

Add save-best option that snapshots model weights when validation accuracy
improves and restores them at end of training. Simplify patience to track
accuracy only (remove DPOPatienceMode). Add review mode to the DPO pair
tool with orphan detection and pair removal. Default execution mode is now
Full Concurrent. Polish curation UI with card layout and progress bar.
@BitcrushedHeart BitcrushedHeart force-pushed the RLHF branch 2 times, most recently from 5b55d90 to e1b8b09 Compare April 5, 2026 17:14
Parse plain-text A1111/Forge parameters (positive prompt before
"Negative prompt:" marker) and ComfyUI workflow JSON (trace KSampler
positive conditioning to CLIPTextEncode node, with fallback to longest
text encoder output). Handles SDXL/SD3/Flux text encoder variants.
@BitcrushedHeart
Copy link
Copy Markdown
Contributor Author

It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata

Very interesting PR! I wanna test it, but right now ComfyUI metadata does not seem to work. If I select an input folder with ComfyUI-generated images (from a very basic SDXL workflow), I get "no images with extractable prompt metadata found". Would be nice if either (or ideally both) ComfyUI or Forge metadata could be supported.

Forge & Comfy now supported in latest commit.

@Silvicultor
Copy link
Copy Markdown

I did some testing with the PR:
(1) There is a bug in the UI: Currently only the ELO mode in the DPO pair tool in functional on Linux. In the direct selection mode you can’t view the images fully, if you try, the whole screen becomes white and OT Python process needs to be terminated to regain control.
Error log:

Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.12/tkinter/__init__.py", line 1967, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/venv/lib/python3.12/site-packages/customtkinter/windows/widgets/ctk_button.py", line 554, in _clicked
    self._command()
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/TrainUI.py", line 699, in open_dpo_curation_tool
    DPOCurationWindow(self)
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 75, in __init__
    self.grab_set()
  File "/usr/lib/python3.12/tkinter/__init__.py", line 963, in grab_set
    self.tk.call('grab', 'set', self._w)
_tkinter.TclError: grab failed: window not viewable
Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.12/tkinter/__init__.py", line 1967, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 542, in <lambda>
    label.bind("<Button-1>", lambda e, p=path: self._selection_preview(p))
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 552, in _selection_preview
    preview.grab_set()
  File "/usr/lib/python3.12/tkinter/__init__.py", line 963, in grab_set
    self.tk.call('grab', 'set', self._w)
_tkinter.TclError: grab failed: window not viewable

(2) Test training with 15 pairs (NoobAI vpred) worked without error. I see minimal improvement of my concept in sample images. Tho the dpo/val_accuracy graph does not really reflect it. But I guess can’t expect that with literally only 1 val pair.

Patience now uses DPOPatienceMode (EITHER/BOTH) and tracks val_loss
alongside accuracy with rounding to 5 decimal places. Add visual
pair review accessible from RLHF tab that merges train/val splits
and supports orphan detection. Fix grab_set crash on Linux by adding
wait_visibility before grab.
@BitcrushedHeart
Copy link
Copy Markdown
Contributor Author

I did some testing with the PR: (1) There is a bug in the UI: Currently only the ELO mode in the DPO pair tool in functional on Linux. In the direct selection mode you can’t view the images fully, if you try, the whole screen becomes white and OT Python process needs to be terminated to regain control. Error log:

Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.12/tkinter/__init__.py", line 1967, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/venv/lib/python3.12/site-packages/customtkinter/windows/widgets/ctk_button.py", line 554, in _clicked
    self._command()
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/TrainUI.py", line 699, in open_dpo_curation_tool
    DPOCurationWindow(self)
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 75, in __init__
    self.grab_set()
  File "/usr/lib/python3.12/tkinter/__init__.py", line 963, in grab_set
    self.tk.call('grab', 'set', self._w)
_tkinter.TclError: grab failed: window not viewable
Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.12/tkinter/__init__.py", line 1967, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 542, in <lambda>
    label.bind("<Button-1>", lambda e, p=path: self._selection_preview(p))
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 552, in _selection_preview
    preview.grab_set()
  File "/usr/lib/python3.12/tkinter/__init__.py", line 963, in grab_set
    self.tk.call('grab', 'set', self._w)
_tkinter.TclError: grab failed: window not viewable

(2) Test training with 15 pairs (NoobAI vpred) worked without error. I see minimal improvement of my concept in sample images. Tho the dpo/val_accuracy graph does not really reflect it. But I guess can’t expect that with literally only 1 val pair.

Fixed in 26e3b7a.

Validation accuracy is binary - with 1 validation pair it is literally a coin flip. Val loss would be a better metric for you in this case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants