Skip to content

Commit 3c76232

Browse files
fix: address code review issues in RLHF DPO implementation
- Fix failing export test to match 7-value return signature - Add ValueError for unsupported DPO reference modes - Remove dead hasattr check in reference_model - Close PIL file handles promptly in DPO curation window - Read only first 256KB for metadata extraction with fallback - Add config migration 14 for transfer_* fields - Add DPO loss math integration tests including beta=0 sanity check
1 parent cc0f749 commit 3c76232

6 files changed

Lines changed: 215 additions & 11 deletions

File tree

modules/modelSetup/BaseModelSetup.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,7 @@ def reference_model(self, model: BaseModel, config: TrainConfig):
313313
if len(adapters) == 0:
314314
raise RuntimeError("RLHF DPO requires active adapters, but no trainable adapters are attached to the current model.")
315315

316-
if hasattr(config, "effective_dpo_ref_mode"):
317-
ref_mode = config.effective_dpo_ref_mode()
318-
else:
319-
ref_mode = DPORefMode.EXISTING_ADAPTER if config.lora_model_name else DPORefMode.NEW_ADAPTER
316+
ref_mode = config.effective_dpo_ref_mode()
320317

321318
if ref_mode == DPORefMode.NEW_ADAPTER:
322319
for adapter in adapters:
@@ -349,6 +346,8 @@ def reference_model(self, model: BaseModel, config: TrainConfig):
349346
for adapter, policy_ptrs in zip(adapters, policy_data, strict=True):
350347
for param, policy_ptr in zip(adapter.parameters(), policy_ptrs, strict=True):
351348
param.data = policy_ptr
349+
else:
350+
raise ValueError(f"Unsupported DPO reference mode: {ref_mode}")
352351

353352
def _create_model_part_parameters(
354353
self,

modules/ui/DPOCurationWindow.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,8 @@ def _build_selection_ui(self):
308308
def _display_thumbnail(self, master, path: str, row: int, col: int):
309309
thumb_size = 250
310310
try:
311-
pil_img = Image.open(path)
312-
pil_img = self._fit_image(pil_img, thumb_size, thumb_size)
311+
with Image.open(path) as _raw:
312+
pil_img = self._fit_image(_raw, thumb_size, thumb_size).copy()
313313
ctk_img = ctk.CTkImage(light_image=pil_img, size=pil_img.size)
314314

315315
label = ctk.CTkLabel(master, text="", image=ctk_img)
@@ -328,9 +328,9 @@ def _selection_preview(self, path: str):
328328
preview.focus_set()
329329

330330
try:
331-
pil_img = Image.open(path)
332331
sw, sh = preview.winfo_screenwidth(), preview.winfo_screenheight()
333-
pil_img = self._fit_image(pil_img, sw, sh - 50)
332+
with Image.open(path) as _raw:
333+
pil_img = self._fit_image(_raw, sw, sh - 50).copy()
334334
ctk_img = ctk.CTkImage(light_image=pil_img, size=pil_img.size)
335335

336336
label = ctk.CTkLabel(preview, text="", image=ctk_img)
@@ -416,13 +416,13 @@ def _toggle():
416416

417417
def _display_image(self, master, path: str, row: int, col: int):
418418
try:
419-
pil_img = Image.open(path)
420419
self.update_idletasks()
421420
win_w = self.winfo_width() or self.winfo_screenwidth()
422421
win_h = self.winfo_height() or self.winfo_screenheight()
423422
max_w = max(400, win_w // 2 - 40)
424423
max_h = max(400, win_h - 200)
425-
pil_img = self._fit_image(pil_img, max_w, max_h)
424+
with Image.open(path) as _raw:
425+
pil_img = self._fit_image(_raw, max_w, max_h).copy()
426426
ctk_img = ctk.CTkImage(light_image=pil_img, size=pil_img.size)
427427

428428
label = ctk.CTkLabel(master, text="", image=ctk_img)

modules/util/config/TrainConfig.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def __init__(self, data: list[(str, Any, type, bool)]):
601601
11: self.__migration_11,
602602
12: self.__migration_12,
603603
13: self.__migration_13,
604+
14: self.__migration_14,
604605
}
605606
)
606607

@@ -849,6 +850,14 @@ def __migration_13(self, data: dict) -> dict:
849850
migrated_data.setdefault("rlhf_dpo_execution_mode", "SEQUENTIAL")
850851
return migrated_data
851852

853+
def __migration_14(self, data: dict) -> dict:
854+
migrated_data = data.copy()
855+
migrated_data.setdefault("transfer_step1", False)
856+
migrated_data.setdefault("transfer_step2", False)
857+
migrated_data.setdefault("transfer_guidance", 3.0)
858+
migrated_data.setdefault("transfer_train_lora", False)
859+
return migrated_data
860+
852861
def effective_dpo_ref_mode(self) -> DPORefMode:
853862
return DPORefMode.EXISTING_ADAPTER if self.lora_model_name else DPORefMode.NEW_ADAPTER
854863

modules/util/image_metadata_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@ def strip_angle_bracket_segments(prompt: str) -> str:
5353

5454
def _extract_raw_metadata(path: str) -> dict:
5555
"""Scan raw file bytes for plaintext metadata (JPEG, WebP, etc.)."""
56+
CHUNK = 256 * 1024 # 256 KB covers metadata in most formats
5657
with open(path, 'rb') as f:
57-
raw = f.read()
58+
raw = f.read(CHUNK)
59+
if b'"sui_image_params"' not in raw and b'"prompt"' not in raw:
60+
raw = raw + f.read() # fall back to full read
5861
# Preserve non-ASCII bytes instead of silently dropping them.
5962
text = raw.decode('utf-8', errors='surrogateescape')
6063

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import tempfile
2+
import unittest
3+
from pathlib import Path
4+
5+
from modules.util.dpo_curation_util import export_curated_pairs
6+
7+
from PIL import Image
8+
9+
10+
class DPOCurationUtilExportTest(unittest.TestCase):
11+
def test_export_pairs_copies_images_and_strips_angle_bracket_metadata_from_captions(self):
12+
with tempfile.TemporaryDirectory() as temp_dir:
13+
temp_path = Path(temp_dir)
14+
source_dir = temp_path / "source"
15+
output_dir = temp_path / "output"
16+
source_dir.mkdir()
17+
18+
chosen_path = source_dir / "chosen.png"
19+
rejected_path = source_dir / "rejected.jpg"
20+
21+
Image.new("RGB", (8, 8), color="red").save(chosen_path)
22+
Image.new("RGB", (8, 8), color="blue").save(rejected_path)
23+
24+
groups = [{
25+
"prompt": "portrait, <wildcard:hair>, cinematic light, <segment:face>, sharp focus",
26+
"aspectratio": "1:1",
27+
"images": [str(chosen_path), str(rejected_path)],
28+
}]
29+
results = {
30+
0: [
31+
{
32+
"chosen": str(chosen_path),
33+
"rejected": str(rejected_path),
34+
}
35+
]
36+
}
37+
38+
chosen_dir, rejected_dir, chosen_val_dir, rejected_val_dir, skipped, val_count, train_count = \
39+
export_curated_pairs(groups, results, str(output_dir), val_percentage=0.0)
40+
41+
self.assertEqual(skipped, 0)
42+
43+
exported_chosen_path = Path(chosen_dir) / "pair_0000_0000.png"
44+
exported_rejected_path = Path(rejected_dir) / "pair_0000_0000.jpg"
45+
exported_chosen_caption = Path(chosen_dir) / "pair_0000_0000.txt"
46+
exported_rejected_caption = Path(rejected_dir) / "pair_0000_0000.txt"
47+
48+
self.assertTrue(exported_chosen_path.exists())
49+
self.assertTrue(exported_rejected_path.exists())
50+
self.assertEqual(exported_chosen_path.read_bytes(), chosen_path.read_bytes())
51+
self.assertEqual(exported_rejected_path.read_bytes(), rejected_path.read_bytes())
52+
53+
expected_caption = "portrait, cinematic light, sharp focus"
54+
self.assertEqual(exported_chosen_caption.read_text(encoding="utf-8"), expected_caption)
55+
self.assertEqual(exported_rejected_caption.read_text(encoding="utf-8"), expected_caption)
56+
57+
def test_has_existing_exports_detects_prior_output(self):
58+
from modules.util.dpo_curation_util import has_existing_exports
59+
60+
with tempfile.TemporaryDirectory() as temp_dir:
61+
# Empty directory — no prior exports
62+
self.assertFalse(has_existing_exports(temp_dir))
63+
64+
# Create a chosen dir with a pair file
65+
chosen_dir = Path(temp_dir) / "chosen"
66+
chosen_dir.mkdir()
67+
(chosen_dir / "pair_0000.png").write_bytes(b"fake")
68+
self.assertTrue(has_existing_exports(temp_dir))
69+
70+
def test_has_existing_exports_ignores_non_pair_files(self):
71+
from modules.util.dpo_curation_util import has_existing_exports
72+
73+
with tempfile.TemporaryDirectory() as temp_dir:
74+
chosen_dir = Path(temp_dir) / "chosen"
75+
chosen_dir.mkdir()
76+
(chosen_dir / "something_else.png").write_bytes(b"fake")
77+
self.assertFalse(has_existing_exports(temp_dir))

tests/util/test_dpo_loss.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import math
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
7+
def _mse_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
8+
"""Reproduce the per-sample MSE used in calculate_dpo_loss."""
9+
return (pred - target).pow(2).mean(dim=list(range(1, pred.ndim)))
10+
11+
12+
def _dpo_loss(
13+
policy_chosen_logp: torch.Tensor,
14+
policy_rejected_logp: torch.Tensor,
15+
ref_chosen_logp: torch.Tensor,
16+
ref_rejected_logp: torch.Tensor,
17+
beta: float,
18+
label_smoothing: float = 0.0,
19+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
20+
"""Reproduce the DPO loss computation from BaseModelSetup.calculate_dpo_loss."""
21+
chosen_ratio = policy_chosen_logp - ref_chosen_logp
22+
rejected_ratio = policy_rejected_logp - ref_rejected_logp
23+
logits = beta * (chosen_ratio - rejected_ratio)
24+
dpo_loss = -F.logsigmoid(logits).mean()
25+
26+
if label_smoothing > 0:
27+
smooth_loss = -F.logsigmoid(-logits).mean()
28+
loss = (1 - label_smoothing) * dpo_loss + label_smoothing * smooth_loss
29+
else:
30+
loss = dpo_loss
31+
32+
chosen_reward = chosen_ratio.detach().mean()
33+
rejected_reward = rejected_ratio.detach().mean()
34+
accuracy = (chosen_ratio > rejected_ratio).float().mean()
35+
36+
return loss, dpo_loss, chosen_reward, rejected_reward, accuracy
37+
38+
39+
class TestDPOLossMath:
40+
def test_beta_zero_gives_log2(self):
41+
"""At beta=0, logits collapse to 0 and -log(sigmoid(0)) = log(2)."""
42+
B = 4
43+
policy_chosen_logp = torch.randn(B)
44+
policy_rejected_logp = torch.randn(B)
45+
ref_chosen_logp = torch.randn(B)
46+
ref_rejected_logp = torch.randn(B)
47+
48+
loss, dpo_loss, _, _, _ = _dpo_loss(
49+
policy_chosen_logp, policy_rejected_logp,
50+
ref_chosen_logp, ref_rejected_logp,
51+
beta=0.0,
52+
)
53+
assert abs(loss.item() - math.log(2)) < 1e-6
54+
55+
def test_perfect_preference_gives_low_loss(self):
56+
"""When policy strongly prefers chosen over rejected, loss should be low."""
57+
B = 4
58+
policy_chosen_logp = torch.tensor([0.0] * B)
59+
policy_rejected_logp = torch.tensor([-10.0] * B)
60+
ref_chosen_logp = torch.tensor([-5.0] * B)
61+
ref_rejected_logp = torch.tensor([-5.0] * B)
62+
63+
loss, _, chosen_reward, rejected_reward, accuracy = _dpo_loss(
64+
policy_chosen_logp, policy_rejected_logp,
65+
ref_chosen_logp, ref_rejected_logp,
66+
beta=5000.0,
67+
)
68+
assert accuracy.item() == 1.0
69+
assert chosen_reward.item() > rejected_reward.item()
70+
assert loss.item() < 0.01
71+
72+
def test_inverted_preference_gives_high_loss(self):
73+
"""When policy prefers rejected over chosen, loss should be high."""
74+
B = 4
75+
policy_chosen_logp = torch.tensor([-10.0] * B)
76+
policy_rejected_logp = torch.tensor([0.0] * B)
77+
ref_chosen_logp = torch.tensor([-5.0] * B)
78+
ref_rejected_logp = torch.tensor([-5.0] * B)
79+
80+
loss, _, _, _, accuracy = _dpo_loss(
81+
policy_chosen_logp, policy_rejected_logp,
82+
ref_chosen_logp, ref_rejected_logp,
83+
beta=5000.0,
84+
)
85+
assert accuracy.item() == 0.0
86+
assert loss.item() > 10.0
87+
88+
def test_label_smoothing_reduces_extreme_loss(self):
89+
"""Label smoothing should make loss less extreme for both directions."""
90+
B = 4
91+
policy_chosen_logp = torch.tensor([0.0] * B)
92+
policy_rejected_logp = torch.tensor([-10.0] * B)
93+
ref_chosen_logp = torch.tensor([-5.0] * B)
94+
ref_rejected_logp = torch.tensor([-5.0] * B)
95+
96+
loss_no_smooth, _, _, _, _ = _dpo_loss(
97+
policy_chosen_logp, policy_rejected_logp,
98+
ref_chosen_logp, ref_rejected_logp,
99+
beta=5000.0, label_smoothing=0.0,
100+
)
101+
loss_smooth, _, _, _, _ = _dpo_loss(
102+
policy_chosen_logp, policy_rejected_logp,
103+
ref_chosen_logp, ref_rejected_logp,
104+
beta=5000.0, label_smoothing=0.1,
105+
)
106+
assert loss_smooth.item() > loss_no_smooth.item()
107+
108+
def test_mse_per_sample_reduces_correctly(self):
109+
"""MSE reduction should produce [B] shape from [B, C, H, W]."""
110+
B, C, H, W = 2, 4, 8, 8
111+
pred = torch.randn(B, C, H, W)
112+
target = torch.randn(B, C, H, W)
113+
result = _mse_per_sample(pred, target)
114+
assert result.shape == (B,)
115+
expected_0 = (pred[0] - target[0]).pow(2).mean()
116+
assert abs(result[0].item() - expected_0.item()) < 1e-6

0 commit comments

Comments
 (0)