-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathatom.xml
More file actions
633 lines (350 loc) · 390 KB
/
atom.xml
File metadata and controls
633 lines (350 loc) · 390 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<title>拾荒志</title>
<subtitle>虚怀若谷,大智若愚</subtitle>
<link href="/atom.xml" rel="self"/>
<link href="https://murphypei.github.io/"/>
<updated>2026-03-04T06:11:49.807Z</updated>
<id>https://murphypei.github.io/</id>
<author>
<name>AngryBirds</name>
</author>
<generator uri="https://hexo.io/">Hexo</generator>
<entry>
<title>《FQ-Eval Building Evaluation Dataset for User-centered Follow-up Question Generation》</title>
<link href="https://murphypei.github.io/blog/2026/03/paper-qa-eval.html"/>
<id>https://murphypei.github.io/blog/2026/03/paper-qa-eval.html</id>
<published>2026-03-04T08:20:00.000Z</published>
<updated>2026-03-04T06:11:49.807Z</updated>
<content type="html"><![CDATA[<p>FQ-Eval 这是一篇关于对话中的 Follow ups questions 的论文。</p><span id="more"></span><h2 id="一、论文基本信息"><a href="#一、论文基本信息" class="headerlink" title="一、论文基本信息"></a>一、论文基本信息</h2><h2 id="报告基本信息"><a href="#报告基本信息" class="headerlink" title="报告基本信息"></a>报告基本信息</h2><p><strong>论文标题</strong></p><p>FQ-Eval: Building Evaluation Dataset for User-centered Follow-up Question Generation</p><p><strong>发表会议</strong></p><p>2025 Conference on Empirical Methods in Natural Language Processing (EMNLP 2025) Industry Track</p><p><strong>论文地址</strong></p><p><a href="https://aclanthology.org/2025.emnlp-industry.188.pdf">https://aclanthology.org/2025.emnlp-industry.188.pdf</a></p><p><strong>开源地址</strong></p><p><a href="https://github.com/LGAI-Research/FQ-Eval">https://github.com/LGAI-Research/FQ-Eval</a></p><h2 id="二、研究背景与问题提出"><a href="#二、研究背景与问题提出" class="headerlink" title="二、研究背景与问题提出"></a>二、研究背景与问题提出</h2><h3 id="2-1-研究背景"><a href="#2-1-研究背景" class="headerlink" title="2.1 研究背景"></a>2.1 研究背景</h3><p>随着 ChatGPT、Copilot、Perplexity 等对话式 LLM 服务的规模化落地,<strong>后续问题生成(Follow-up Question Generation)</strong> 已成为提升用户体验的核心功能:</p><ol><li>大量用户难以清晰、完整地向 LLM 表达自身潜在意图与对话目标,模糊查询易导致对话偏离用户初衷,显著降低使用满意度;</li><li>高质量的后续问题可降低用户持续交互的门槛,引导用户深化对话、达成核心目标,直接影响商用 LLM 产品的用户留存与活跃度。</li></ol><h3 id="2-2-现有研究的核心缺陷"><a href="#2-2-现有研究的核心缺陷" class="headerlink" title="2.2 现有研究的核心缺陷"></a>2.2 现有研究的核心缺陷</h3><p>现有后续问题生成相关研究存在三大关键短板,无法适配真实商用 LLM 场景的需求:</p><ol><li><strong>评估维度严重片面</strong>:绝大多数研究仅聚焦于「信息检索能力」「主题相关性」两个维度,完全忽视了后续问题对用户内在需求的满足、对用户对话目标的支撑,生成的问题与用户真实预期严重脱节;</li><li><strong>缺乏标准化评估框架</strong>:领域内尚无统一、可复现的用户中心式评估方法,不同研究的评估指标、数据集差异极大,无法横向对比模型性能;</li><li><strong>数据集与真实场景脱节</strong>:现有主流数据集(如 FQ-Bank、FollowupQG)多基于合成文本、社交平台非正式对话构建,无法反映商用 LLM 服务中用户的真实使用场景与交互模式。</li></ol><h3 id="2-3-核心研究问题"><a href="#2-3-核心研究问题" class="headerlink" title="2.3 核心研究问题"></a>2.3 核心研究问题</h3><p>如何构建一套<strong>贴合真实商用 LLM 使用场景、以用户为中心</strong>的后续问题生成评估体系,包括标准化的评估准则、高质量的评估数据集,以及可复现的评估方法,全面衡量后续问题对用户的真实价值。</p><h2 id="三、核心研究内容:任务定义与评估准则"><a href="#三、核心研究内容:任务定义与评估准则" class="headerlink" title="三、核心研究内容:任务定义与评估准则"></a>三、核心研究内容:任务定义与评估准则</h2><h3 id="3-1-任务形式化定义"><a href="#3-1-任务形式化定义" class="headerlink" title="3.1 任务形式化定义"></a>3.1 任务形式化定义</h3><p>论文首次正式定义了商用 LLM 场景下的后续问题生成任务与评估任务,明确了任务边界:</p><ol><li><strong>生成任务</strong>:给定单轮用户问题 $q \in Q$ 与对应 LLM 回答 $a \in A$ ,通过生成函数 $F$ 输出有价值的后续问题 $f$,即:$f=F(q,a), F:Q \times A \to F$</li><li><strong>评估任务</strong>:分为两种可互补的范式</li></ol><ul><li>相对评估(n-way 选择):给定 QA 对、评估准则与多个候选后续问题,评估函数 $E$ 选出最符合准则的最优问题;</li><li>绝对评估(分数制):给定 QA 对、评估准则与单个候选后续问题,评估函数 $E$ 输出 1-5 分的质量评分,分数越高代表与准则的对齐度越好。</li></ul><h3 id="3-2-五大用户对齐评估准则(核心创新)"><a href="#3-2-五大用户对齐评估准则(核心创新)" class="headerlink" title="3.2 五大用户对齐评估准则(核心创新)"></a>3.2 五大用户对齐评估准则(核心创新)</h3><p>论文通过<strong>三阶段用户研究</strong>(半结构化访谈→焦点小组讨论→亲和图分析),招募 9 名有丰富 LLM 使用经验、覆盖教育 / 设计 / 编程 / 产品等多元背景的参与者,最终提炼出 5 个完全贴合用户真实预期的评估准则,彻底突破了现有研究仅关注相关性的局限。</p><div class="table-container"><table><thead><tr><th>准则编号</th><th>准则名称</th><th>核心定义</th><th>关键评估维度</th><th>论文示例(恐龙睡前故事场景)</th></tr></thead><tbody><tr><td>C1</td><td>探索范围(Exploratory Scope)</td><td>评估后续问题能否拓宽、深化用户与 LLM 的对话,引导用户发现被忽略的视角与子话题</td><td>多视角探究、相关概念引入、深度细分挖掘、话题范围精细化</td><td>How can I make the bedtime story more interactive so my child can participate as we read?</td></tr><tr><td>C2</td><td>语境相关性(Contextual Relevance)</td><td>评估后续问题能否与前文对话、用户原始意图保持一致,避免话题偏离</td><td>话题连续性、用户意图对齐、核心焦点保持、上下文承接</td><td>How can I use ideas from the story, like counting stars or listening to crickets, to help my child fall asleep?</td></tr><tr><td>C3</td><td>创意跳跃(Creative Leap)</td><td>评估后续问题能否突破常规框架,激发用户的原创思考、意外洞见与想象力探索</td><td>新颖视角、跨领域关联、想象力引导、交互趣味性提升</td><td>Can you tell a bedtime story where a dinosaur and a dragon become best friends and share an adventure?</td></tr><tr><td>C4</td><td>LLM 赋能(LLM Enablement)</td><td>评估后续问题能否引导用户充分利用 LLM 的多元能力,掌握更高效的交互方法</td><td>LLM 功能展示、具体使用场景示例、提示词优化引导、高阶用法挖掘</td><td>How can I adapt AI-generated stories to be appropriate for both younger and slightly older kids?</td></tr><tr><td>C5</td><td>引导入门(Guided Onboarding)</td><td>评估后续问题能否帮助用户快速开启新主题、陌生领域的探索,降低入门门槛</td><td>核心概念高亮、入门关键词推荐、探索路径指引、基础背景补充</td><td>What are some age-appropriate adventure themes for bedtime stories for 5-year-olds?</td></tr></tbody></table></div><h2 id="四、FQ-Eval-数据集构建"><a href="#四、FQ-Eval-数据集构建" class="headerlink" title="四、FQ-Eval 数据集构建"></a>四、FQ-Eval 数据集构建</h2><p>论文构建的 FQ-Eval 是领域内首个基于真实生成式 AI 使用场景、以用户为中心的后续问题生成评估数据集,构建流程全程遵循「真实场景锚定 - 人工精制校准 - 多轮质量管控」原则,确保数据集的可靠性与代表性。</p><h3 id="4-1-数据集整体规模"><a href="#4-1-数据集整体规模" class="headerlink" title="4.1 数据集整体规模"></a>4.1 数据集整体规模</h3><ul><li>核心单元:200 个单轮 QA 对,覆盖 6 大真实生成式 AI 使用类别;</li><li>标注内容:1000 个人工精制的后续问题,每个 QA 对对应 5 个问题,分别与 5 大评估准则一一对应;</li><li>对话类型:单轮对话式 QA,适配商用 LLM 的主流交互场景。</li></ul><h3 id="4-2-分阶段构建流程"><a href="#4-2-分阶段构建流程" class="headerlink" title="4.2 分阶段构建流程"></a>4.2 分阶段构建流程</h3><h4 id="阶段-1:种子-QA-对生成(真实场景锚定)"><a href="#阶段-1:种子-QA-对生成(真实场景锚定)" class="headerlink" title="阶段 1:种子 QA 对生成(真实场景锚定)"></a>阶段 1:种子 QA 对生成(真实场景锚定)</h4><ol><li><strong>场景来源</strong>:基于哈佛商业评论(HBR)2025 年发布的《生成式 AI 真实使用报告》,覆盖 100 个真实用户生成式 AI 使用场景,分为内容创作、创意娱乐、学习教育、个人 / 职业支持、研究分析、技术支持 6 大类别;</li><li><strong>候选生成</strong>:为每个场景设计简单、复杂 2 个难度等级,通过 GPT-4.1 为每个难度生成 5 个候选用户问题,总计 1000 个候选种子问题;</li><li><strong>人工筛选修订</strong>:3 名训练有素的标注员独立评估,从每个难度的 5 个候选中选出 1 个最贴合真实用户语言习惯、场景对齐度最高的问题,最终形成 200 个高质量种子问题;</li><li><strong>回答生成与质检</strong>:通过 GPT-4.1 为每个种子问题生成对应回答,标注员审核回答的上下文连贯性与完整性,形成最终的 200 个 QA 对。</li></ol><h4 id="阶段-2:后续问题筛选与精制(准则对齐)"><a href="#阶段-2:后续问题筛选与精制(准则对齐)" class="headerlink" title="阶段 2:后续问题筛选与精制(准则对齐)"></a>阶段 2:后续问题筛选与精制(准则对齐)</h4><ol><li><strong>候选生成</strong>:针对每个 QA 对 + 每个评估准则,通过 GPT-4.1 生成 8 个严格贴合准则定义的候选后续问题,每个 QA 对总计生成 40 个候选;</li><li><strong>人工筛选与修订</strong>:7 名专业标注员(6 名英语母语者 + 1 名双语专家)独立完成两项工作:<ul><li>筛选:从每个准则的 8 个候选中,选出最能体现该准则核心特征的 1 个问题;</li><li>修订:优化选中问题的准则对齐度、上下文连贯性与语言质量,保留核心语义不变;</li></ul></li><li><strong>质量管控</strong>:通过 GPT-4.1 对所有最终问题进行自动打分,得分低于 2 分的问题触发人工复检与重标注,确保所有问题与准则的高度对齐。</li></ol><h3 id="4-3-与现有主流数据集的对比"><a href="#4-3-与现有主流数据集的对比" class="headerlink" title="4.3 与现有主流数据集的对比"></a>4.3 与现有主流数据集的对比</h3><div class="table-container"><table><thead><tr><th>核心属性</th><th>FQ-Bank</th><th>FollowupQG</th><th>FQ-Eval(本文)</th></tr></thead><tbody><tr><td>对话类型</td><td>多轮</td><td>单轮</td><td>单轮</td></tr><tr><td>QA 对数量</td><td>2132</td><td>501</td><td>200</td></tr><tr><td>后续问题数量</td><td>2132</td><td>501</td><td>1000</td></tr><tr><td>领域</td><td>基于篇章的 QA</td><td>社交 QA</td><td>对话式 QA</td></tr><tr><td>数据来源</td><td>OrQuAC 合成语料</td><td>Reddit 社交平台</td><td>HBR 真实生成式 AI 使用场景</td></tr><tr><td>核心特征</td><td>合成 / 人工设计</td><td>非正式口语化</td><td>真实场景、用户中心、准则对齐</td></tr></tbody></table></div><hr><h2 id="五、实验设计与核心结果"><a href="#五、实验设计与核心结果" class="headerlink" title="五、实验设计与核心结果"></a>五、实验设计与核心结果</h2><p>论文通过多组互补实验,全面验证了 FQ-Eval 数据集的有效性、评估框架的可靠性,同时揭示了当前主流 LLM 在后续问题生成上的核心短板。</p><h3 id="5-1-实验基础设置"><a href="#5-1-实验基础设置" class="headerlink" title="5.1 实验基础设置"></a>5.1 实验基础设置</h3><ol><li><strong>评估范式</strong>:采用「n-way 选择任务 + 分数制评估任务」双范式,确保评估结果的全面性;</li><li><strong>评估器设置</strong>:<ul><li>主力评估器:GPT-4.1(temperature=0.0,贪心解码,确保结果可复现);</li><li>金标准验证:独立人类标注员盲测,与数据集构建阶段的标注员完全隔离,避免偏见;</li><li>鲁棒性验证:额外采用 Claude Opus 4、Mistral Large、Gemini 2.5 Flash 作为备选评估器;</li></ul></li><li><strong>对比模型</strong>:覆盖 OpenAI、Anthropic、Mistral、Google 四大厂商的 12 款主流 LLM,包含旗舰大模型与中小尺寸模型;</li><li><strong>工业场景验证</strong>:对商用产品 Perplexity 的后续问题生成能力进行实测评估。</li></ol><h3 id="5-2-核心实验结果"><a href="#5-2-核心实验结果" class="headerlink" title="5.2 核心实验结果"></a>5.2 核心实验结果</h3><h4 id="实验-1:5-way-选择任务(核心验证实验)"><a href="#实验-1:5-way-选择任务(核心验证实验)" class="headerlink" title="实验 1:5-way 选择任务(核心验证实验)"></a>实验 1:5-way 选择任务(核心验证实验)</h4><p><strong>实验设置</strong>:对每个 QA 对 + 每个准则,构建 5 候选池(1 个 FQ-Eval 基准问题 + 4 款旗舰 LLM 生成的问题),评估者在盲测环境下选出最符合准则的问题,统计各候选的选择率。</p><p><strong>核心结果</strong>:</p><ol><li>FQ-Eval 在 5 个准则上的平均选择率显著高于所有主流 LLM,在<strong>创意跳跃</strong>维度,LLM-judge 选择率 99.5%、人类选择率 92%,几乎碾压所有对比模型;</li><li>在<strong>引导入门、LLM 赋能、探索范围</strong>三个维度,FQ-Eval 的 LLM-judge 选择率均超过 79%,大幅领先对比模型;</li><li>仅在<strong>语境相关性</strong>维度,Claude Opus 4 等旗舰模型展现出竞争力(人类选择率 28.5%,略高于 FQ-Eval 的 24%),验证了现有 LLM 仅在相关性维度有成熟能力,与论文的核心假设一致。</li></ol><h4 id="实验-2:2-way-两两对比任务(全模型覆盖)"><a href="#实验-2:2-way-两两对比任务(全模型覆盖)" class="headerlink" title="实验 2:2-way 两两对比任务(全模型覆盖)"></a>实验 2:2-way 两两对比任务(全模型覆盖)</h4><p><strong>实验设置</strong>:对每个 QA 对 + 每个准则,构建 2 候选池(1 个 FQ-Eval 基准问题 + 1 款待评估 LLM 生成的问题),评估者二选一,统计 FQ-Eval 的胜率。</p><p><strong>核心结果</strong>:</p><ol><li>FQ-Eval 在 12 款 LLM 的全量对比中,<strong>平均胜率达 86.4%</strong>,在所有准则上均占据绝对优势;</li><li>在<strong>创意跳跃</strong>维度,FQ-Eval 对所有模型的胜率均超过 99%,证明当前主流 LLM 完全无法生成符合用户预期的高创意性后续问题;</li><li>同厂商内的中小尺寸模型,在部分准则上的表现优于旗舰大模型,核心原因是中小模型的输出更简洁,更适配简单场景的用户需求;</li><li>随着用户问题复杂度提升,FQ-Eval 的胜率略有下降,说明旗舰大模型在高复杂度场景下能更好地发挥能力。</li></ol><h4 id="实验-3:分数制绝对评估任务"><a href="#实验-3:分数制绝对评估任务" class="headerlink" title="实验 3:分数制绝对评估任务"></a>实验 3:分数制绝对评估任务</h4><p><strong>实验设置</strong>:对每个候选后续问题,基于 5 大准则进行 1-5 分的独立绝对评分,分数越高代表质量越好。</p><p><strong>核心结果(平均分)</strong>:</p><div class="table-container"><table><thead><tr><th>模型</th><th>C1 探索范围</th><th>C2 语境相关性</th><th>C3 创意跳跃</th><th>C4 LLM 赋能</th><th>C5 引导入门</th></tr></thead><tbody><tr><td>FQ-Eval</td><td>4.04</td><td>4.96</td><td>4.35</td><td>4.21</td><td>4.24</td></tr><tr><td>GPT-4.1</td><td>3.23</td><td>4.88</td><td>1.97</td><td>3.28</td><td>3.13</td></tr><tr><td>Claude Opus 4</td><td>3.53</td><td>4.97</td><td>2.00</td><td>3.40</td><td>3.26</td></tr><tr><td>Mistral Large</td><td>3.45</td><td>4.75</td><td>2.01</td><td>3.05</td><td>3.16</td></tr><tr><td>Gemini 2.5 Flash</td><td>3.18</td><td>4.86</td><td>1.92</td><td>3.13</td><td>3.08</td></tr></tbody></table></div><p>关键结论:</p><ul><li>FQ-Eval 在所有 5 个准则上的平均分均高于主流旗舰 LLM,总分领先幅度显著;</li><li>语境相关性维度,FQ-Eval 与 Claude Opus 4 几乎持平,再次验证现有模型在该维度的成熟度;</li><li>创意跳跃维度,FQ-Eval 平均分是对比模型的 2 倍以上,差距极为显著。</li></ul><h4 id="实验-4:真实商用服务评估"><a href="#实验-4:真实商用服务评估" class="headerlink" title="实验 4:真实商用服务评估"></a>实验 4:真实商用服务评估</h4><p><strong>核心发现</strong>:Perplexity 的后续问题仅在信息检索相关的探索范围、语境相关性维度表现尚可,但在<strong>LLM 赋能、引导入门、创意跳跃</strong>三个维度表现极差,与 FQ-Eval 的基准水平差距显著。证明 FQ-Eval 可直接用于商用产品的能力诊断,帮助开发者定位优化方向,平衡产品的质量、成本与延迟。</p><h3 id="5-3-鲁棒性验证"><a href="#5-3-鲁棒性验证" class="headerlink" title="5.3 鲁棒性验证"></a>5.3 鲁棒性验证</h3><ol><li><strong>LLM-judge 一致性</strong>:4 款不同厂商的 LLM 作为评估器时,两两之间的皮尔逊相关系数均 > 0.8(p<0.001),整体评分相邻一致性(分差≤1 分)达 97.2%,证明评估结果稳定可靠,无系统性偏差;</li><li><strong>人类与 LLM-judge 对齐性</strong>:人类标注结果与 GPT-4.1 的评估结果高度一致,无显著差异,证明 LLM-judge 的自动化评估完全贴合人类真实偏好;</li><li><strong>循环性偏见规避</strong>:针对 “GPT-4.1 同时参与数据集构建与评估” 的潜在偏见,论文通过人工修订消除模型风格偏差、人类评估对齐、多模型交叉验证三种方式,证明结果无循环性偏见。</li></ol>]]></content>
<summary type="html">
<p>FQ-Eval 这是一篇关于对话中的 Follow ups questions 的论文。</p>
</summary>
<category term="论文阅读" scheme="https://murphypei.github.io/categories/%E8%AE%BA%E6%96%87%E9%98%85%E8%AF%BB/"/>
<category term="论文" scheme="https://murphypei.github.io/tags/%E8%AE%BA%E6%96%87/"/>
<category term="FQ-Eval" scheme="https://murphypei.github.io/tags/FQ-Eval/"/>
<category term="Evaluation" scheme="https://murphypei.github.io/tags/Evaluation/"/>
<category term="Dataset" scheme="https://murphypei.github.io/tags/Dataset/"/>
<category term="Follow-up" scheme="https://murphypei.github.io/tags/Follow-up/"/>
</entry>
<entry>
<title>大模型训练方法:DAPO</title>
<link href="https://murphypei.github.io/blog/2025/12/llm-dapo.html"/>
<id>https://murphypei.github.io/blog/2025/12/llm-dapo.html</id>
<published>2025-12-01T08:20:00.000Z</published>
<updated>2026-01-27T04:21:57.272Z</updated>
<content type="html"><![CDATA[<p>DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) 作为大模型强化学习训练的新兴算法,通过四项核心改进有效解决了GRPO在长序列优化中的痛点问题,在数学推理等复杂任务中取得了显著的性能提升。</p><span id="more"></span><h2 id="DAPO-简介"><a href="#DAPO-简介" class="headerlink" title="DAPO 简介"></a>DAPO 简介</h2><p>在大模型强化学习(RL)训练领域,OpenAI 的 O1、DeepSeek 的 R1 等模型凭借出色的复杂任务表现,证明了大规模 RL 训练的巨大价值。但这些顶尖模型的核心训练技术细节长期未公开,而主流的 GRPO 算法在长链式思维(CoT)等复杂场景中,又频繁面临熵崩溃、训练不稳定等问题。</p><p><strong>DAPO(Decoupled Clip and Dynamic Sampling Policy Optimization)</strong> 算法针对这些问题,提出了四项关键改进,在保持训练稳定性的同时显著提升了模型在复杂推理任务上的表现。本文将深入剖析 DAPO 的技术原理,重点解析其核心改进点及实际应用价值。</p><h2 id="DAPO-的技术背景与核心定位"><a href="#DAPO-的技术背景与核心定位" class="headerlink" title="DAPO 的技术背景与核心定位"></a>DAPO 的技术背景与核心定位</h2><h3 id="传统-RL-训练的痛点"><a href="#传统-RL-训练的痛点" class="headerlink" title="传统 RL 训练的痛点"></a>传统 RL 训练的痛点</h3><p>GRPO 作为大模型 RL 训练的主流算法,虽能提升模型性能,但在数学推理等长 CoT 场景中暴露出诸多缺陷:</p><ol><li><strong>熵崩溃</strong>:固定剪裁范围限制了低概率 token 的探索,导致策略快速收敛,生成内容同质化严重</li><li><strong>梯度失效</strong>:训练后期易出现大量全优或全差样本,优势函数趋近于零,梯度信号消失,浪费训练资源 </li><li><strong>梯度稀释</strong>:长序列样本的梯度经样本级平均后被弱化,关键推理步骤的优化信号无法有效传递</li><li><strong>奖励噪声</strong>:对超长响应的刚性截断惩罚,易误判有效推理内容,干扰模型对奖励信号的学习</li></ol><h3 id="DAPO-的核心目标"><a href="#DAPO-的核心目标" class="headerlink" title="DAPO 的核心目标"></a>DAPO 的核心目标</h3><p>DAPO 以”解决长序列 RL 优化难题、提供开源可复现方案、突破复杂任务性能上限”为核心目标。其算法设计围绕”释放模型探索能力、提升训练信号质量、精准优化长序列”三大方向展开,在 AIME 等数学推理任务中取得了显著的性能提升。</p><h2 id="DAPO-的四大核心改进点深度解析"><a href="#DAPO-的四大核心改进点深度解析" class="headerlink" title="DAPO 的四大核心改进点深度解析"></a>DAPO 的四大核心改进点深度解析</h2><p>DAPO 通过解耦剪辑策略、动态采样机制、token 级损失计算和长度感知奖励修正四项革新,针对性解决了 GRPO 的核心痛点。以下结合技术原理、解决路径和实际效果展开说明:</p><h3 id="改进一:Clip-Higher——解耦高低剪辑范围,平衡探索与利用"><a href="#改进一:Clip-Higher——解耦高低剪辑范围,平衡探索与利用" class="headerlink" title="改进一:Clip-Higher——解耦高低剪辑范围,平衡探索与利用"></a>改进一:Clip-Higher——解耦高低剪辑范围,平衡探索与利用</h3><h4 id="传统算法的局限"><a href="#传统算法的局限" class="headerlink" title="传统算法的局限"></a>传统算法的局限</h4><p>GRPO 和 PPO 均采用固定的对称剪辑范围(如 $\epsilon=0.2$),将新策略与旧策略的概率比值限制在 $[1-\epsilon, 1+\epsilon]$ 区间。传统的剪辑函数为:</p><script type="math/tex; mode=display">clip(r_t, 1-\epsilon, 1+\epsilon) = \max(\min(r_t, 1+\epsilon), 1-\epsilon)</script><p>其中 $r<em>t = \frac{\pi</em>\theta(a<em>t|s_t)}{\pi</em>{\theta_{old}}(a_t|s_t)}$ 是重要性采样比率。</p><p>这种设计虽能避免策略突变,但会严重抑制低概率关键 token 的优化。当模型偶然生成一个对推理至关重要但概率极低的 token 时,其概率比值极易超出剪辑上限而被截断,导致该 token 无法获得有效强化,最终模型探索能力衰退,陷入熵崩溃。</p><h4 id="DAPO-的解决方案"><a href="#DAPO-的解决方案" class="headerlink" title="DAPO 的解决方案"></a>DAPO 的解决方案</h4><p>DAPO 提出<strong>解耦式剪辑策略</strong>,将上下剪辑阈值拆分为独立参数:</p><script type="math/tex; mode=display">clip_{higher}(r_t, \epsilon_{low}, \epsilon_{high}) = \max(\min(r_t, 1+\epsilon_{high}), 1-\epsilon_{low})</script><p>其中:</p><ul><li>低阈值 $\epsilon_{low}=0.2$:保持与传统算法一致,抑制高概率 token 的过度利用,避免模型陷入局部最优</li><li>高阈值 $\epsilon_{high}=0.28$:放宽对高比值的限制,为低概率关键 token 提供足够的上涨空间</li></ul><p>这种非对称设计既保证了训练稳定性,又释放了模型的探索潜力。例如在数学推理中,模型生成的”辅助变量假设”等低概率关键推理步骤,能通过该机制获得有效强化,逐步形成更完整的推理链条。</p><h4 id="实际效果"><a href="#实际效果" class="headerlink" title="实际效果"></a>实际效果</h4><p>该改进使模型生成多样性显著提升,训练过程中熵值保持稳定。通过允许更大的上行剪辑范围,模型能够更好地学习低概率但高价值的 token,有效缓解了熵崩溃问题。</p><h3 id="改进二:Dynamic-Sampling——动态过滤无效样本,强化梯度信号"><a href="#改进二:Dynamic-Sampling——动态过滤无效样本,强化梯度信号" class="headerlink" title="改进二:Dynamic Sampling——动态过滤无效样本,强化梯度信号"></a>改进二:Dynamic Sampling——动态过滤无效样本,强化梯度信号</h3><h4 id="传统采样的核心问题"><a href="#传统采样的核心问题" class="headerlink" title="传统采样的核心问题"></a>传统采样的核心问题</h4><p>训练过程中,当模型性能提升到一定阶段,易出现大量全正确或全错误的样本。这些样本的优势函数值为零,对应的梯度信号也会消失,导致”采样量大但有效信息少”的资源浪费。</p><p>在强化学习中,优势函数定义为:</p><script type="math/tex; mode=display">A(s_t, a_t) = Q(s_t, a_t) - V(s_t)</script><p>当奖励为极端值(0 或 1)时,$A(s_t, a_t) \approx 0$,导致策略梯度:</p><script type="math/tex; mode=display">\nabla_\theta J(\theta) = \mathbb{E}[A(s_t, a_t) \nabla_\theta \log \pi_\theta(a_t|s_t)]</script><p>趋近于零,训练效率严重下降。</p><h4 id="DAPO-的动态采样策略"><a href="#DAPO-的动态采样策略" class="headerlink" title="DAPO 的动态采样策略"></a>DAPO 的动态采样策略</h4><p>DAPO 引入<strong>梯度有效性筛选机制</strong>,具体算法如下:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">dynamic_sampling</span>(<span class="params">model, prompts, target_batch_size</span>):</span><br><span class="line"> effective_samples = []</span><br><span class="line"> sampling_attempts = <span class="number">0</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">while</span> <span class="built_in">len</span>(effective_samples) < target_batch_size:</span><br><span class="line"> <span class="comment"># 生成候选样本</span></span><br><span class="line"> candidates = sample_responses(model, prompts)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 筛选有效样本(奖励不为极端值)</span></span><br><span class="line"> <span class="keyword">for</span> sample <span class="keyword">in</span> candidates:</span><br><span class="line"> <span class="keyword">if</span> <span class="number">0</span> < sample.reward < <span class="number">1</span>: <span class="comment"># 排除极端奖励</span></span><br><span class="line"> effective_samples.append(sample)</span><br><span class="line"> </span><br><span class="line"> sampling_attempts += <span class="number">1</span></span><br><span class="line"> <span class="keyword">if</span> sampling_attempts > max_attempts:</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> effective_samples[:target_batch_size]</span><br></pre></td></tr></table></figure><p>该机制的核心逻辑是:若当前采样结果中无有效梯度样本,则触发二次采样,直至批次中充满具备梯度价值的样本。由于过滤操作仅需校验奖励值,计算开销极低,不会显著增加训练时间。</p><h4 id="实际效果-1"><a href="#实际效果-1" class="headerlink" title="实际效果"></a>实际效果</h4><p>动态采样使模型收敛速度大幅提升,通过确保每个训练批次都包含有效的梯度信号,避免了训练后期因梯度失效导致的模型性能震荡。实验表明,该机制能够显著提升训练效率。</p><h3 id="改进三:Token-Level-Policy-Gradient-Loss——精准优化长序列,避免梯度稀释"><a href="#改进三:Token-Level-Policy-Gradient-Loss——精准优化长序列,避免梯度稀释" class="headerlink" title="改进三:Token-Level Policy Gradient Loss——精准优化长序列,避免梯度稀释"></a>改进三:Token-Level Policy Gradient Loss——精准优化长序列,避免梯度稀释</h3><h4 id="长序列优化的核心痛点"><a href="#长序列优化的核心痛点" class="headerlink" title="长序列优化的核心痛点"></a>长序列优化的核心痛点</h4><p>GRPO 采用”样本内平均+批次内平均”的双层损失计算方式。传统的损失函数为:</p><script type="math/tex; mode=display">L = \frac{1}{B} \sum_{i=1}^{B} \left[ \frac{1}{T_i} \sum_{t=1}^{T_i} \ell_t^{(i)} \right]</script><p>其中 $B$ 是批次大小,$T_i$ 是第 $i$ 个样本的序列长度,$\ell_t^{(i)}$ 是第 $i$ 个样本第 $t$ 个 token 的损失。</p><p>这种方式对长序列极不友好:一个 200 token 的长回答与一个 10 token 的短回答,在样本级平均后,每个 token 的梯度权重分别为 $1/200$ 和 $1/10$。长回答中的关键推理 token(如公式推导、逻辑转折)的梯度被严重稀释,模型难以捕捉长链式推理的核心规律。</p><h4 id="DAPO-的-token-级损失计算"><a href="#DAPO-的-token-级损失计算" class="headerlink" title="DAPO 的 token 级损失计算"></a>DAPO 的 token 级损失计算</h4><p>DAPO 将损失聚合方式从<strong>样本级</strong>改为<strong>全局 token 级</strong>:</p><script type="math/tex; mode=display">L_{token} = \frac{1}{\sum_{i=1}^{B} T_i} \sum_{i=1}^{B} \sum_{t=1}^{T_i} w_t^{(i)} \cdot \ell_t^{(i)}</script><p>其中 $w_t^{(i)}$ 是第 $i$ 个样本第 $t$ 个 token 的权重。</p><p>具体优化包括:</p><ol><li><strong>取消样本内梯度平均</strong>:直接计算每个 token 的独立损失</li><li><strong>全局归一化</strong>:按批次内所有 token 的总数进行归一化,而非按样本数量归一化 </li><li><strong>重要性权重</strong>:对长序列中的关键 token(如数学公式、逻辑连词)可额外增加权重</li></ol><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">token_level_loss</span>(<span class="params">batch_outputs, advantages</span>):</span><br><span class="line"> total_tokens = <span class="built_in">sum</span>(<span class="built_in">len</span>(output) <span class="keyword">for</span> output <span class="keyword">in</span> batch_outputs)</span><br><span class="line"> token_losses = []</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> output, advantage <span class="keyword">in</span> <span class="built_in">zip</span>(batch_outputs, advantages):</span><br><span class="line"> <span class="keyword">for</span> token_id, token_advantage <span class="keyword">in</span> <span class="built_in">zip</span>(output, advantage):</span><br><span class="line"> <span class="comment"># 每个token独立计算损失,无样本内平均</span></span><br><span class="line"> token_loss = -log_prob(token_id) * token_advantage</span><br><span class="line"> token_losses.append(token_loss)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 全局token级归一化</span></span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">sum</span>(token_losses) / total_tokens</span><br></pre></td></tr></table></figure><p>这确保长、短回答的每个 token 拥有公平的梯度权重,长序列中的有效信息不再被稀释。</p><h4 id="实际效果-2"><a href="#实际效果-2" class="headerlink" title="实际效果"></a>实际效果</h4><p>该改进让模型的长序列生成能力显著增强。在需要多步骤推理的数学题求解中,模型能稳定生成结构完整的推理过程,重复、乱码等低质量模式大幅减少,训练稳定性也得到明显提升。</p><h3 id="改进四:Overlong-Reward-Shaping——软惩罚机制,降低奖励噪声"><a href="#改进四:Overlong-Reward-Shaping——软惩罚机制,降低奖励噪声" class="headerlink" title="改进四:Overlong Reward Shaping——软惩罚机制,降低奖励噪声"></a>改进四:Overlong Reward Shaping——软惩罚机制,降低奖励噪声</h3><h4 id="传统长度惩罚的缺陷"><a href="#传统长度惩罚的缺陷" class="headerlink" title="传统长度惩罚的缺陷"></a>传统长度惩罚的缺陷</h4><p>传统算法对超长响应采用刚性惩罚(如超过长度阈值直接扣 1 分),这种方式存在两大问题:</p><ol><li><strong>误判有效内容</strong>:将有效长推理与无效冗余内容同等惩罚,例如完整的数学推导因超长被误判为低质量</li><li><strong>引入奖励噪声</strong>:突然的惩罚会引入剧烈奖励噪声,干扰模型对有效推理模式的学习</li></ol><h4 id="DAPO-的长度感知奖励修正"><a href="#DAPO-的长度感知奖励修正" class="headerlink" title="DAPO 的长度感知奖励修正"></a>DAPO 的长度感知奖励修正</h4><p>DAPO 设计了<strong>分段式软惩罚函数</strong>,根据响应长度动态调整惩罚力度:</p><script type="math/tex; mode=display">r_{\text{length}}(y) = \begin{cases}0, & |y| \leq l_{\text{max}} - l_{\text{cache}} \\\frac{(l_{\text{max}} - l_{\text{cache}}) - |y|}{l_{\text{cache}}}, & l_{\text{max}} - l_{\text{cache}} < |y| \leq l_{\text{max}} \\-1, & |y| > l_{\text{max}}\end{cases}</script><p>其中 $l<em>{\text{max}}$ 为最大长度阈值,$l</em>{\text{cache}}$ 为缓冲长度。该函数实现三层惩罚逻辑:</p><ol><li><strong>正常长度区间</strong>:无惩罚,鼓励完整推理</li><li><strong>缓冲区间</strong>:惩罚随长度线性增加,避免轻微超长被过度惩罚 </li><li><strong>超长区间</strong>:刚性惩罚,过滤无意义的冗余内容</li></ol><p>最终的奖励函数为:</p><script type="math/tex; mode=display">r_{\text{final}}(y) = r_{\text{original}}(y) + \alpha \cdot r_{\text{length}}(y)</script><p>其中 $\alpha$ 是长度惩罚的权重系数。</p><p>同时,DAPO 会过滤截断样本的损失计算,进一步减少噪声对训练的干扰:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">compute_length_penalty</span>(<span class="params">response_length, max_length=<span class="number">2048</span>, cache_length=<span class="number">512</span></span>):</span><br><span class="line"> <span class="keyword">if</span> response_length <= max_length - cache_length:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0.0</span></span><br><span class="line"> <span class="keyword">elif</span> response_length <= max_length:</span><br><span class="line"> <span class="keyword">return</span> (max_length - cache_length - response_length) / cache_length</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> -<span class="number">1.0</span></span><br></pre></td></tr></table></figure><h4 id="实际效果-3"><a href="#实际效果-3" class="headerlink" title="实际效果"></a>实际效果</h4><p>软惩罚机制使训练波动大幅降低,模型既能生成足够长度的推理内容,又能有效抑制无意义的冗余输出,实现了推理完整性与简洁性的平衡。</p><h2 id="DAPO-的整体训练流程与性能表现"><a href="#DAPO-的整体训练流程与性能表现" class="headerlink" title="DAPO 的整体训练流程与性能表现"></a>DAPO 的整体训练流程与性能表现</h2><h3 id="完整训练流程"><a href="#完整训练流程" class="headerlink" title="完整训练流程"></a>完整训练流程</h3><p>DAPO 的训练流程整合了四项核心改进,具体步骤如下:</p><ol><li><strong>数据预处理</strong>:输入 prompt 与高奖励回复数据,标注 token 级别重要性与响应长度信息</li><li><strong>动态采样</strong>:过滤无效样本,构建梯度有效的训练批次</li><li><strong>策略优化</strong>:基于 Clip-Higher 计算策略比值,通过 token 级损失函数反向传播梯度</li><li><strong>奖励修正</strong>:采用软惩罚函数调整超长响应的奖励值,更新模型参数</li><li><strong>迭代优化</strong>:循环上述步骤,直至模型熵值稳定且任务准确率达到目标</li></ol><h3 id="算法伪代码"><a href="#算法伪代码" class="headerlink" title="算法伪代码"></a>算法伪代码</h3><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">dapo_training</span>(<span class="params">model, dataset, epochs</span>):</span><br><span class="line"> <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(epochs):</span><br><span class="line"> <span class="comment"># 1. 动态采样</span></span><br><span class="line"> effective_batch = dynamic_sampling(model, dataset.prompts)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 2. 计算重要性采样比率 (Clip-Higher)</span></span><br><span class="line"> ratios = []</span><br><span class="line"> <span class="keyword">for</span> sample <span class="keyword">in</span> effective_batch:</span><br><span class="line"> old_prob = old_model.log_prob(sample)</span><br><span class="line"> new_prob = model.log_prob(sample)</span><br><span class="line"> ratio = exp(new_prob - old_prob)</span><br><span class="line"> clipped_ratio = clip_higher(ratio, eps_low=<span class="number">0.2</span>, eps_high=<span class="number">0.28</span>)</span><br><span class="line"> ratios.append(clipped_ratio)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 3. Token级损失计算</span></span><br><span class="line"> loss = compute_token_level_loss(effective_batch, ratios)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 4. 长度感知奖励修正</span></span><br><span class="line"> adjusted_rewards = []</span><br><span class="line"> <span class="keyword">for</span> sample <span class="keyword">in</span> effective_batch:</span><br><span class="line"> length_penalty = compute_length_penalty(<span class="built_in">len</span>(sample))</span><br><span class="line"> adjusted_reward = sample.reward + length_penalty</span><br><span class="line"> adjusted_rewards.append(adjusted_reward)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 5. 反向传播更新</span></span><br><span class="line"> optimizer.zero_grad()</span><br><span class="line"> loss.backward()</span><br><span class="line"> optimizer.step()</span><br></pre></td></tr></table></figure><h3 id="性能表现分析"><a href="#性能表现分析" class="headerlink" title="性能表现分析"></a>性能表现分析</h3><p>实验结果显示 DAPO 在数学推理任务上的显著提升:</p><div class="table-container"><table><thead><tr><th>技术组件</th><th>基线性能</th><th>性能提升</th><th>累计提升</th></tr></thead><tbody><tr><td>基础 GRPO</td><td>30%</td><td>-</td><td>-</td></tr><tr><td>+ 动态采样</td><td>34%</td><td>+4%</td><td>+4%</td></tr><tr><td>+ Clip-Higher</td><td>36%</td><td>+2%</td><td>+6%</td></tr><tr><td>+ Token级损失</td><td>39%</td><td>+3%</td><td>+9%</td></tr><tr><td>+ 软惩罚机制</td><td>42%</td><td>+3%</td><td>+12%</td></tr><tr><td>完整 DAPO 系统</td><td>45%</td><td>+3%</td><td>+15%</td></tr></tbody></table></div><p>从数据可见,DAPO 的四项改进形成了协同效应,每项改进都对最终性能有积极贡献。特别值得注意的是,训练过程中模型还自发演化出”反思修正”等高级推理能力,证明其优化方向贴合复杂任务的核心需求。</p><h2 id="总结与展望"><a href="#总结与展望" class="headerlink" title="总结与展望"></a>总结与展望</h2><h3 id="核心贡献"><a href="#核心贡献" class="headerlink" title="核心贡献"></a>核心贡献</h3><p>DAPO 的核心价值体现在以下几个方面:</p><ol><li><strong>系统性解决方案</strong>:针对 GRPO 在长序列训练中的四大痛点,提出了对应的解决方案</li><li><strong>技术创新</strong>:解耦剪辑策略和 token 级损失计算等创新设计,为强化学习优化提供了新思路</li><li><strong>实用性</strong>:显著的性能提升证明了算法的实际价值,特别是在数学推理等复杂任务上</li></ol><h3 id="技术优势总结"><a href="#技术优势总结" class="headerlink" title="技术优势总结"></a>技术优势总结</h3><div class="table-container"><table><thead><tr><th>改进点</th><th>解决问题</th><th>技术手段</th><th>核心优势</th></tr></thead><tbody><tr><td>Clip-Higher</td><td>熵崩溃</td><td>非对称剪辑范围</td><td>平衡探索与稳定性</td></tr><tr><td>Dynamic Sampling</td><td>梯度失效</td><td>有效样本筛选</td><td>提升训练效率</td></tr><tr><td>Token-Level Loss</td><td>梯度稀释</td><td>全局token归一化</td><td>公平优化长序列</td></tr><tr><td>Soft Length Penalty</td><td>奖励噪声</td><td>分段式软惩罚</td><td>平滑奖励信号</td></tr></tbody></table></div><h3 id="未来发展方向"><a href="#未来发展方向" class="headerlink" title="未来发展方向"></a>未来发展方向</h3><p>DAPO 的优化方向可能集中在以下几个方面:</p><ol><li><strong>自适应参数调整</strong>:根据训练阶段和任务特性,动态调整 Clip-Higher 的阈值参数</li><li><strong>智能权重分配</strong>:结合注意力机制,为不同类型的 token 分配更精准的权重</li><li><strong>多任务适配</strong>:扩展到更多领域的复杂推理任务,验证算法的通用性</li><li><strong>计算优化</strong>:进一步优化动态采样的计算效率,降低训练开销</li></ol><h3 id="实际应用价值"><a href="#实际应用价值" class="headerlink" title="实际应用价值"></a>实际应用价值</h3><p>对于研究人员而言,DAPO 提供了一套完整的长序列强化学习优化框架,为探索更高效的大模型训练路径提供了重要参考。对于工业界来说,其高效的训练策略为降低大规模模型的训练成本、提升复杂任务性能提供了实用的解决方案。</p><p>通过解决长序列强化学习的核心技术难题,DAPO 为大模型在复杂推理任务上的应用奠定了更加坚实的基础。</p>]]></content>
<summary type="html">
<p>DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) 作为大模型强化学习训练的新兴算法,通过四项核心改进有效解决了GRPO在长序列优化中的痛点问题,在数学推理等复杂任务中取得了显著的性能提升。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="强化学习" scheme="https://murphypei.github.io/tags/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0/"/>
<category term="GRPO" scheme="https://murphypei.github.io/tags/GRPO/"/>
<category term="DAPO" scheme="https://murphypei.github.io/tags/DAPO/"/>
<category term="策略优化" scheme="https://murphypei.github.io/tags/%E7%AD%96%E7%95%A5%E4%BC%98%E5%8C%96/"/>
</entry>
<entry>
<title>LLM 训练:GSPO 算法详解与 GRPO 对比</title>
<link href="https://murphypei.github.io/blog/2025/11/llm-gspo.html"/>
<id>https://murphypei.github.io/blog/2025/11/llm-gspo.html</id>
<published>2025-11-18T08:20:00.000Z</published>
<updated>2026-01-13T06:51:17.655Z</updated>
<content type="html"><![CDATA[<p>GSPO(Group Sequence Policy Optimization,群组序列策略优化)是对 GRPO 的重要改进。通过将优化粒度从 <strong>token 级</strong> 提升到 <strong>序列级</strong>,GSPO 从根本上解决了 GRPO 在处理长文本、MoE 模型时的训练不稳定问题,同时保持了轻量化的优势(无需 Critic 模型)。</p><span id="more"></span><p>在 <a href="https://murphypei.github.io/blog/2025/07/llm-grpo.html">GRPO 博文</a>中,我们详细介绍了 GRPO 的 token 级优化方法。本文深入探讨 GSPO 的算法原理,以及与 GRPO 的核心区别。</p><p><strong>GSPO 的核心创新</strong>:</p><ol><li><strong>序列级重要性采样</strong>:不计算每个 token 的重要性比率,而是计算整个序列的重要性比率</li><li><strong>几何平均归一化</strong>:通过开方操作消除序列长度对概率乘积的影响</li><li><strong>好坏序列对比学习</strong>:直接通过好序列和坏序列的对比来计算优势</li><li><strong>更稳定的训练</strong>:避免 token 级别方差大的问题</li></ol><h2 id="GRPO-vs-GSPO:概念回顾"><a href="#GRPO-vs-GSPO:概念回顾" class="headerlink" title="GRPO vs GSPO:概念回顾"></a>GRPO vs GSPO:概念回顾</h2><h3 id="GRPO-的-Token-级优化"><a href="#GRPO-的-Token-级优化" class="headerlink" title="GRPO 的 Token 级优化"></a>GRPO 的 Token 级优化</h3><p>GRPO(Group Relative Policy Optimization)采用 <strong>token 级别</strong> 的优化策略:</p><p><strong>核心特点</strong>:</p><ol><li><strong>Token 级重要性采样</strong>:对每个 token 计算 $r<em>t^{\text{(i)}} = \frac{\pi</em>\theta(a<em>t^{(i)}|s_t^{(i)})}{\pi</em>{\text{ref}}(a_t^{(i)}|s_t^{(i)})}$</li><li><strong>奖励分解分配</strong>:将序列奖励 $r_{\text{RM}}$ 分配到每个 token(通常采用折扣分配)</li><li><strong>Token 级优势</strong>:每个 token 有独立的优势值 $\hat{A}_t^{(i)}$</li><li><strong>Token 级裁剪</strong>:对每个 token 的重要性比率进行 PPO 风格裁剪</li><li><strong>损失聚合</strong>:先对单序列内的 token 平均,再对 batch 平均</li></ol><p><strong>优点</strong>:</p><ul><li>贴合模型逐 token 生成的过程</li><li>提供精细的梯度信号</li><li>理论基础清晰(PPO 风格)</li></ul><p><strong>缺点</strong>:</p><ul><li>Token 级方差大,特别是在长序列中</li><li>奖励分摊到单个 token 可能出现矛盾</li><li>MoE 模型中某些专家可能退化(处理少数低质量 token)</li></ul><h3 id="GSPO-的序列级优化"><a href="#GSPO-的序列级优化" class="headerlink" title="GSPO 的序列级优化"></a>GSPO 的序列级优化</h3><p>GSPO(Group Sequence Policy Optimization)采用 <strong>序列级别</strong> 的优化策略:</p><p><strong>核心特点</strong>:</p><ol><li><strong>序列级重要性采样</strong>:计算整个序列的重要性比率 $s(y) = \left(\frac{\pi<em>\theta(y|x)}{\pi</em>{\text{ref}}(y|x)}\right)^{1/L}$</li><li><strong>不分摊奖励</strong>:直接使用完整序列的整体奖励</li><li><strong>几何平均归一化</strong>:用 $1/L$ 次方消除序列长度差异</li><li><strong>好坏序列对比</strong>:通过比较好序列和坏序列的平均重要性比率计算优势</li><li><strong>损失聚合</strong>:直接对序列计算损失,然后平均</li></ol><p><strong>优点</strong>:</p><ul><li>优化目标与人类评价逻辑一致(关注整体质量)</li><li>序列级方差小,训练更稳定</li><li>支持长文本和 MoE 模型更好</li><li>避免分摊奖励的矛盾</li></ul><p><strong>缺点</strong>:</p><ul><li>失去 token 级的精细控制</li><li>无法区分序列内各 token 的贡献</li><li>不同长度序列的比较可能仍存在偏差</li></ul><hr><h2 id="GSPO-的算法原理"><a href="#GSPO-的算法原理" class="headerlink" title="GSPO 的算法原理"></a>GSPO 的算法原理</h2><h3 id="1-序列级重要性采样比率"><a href="#1-序列级重要性采样比率" class="headerlink" title="1. 序列级重要性采样比率"></a>1. 序列级重要性采样比率</h3><p>GSPO 的核心创新是 <strong>序列级重要性采样比率</strong>,这解决了序列长度不同导致概率乘积差异大的问题。</p><h4 id="序列概率计算"><a href="#序列概率计算" class="headerlink" title="序列概率计算"></a>序列概率计算</h4><p>对于长度为 $L$ 的序列 $y = (y_1, y_2, \ldots, y_L)$,整个序列的生成概率是所有 token 条件概率的乘积:</p><script type="math/tex; mode=display">\pi_\theta(y|x) = \prod_{t=1}^{L} \pi_\theta(y_t | x, y_1:t-1)</script><p>类似地,参考模型的序列概率为:</p><script type="math/tex; mode=display">\pi_{\text{ref}}(y|x) = \prod_{t=1}^{L} \pi_{\text{ref}}(y_t | x, y_1:t-1)</script><h4 id="几何平均归一化"><a href="#几何平均归一化" class="headerlink" title="几何平均归一化"></a>几何平均归一化</h4><p>简单的概率比 $\frac{\pi<em>\theta(y|x)}{\pi</em>{\text{ref}}(y|x)}$ 存在一个问题:<strong>长序列的概率乘积会更小,导致长序列的重要性比率被低估</strong>。</p><p>例如:</p><ul><li>短序列(长度 3):$0.5 \times 0.5 \times 0.5 = 0.125$</li><li>长序列(长度 10):$0.5^{10} \approx 0.001$</li></ul><p>即使两个序列的 token 概率分布相同,长序列的重要性比率也会被人为降低。</p><p><strong>GSPO 的解决方案</strong>:使用 <strong>几何平均</strong>(geometric mean),即开 $L$ 次方:</p><script type="math/tex; mode=display">s(y) = \left(\frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}\right)^{1/L}</script><p><strong>直观理解</strong>:</p><ul><li>这相当于计算 “平均每个 token 的概率增长倍数”</li><li>长序列和短序列可以公平比较</li><li>与序列长度无关</li></ul><p><strong>数值例子</strong>:</p><p>假设一个长度为 $L$ 的序列中,每个 token 的概率都从 0.3 增加到 0.4:</p><script type="math/tex; mode=display">\frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)} = \left(\frac{0.4}{0.3}\right)^L = (1.333)^L</script><p>不同长度下的比率:</p><ul><li>$L=3$:$(1.333)^3 \approx 2.37 \rightarrow$ 开 3 次方 $= 1.333$</li><li>$L=10$:$(1.333)^{10} \approx 33.9 \rightarrow$ 开 10 次方 $= 1.333$</li></ul><p><strong>经过几何平均,得到相同的 $s(y) = 1.333$</strong>,公平地反映了两个序列的改进程度。</p><h4 id="对数概率视角"><a href="#对数概率视角" class="headerlink" title="对数概率视角"></a>对数概率视角</h4><p>为了数值稳定性,实践中通常用对数形式:</p><script type="math/tex; mode=display">\log s(y) = \frac{1}{L} \sum_{t=1}^{L} \left(\log \pi_\theta(y_t|x, y_{1:t-1}) - \log \pi_{\text{ref}}(y_t|x, y_{1:t-1})\right)</script><p>或者:</p><script type="math/tex; mode=display">\log s(y) = \frac{1}{L} \left(\log \pi_\theta(y|x) - \log \pi_{\text{ref}}(y|x)\right)</script><p>这样避免了浮点数下溢。</p><h3 id="2-组内优势计算"><a href="#2-组内优势计算" class="headerlink" title="2. 组内优势计算"></a>2. 组内优势计算</h3><p>与 GRPO 不同,GSPO 不计算单个序列的绝对优势,而是计算 <strong>组内相对优势</strong>——即好序列相对于坏序列的优势。</p><h4 id="序列分组"><a href="#序列分组" class="headerlink" title="序列分组"></a>序列分组</h4><p>对于一个 batch 中的 $B$ 个 prompt,每个 prompt 采样 $G$ 个序列,得到 $B \times G$ 个序列。</p><p>对于每个 prompt,按奖励将 $G$ 个序列分成两组:</p><ul><li><strong>好序列组</strong> $G_{\text{good}}$:奖励高于组内平均的序列</li><li><strong>坏序列组</strong> $G_{\text{bad}}$:奖励低于组内平均的序列</li></ul><p><strong>分组标准</strong>:</p><script type="math/tex; mode=display">\bar{r} = \frac{1}{G} \sum_{i=1}^{G} r_i</script><ul><li>若 $r_i > \bar{r}$:序列 $i$ 属于好序列组</li><li>若 $r_i < \bar{r}$:序列 $i$ 属于坏序列组</li></ul><h4 id="组内平均重要性比率"><a href="#组内平均重要性比率" class="headerlink" title="组内平均重要性比率"></a>组内平均重要性比率</h4><p>计算两组的平均重要性比率:</p><script type="math/tex; mode=display">\bar{s}_{\text{good}} = \frac{1}{|G_{\text{good}}|} \sum_{y \in G_{\text{good}}} s(y)</script><script type="math/tex; mode=display">\bar{s}_{\text{bad}} = \frac{1}{|G_{\text{bad}}|} \sum_{y \in G_{\text{bad}}} s(y)</script><h4 id="序列级优势"><a href="#序列级优势" class="headerlink" title="序列级优势"></a>序列级优势</h4><p>序列级优势定义为两组平均重要性比率的差:</p><script type="math/tex; mode=display">A = \bar{s}_{\text{good}} - \bar{s}_{\text{bad}}</script><p><strong>含义</strong>:</p><ul><li>若 $A > 0$:新策略目前更倾向生成好序列(方向正确)</li><li>若 $A < 0$:新策略目前更倾向生成坏序列(需要调整)</li></ul><p><strong>注意</strong>:这里的优势是 <strong>全局的</strong>,所有好序列和坏序列之间的平均比较。</p><h3 id="3-损失函数"><a href="#3-损失函数" class="headerlink" title="3. 损失函数"></a>3. 损失函数</h3><p>GSPO 的损失函数包含两部分:策略损失和 KL 约束。</p><h4 id="策略损失"><a href="#策略损失" class="headerlink" title="策略损失"></a>策略损失</h4><script type="math/tex; mode=display">L_{\text{Policy}} = -\left(\bar{s}_{\text{good}} - \bar{s}_{\text{bad}}\right) \cdot \min(A, A_{\text{clip}})</script><p>或使用 PPO 风格的裁剪:</p><script type="math/tex; mode=display">L_{\text{Policy}}^{\text{CLIP}} = -\min\left(\bar{s}_{\text{good}} - \bar{s}_{\text{bad}}, \text{clip}(\bar{s}_{\text{good}} - \bar{s}_{\text{bad}}, 1-\epsilon, 1+\epsilon)\right) \cdot A</script><p>其中 $\epsilon = 0.2$ 是裁剪范围。</p><p><strong>目标</strong>:</p><ul><li>提高好序列的整体生成概率($\bar{s}_{\text{good}}$ 增大)</li><li>降低坏序列的整体生成概率($\bar{s}_{\text{bad}}$ 减小)</li><li>限制更新幅度(通过裁剪)</li></ul><h4 id="KL-约束"><a href="#KL-约束" class="headerlink" title="KL 约束"></a>KL 约束</h4><script type="math/tex; mode=display">L_{\text{KL}} = \frac{1}{B \times G} \sum_{y} D_{\text{KL}}(\pi_\theta(\cdot|x) \| \pi_{\text{ref}}(\cdot|x))</script><p>这里对整个序列计算 KL 散度,防止策略偏离参考模型太远。</p><h4 id="总损失"><a href="#总损失" class="headerlink" title="总损失"></a>总损失</h4><script type="math/tex; mode=display">L_{\text{total}} = L_{\text{Policy}} + \beta \cdot L_{\text{KL}}</script><p>其中 $\beta$ 是 KL 权重(通常 0.01-0.1)。</p><hr><h2 id="GSPO-vs-GRPO-的详细对比"><a href="#GSPO-vs-GRPO-的详细对比" class="headerlink" title="GSPO vs GRPO 的详细对比"></a>GSPO vs GRPO 的详细对比</h2><h3 id="重要性采样粒度对比"><a href="#重要性采样粒度对比" class="headerlink" title="重要性采样粒度对比"></a>重要性采样粒度对比</h3><div class="table-container"><table><thead><tr><th>维度</th><th>GRPO</th><th>GSPO</th></tr></thead><tbody><tr><td><strong>采样粒度</strong></td><td>Token 级别</td><td><strong>序列级别</strong></td></tr><tr><td><strong>重要性比率定义</strong></td><td>$r<em>t = \frac{\pi</em>\theta(a_t\</td><td>s<em>t)}{\pi</em>{\text{ref}}(a_t\</td><td>s_t)}$</td><td>$s(y) = \left(\frac{\pi_\theta(y\</td><td>x)}{\pi_{\text{ref}}(y\</td><td>x)}\right)^{1/L}$</td></tr><tr><td><strong>计算复杂度</strong></td><td>低(每个 token)</td><td>中(每个序列)</td></tr><tr><td><strong>长序列影响</strong></td><td>方差大</td><td><strong>方差小</strong>(几何平均)</td></tr><tr><td><strong>归一化</strong></td><td>无需(token 独立)</td><td><strong>有(几何平均)</strong></td></tr></tbody></table></div><h3 id="优势计算对比"><a href="#优势计算对比" class="headerlink" title="优势计算对比"></a>优势计算对比</h3><div class="table-container"><table><thead><tr><th>维度</th><th>GRPO</th><th>GSPO</th></tr></thead><tbody><tr><td><strong>优势定义</strong></td><td>组内平均 - 单序列奖励</td><td><strong>好序列均值 - 坏序列均值</strong></td></tr><tr><td><strong>粒度</strong></td><td>Token 级别</td><td><strong>序列级别</strong></td></tr><tr><td><strong>奖励处理</strong></td><td>分摊到 token</td><td><strong>直接使用序列奖励</strong></td></tr><tr><td><strong>聚合方式</strong></td><td>按位置折扣</td><td><strong>组内平均</strong></td></tr><tr><td><strong>长度依赖性</strong></td><td>高(分摊依赖长度)</td><td><strong>低(几何平均消除)</strong></td></tr></tbody></table></div><h3 id="梯度更新对比"><a href="#梯度更新对比" class="headerlink" title="梯度更新对比"></a>梯度更新对比</h3><div class="table-container"><table><thead><tr><th>维度</th><th>GRPO</th><th>GSPO</th></tr></thead><tbody><tr><td><strong>损失计算</strong></td><td>$\sum_t \min(r_t A_t, \text{clip}(r_t) A_t)$</td><td><strong>好序列梯度 - 坏序列梯度</strong></td></tr><tr><td><strong>裁剪方式</strong></td><td>Token 级 PPO 裁剪</td><td><strong>序列级裁剪</strong></td></tr><tr><td><strong>梯度信号</strong></td><td>每个 token 独立</td><td><strong>序列作为整体</strong></td></tr><tr><td><strong>方差来源</strong></td><td>Token 级波动</td><td><strong>序列级波动</strong>(更小)</td></tr></tbody></table></div><h3 id="性能对比"><a href="#性能对比" class="headerlink" title="性能对比"></a>性能对比</h3><div class="table-container"><table><thead><tr><th>指标</th><th>GRPO</th><th>GSPO</th></tr></thead><tbody><tr><td><strong>训练稳定性</strong></td><td>中(长序列波动大)</td><td><strong>高</strong>(序列级平均)</td></tr><tr><td><strong>收敛速度</strong></td><td>中等</td><td><strong>更快</strong>(更稳定的梯度)</td></tr><tr><td><strong>长文本支持</strong></td><td>较差</td><td><strong>很好</strong></td></tr><tr><td><strong>MoE 模型</strong></td><td>不稳定(专家利用率差异大)</td><td><strong>稳定</strong></td></tr><tr><td><strong>内存占用</strong></td><td>低</td><td><strong>低</strong>(相同)</td></tr><tr><td><strong>计算复杂度</strong></td><td>低</td><td><strong>中</strong>(但更稳定)</td></tr></tbody></table></div><hr><h2 id="GSPO-的实现细节"><a href="#GSPO-的实现细节" class="headerlink" title="GSPO 的实现细节"></a>GSPO 的实现细节</h2><h3 id="1-序列级重要性比率计算"><a href="#1-序列级重要性比率计算" class="headerlink" title="1. 序列级重要性比率计算"></a>1. 序列级重要性比率计算</h3><h4 id="伪代码"><a href="#伪代码" class="headerlink" title="伪代码"></a>伪代码</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">compute_sequence_importance_ratio</span>(<span class="params">actor_logits, ref_logits, </span></span><br><span class="line"><span class="params"> token_ids, sequence_lengths</span>):</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> 计算序列级重要性比率</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> actor_logits: shape (B, G, T, V) - Actor 模型 logits</span></span><br><span class="line"><span class="string"> ref_logits: shape (B, G, T, V) - Reference 模型 logits</span></span><br><span class="line"><span class="string"> token_ids: shape (B, G, T) - 生成的 token IDs</span></span><br><span class="line"><span class="string"> sequence_lengths: shape (B, G) - 每个序列的实际长度</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> Returns:</span></span><br><span class="line"><span class="string"> seq_ratios: shape (B, G) - 序列级重要性比率</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> B, G, T, V = actor_logits.shape</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算 log 概率</span></span><br><span class="line"> actor_log_probs = log_softmax(actor_logits, dim=-<span class="number">1</span>) <span class="comment"># shape (B, G, T, V)</span></span><br><span class="line"> ref_log_probs = log_softmax(ref_logits, dim=-<span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 提取生成 token 的 log 概率</span></span><br><span class="line"> actor_seq_log_prob = torch.zeros(B, G)</span><br><span class="line"> ref_seq_log_prob = torch.zeros(B, G)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> b <span class="keyword">in</span> <span class="built_in">range</span>(B):</span><br><span class="line"> <span class="keyword">for</span> g <span class="keyword">in</span> <span class="built_in">range</span>(G):</span><br><span class="line"> L = sequence_lengths[b, g]</span><br><span class="line"> <span class="comment"># 求和所有 token 的 log 概率</span></span><br><span class="line"> <span class="keyword">for</span> t <span class="keyword">in</span> <span class="built_in">range</span>(L):</span><br><span class="line"> token_id = token_ids[b, g, t]</span><br><span class="line"> actor_seq_log_prob[b, g] += actor_log_probs[b, g, t, token_id]</span><br><span class="line"> ref_seq_log_prob[b, g] += ref_log_probs[b, g, t, token_id]</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算序列级重要性比率(取幂化为线性)</span></span><br><span class="line"> log_ratio = actor_seq_log_prob - ref_seq_log_prob</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 几何平均:除以序列长度</span></span><br><span class="line"> seq_ratios = torch.exp(log_ratio / sequence_lengths)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> seq_ratios</span><br></pre></td></tr></table></figure><h3 id="2-好坏序列分组与优势计算"><a href="#2-好坏序列分组与优势计算" class="headerlink" title="2. 好坏序列分组与优势计算"></a>2. 好坏序列分组与优势计算</h3><h4 id="伪代码-1"><a href="#伪代码-1" class="headerlink" title="伪代码"></a>伪代码</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">compute_sequence_advantage</span>(<span class="params">seq_ratios, rewards, sequence_lengths</span>):</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> 计算序列级优势</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> seq_ratios: shape (B, G) - 序列级重要性比率</span></span><br><span class="line"><span class="string"> rewards: shape (B, G) - 每个序列的奖励</span></span><br><span class="line"><span class="string"> sequence_lengths: shape (B, G) - 序列长度</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> Returns:</span></span><br><span class="line"><span class="string"> advantages: shape (B, G) - 序列级优势</span></span><br><span class="line"><span class="string"> group_advantages: shape (B,) - 每个 prompt 的组优势</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> B, G = seq_ratios.shape</span><br><span class="line"> </span><br><span class="line"> advantages = torch.zeros(B, G)</span><br><span class="line"> group_advantages = torch.zeros(B)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> b <span class="keyword">in</span> <span class="built_in">range</span>(B):</span><br><span class="line"> <span class="comment"># 计算该 prompt 的平均奖励</span></span><br><span class="line"> mean_reward = rewards[b].mean()</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 分组</span></span><br><span class="line"> good_mask = rewards[b] >= mean_reward</span><br><span class="line"> bad_mask = ~good_mask</span><br><span class="line"> </span><br><span class="line"> good_ratios = seq_ratios[b, good_mask]</span><br><span class="line"> bad_ratios = seq_ratios[b, bad_mask]</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算好坏序列的平均重要性比率</span></span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">len</span>(good_ratios) > <span class="number">0</span>:</span><br><span class="line"> mean_good_ratio = good_ratios.mean()</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> mean_good_ratio = <span class="number">0.0</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">len</span>(bad_ratios) > <span class="number">0</span>:</span><br><span class="line"> mean_bad_ratio = bad_ratios.mean()</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> mean_bad_ratio = <span class="number">0.0</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 组优势</span></span><br><span class="line"> group_adv = mean_good_ratio - mean_bad_ratio</span><br><span class="line"> group_advantages[b] = group_adv</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 每个序列的优势(同组内所有序列相同)</span></span><br><span class="line"> <span class="keyword">for</span> g <span class="keyword">in</span> <span class="built_in">range</span>(G):</span><br><span class="line"> <span class="keyword">if</span> good_mask[g]:</span><br><span class="line"> advantages[b, g] = group_adv</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> advantages[b, g] = -group_adv <span class="comment"># 坏序列的优势为负</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> advantages, group_advantages</span><br></pre></td></tr></table></figure><h3 id="3-损失计算"><a href="#3-损失计算" class="headerlink" title="3. 损失计算"></a>3. 损失计算</h3><h4 id="伪代码-2"><a href="#伪代码-2" class="headerlink" title="伪代码"></a>伪代码</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">compute_loss</span>(<span class="params">actor_logits, ref_logits, token_ids, </span></span><br><span class="line"><span class="params"> seq_ratios, advantages, sequence_lengths,</span></span><br><span class="line"><span class="params"> beta=<span class="number">0.05</span>, epsilon=<span class="number">0.2</span></span>):</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> 计算 GSPO 损失</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> actor_logits: shape (B, G, T, V)</span></span><br><span class="line"><span class="string"> ref_logits: shape (B, G, T, V)</span></span><br><span class="line"><span class="string"> token_ids: shape (B, G, T)</span></span><br><span class="line"><span class="string"> seq_ratios: shape (B, G) - 序列级重要性比率</span></span><br><span class="line"><span class="string"> advantages: shape (B, G) - 序列级优势</span></span><br><span class="line"><span class="string"> sequence_lengths: shape (B, G)</span></span><br><span class="line"><span class="string"> beta: KL 权重</span></span><br><span class="line"><span class="string"> epsilon: 裁剪范围</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> Returns:</span></span><br><span class="line"><span class="string"> total_loss: 标量损失</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> B, G, T, V = actor_logits.shape</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算 log 概率</span></span><br><span class="line"> actor_log_probs = log_softmax(actor_logits, dim=-<span class="number">1</span>)</span><br><span class="line"> ref_log_probs = log_softmax(ref_logits, dim=-<span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"> policy_loss = <span class="number">0.0</span></span><br><span class="line"> kl_loss = <span class="number">0.0</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> b <span class="keyword">in</span> <span class="built_in">range</span>(B):</span><br><span class="line"> <span class="keyword">for</span> g <span class="keyword">in</span> <span class="built_in">range</span>(G):</span><br><span class="line"> L = sequence_lengths[b, g]</span><br><span class="line"> adv = advantages[b, g]</span><br><span class="line"> ratio = seq_ratios[b, g]</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># PPO 风格的裁剪</span></span><br><span class="line"> clipped_ratio = clamp(ratio, <span class="number">1</span> - epsilon, <span class="number">1</span> + epsilon)</span><br><span class="line"> policy_term = <span class="built_in">min</span>(ratio * adv, clipped_ratio * adv)</span><br><span class="line"> policy_loss += -policy_term <span class="comment"># 最大化优势,所以取负</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># KL 散度(序列级)</span></span><br><span class="line"> kl_sum = <span class="number">0.0</span></span><br><span class="line"> <span class="keyword">for</span> t <span class="keyword">in</span> <span class="built_in">range</span>(L):</span><br><span class="line"> token_id = token_ids[b, g, t]</span><br><span class="line"> log_ratio_t = (actor_log_probs[b, g, t, token_id] - </span><br><span class="line"> ref_log_probs[b, g, t, token_id])</span><br><span class="line"> kl_sum += log_ratio_t</span><br><span class="line"> </span><br><span class="line"> kl_loss += kl_sum / L <span class="comment"># 每个 token 的平均 KL</span></span><br><span class="line"> </span><br><span class="line"> total_loss = (policy_loss + beta * kl_loss) / (B * G)</span><br><span class="line"> <span class="keyword">return</span> total_loss</span><br></pre></td></tr></table></figure><h3 id="4-超参数设置"><a href="#4-超参数设置" class="headerlink" title="4. 超参数设置"></a>4. 超参数设置</h3><div class="table-container"><table><thead><tr><th>超参数</th><th>典型值</th><th>范围</th><th>说明</th></tr></thead><tbody><tr><td><strong>$G$ (组大小)</strong></td><td>8-16</td><td>4-32</td><td>更大的组更稳定</td></tr><tr><td><strong>$\epsilon$ (裁剪范围)</strong></td><td>0.2</td><td>0.1-0.3</td><td>限制重要性比率范围</td></tr><tr><td><strong>$\beta$ (KL 权重)</strong></td><td>0.05</td><td>0.01-0.1</td><td>平衡奖励和 KL 约束</td></tr><tr><td><strong>学习率</strong></td><td>5e-7</td><td>1e-7-1e-6</td><td>通常比 SFT 小 100-1000 倍</td></tr><tr><td><strong>批大小</strong></td><td>$B=8, G=16$</td><td>-</td><td>总共 128 个序列</td></tr></tbody></table></div><h4 id="调优建议"><a href="#调优建议" class="headerlink" title="调优建议"></a>调优建议</h4><ol><li><p><strong>组大小 $G$</strong>:</p><ul><li>GSPO 对 $G$ 的敏感度低于 GRPO(因为序列级优势更稳定)</li><li>推荐从 $G=8$ 开始</li><li>如果仍有波动,增加到 $G=16$ 或 $G=32$</li></ul></li><li><p><strong>KL 权重 $\beta$</strong>:</p><ul><li>观察 KL 散度大小:<ul><li>KL > 0.1:增大 $\beta$</li><li>KL < 0.01:减小 $\beta$</li><li>目标:0.01-0.05</li></ul></li><li>GSPO 通常需要更小的 $\beta$(因为更稳定)</li></ul></li><li><p><strong>学习率</strong>:</p><ul><li>GSPO 对学习率不如 GRPO 敏感</li><li>可以用略大的学习率(5e-7)</li><li>监测梯度范数,避免梯度爆炸</li></ul></li></ol><hr><h2 id="GSPO-完整数值示例"><a href="#GSPO-完整数值示例" class="headerlink" title="GSPO 完整数值示例"></a>GSPO 完整数值示例</h2><h3 id="假设条件"><a href="#假设条件" class="headerlink" title="假设条件"></a>假设条件</h3><ul><li><strong>批大小</strong>:$B = 1$(1 个 prompt)</li><li><strong>组大小</strong>:$G = 3$(3 个序列)</li><li><strong>序列长度</strong>:$L_1 = 5, L_2 = 7, L_3 = 6$(长度不同)</li><li><strong>奖励</strong>:$r_1 = 8, r_2 = 6, r_3 = 4$</li><li><strong>平均奖励</strong>:$\bar{r} = 6$</li></ul><h3 id="计算过程"><a href="#计算过程" class="headerlink" title="计算过程"></a>计算过程</h3><h4 id="步骤-1-分组"><a href="#步骤-1-分组" class="headerlink" title="步骤 1: 分组"></a>步骤 1: 分组</h4><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">平均奖励: r_bar = (8 + 6 + 4) / 3 = 6</span><br><span class="line"></span><br><span class="line">好序列组(r_i >= 6):</span><br><span class="line"> - 序列 1: r_1 = 8 ✓</span><br><span class="line"> - 序列 2: r_2 = 6 ✓</span><br><span class="line"> </span><br><span class="line">坏序列组(r_i < 6):</span><br><span class="line"> - 序列 3: r_3 = 4 ✓</span><br><span class="line"></span><br><span class="line">分组结果:G_good = {seq1, seq2}, G_bad = {seq3}</span><br></pre></td></tr></table></figure><h4 id="步骤-2-计算序列级重要性比率"><a href="#步骤-2-计算序列级重要性比率" class="headerlink" title="步骤 2: 计算序列级重要性比率"></a>步骤 2: 计算序列级重要性比率</h4><p>假设:</p><ul><li><p>序列 1(长度 5):每个 token 概率比平均 1.1 倍</p><ul><li>$\log \pi<em>\theta = -0.5$,$\log \pi</em>{\text{ref}} = -0.55$</li><li>序列 log 比:$5 \times 0.05 = 0.25$</li><li>几何平均:$\log s(y_1) = 0.25 / 5 = 0.05 \Rightarrow s(y_1) \approx 1.051$</li></ul></li><li><p>序列 2(长度 7):每个 token 概率比平均 1.05 倍</p><ul><li>序列 log 比:$7 \times 0.0488 = 0.342$</li><li>几何平均:$\log s(y_2) = 0.342 / 7 = 0.049 \Rightarrow s(y_2) \approx 1.050$</li></ul></li><li><p>序列 3(长度 6):每个 token 概率比平均 0.95 倍</p><ul><li>序列 log 比:$6 \times (-0.0513) = -0.308$</li><li>几何平均:$\log s(y_3) = -0.308 / 6 = -0.0513 \Rightarrow s(y_3) \approx 0.950$</li></ul></li></ul><p><strong>关键观察</strong>:尽管序列长度不同(5, 7, 6),经过几何平均后,序列 1 和 2 的 $s$ 值相近(都约 1.05),反映了相同的改进程度。</p><h4 id="步骤-3-计算组平均重要性比率"><a href="#步骤-3-计算组平均重要性比率" class="headerlink" title="步骤 3: 计算组平均重要性比率"></a>步骤 3: 计算组平均重要性比率</h4><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">好序列组平均:</span><br><span class="line"> s_good = (s(y_1) + s(y_2)) / 2 = (1.051 + 1.050) / 2 ≈ 1.0505</span><br><span class="line"></span><br><span class="line">坏序列组平均:</span><br><span class="line"> s_bad = s(y_3) = 0.950</span><br><span class="line"></span><br><span class="line">组优势:</span><br><span class="line"> A = s_good - s_bad = 1.0505 - 0.950 = 0.1005</span><br></pre></td></tr></table></figure><p><strong>解释</strong>:好序列组的重要性比率高于坏序列组,说明新策略目前倾向于生成好序列,优化方向正确。</p><h4 id="步骤-4-计算损失"><a href="#步骤-4-计算损失" class="headerlink" title="步骤 4: 计算损失"></a>步骤 4: 计算损失</h4><p>假设 $\epsilon = 0.2$,$\beta = 0.05$:</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line">裁剪重要性比率(假设所有 s 值都在 0.8-1.2 范围内):</span><br><span class="line"> clipped_good ≈ 1.0505(无需裁剪)</span><br><span class="line"> clipped_bad ≈ 0.950(无需裁剪)</span><br><span class="line"></span><br><span class="line">策略损失(对好坏序列分别计算):</span><br><span class="line"> - 好序列 1: loss_1 = -min(s(y_1) * A, clip(s) * A)</span><br><span class="line"> = -min(1.051 * 0.1005, 1.051 * 0.1005)</span><br><span class="line"> = -0.1057</span><br><span class="line"> </span><br><span class="line"> - 好序列 2: loss_2 = -min(1.050 * 0.1005, 1.050 * 0.1005)</span><br><span class="line"> = -0.1055</span><br><span class="line"> </span><br><span class="line"> - 坏序列 3: loss_3 = -min(0.950 * (-0.1005), 0.950 * (-0.1005))</span><br><span class="line"> = 0.0955</span><br><span class="line"></span><br><span class="line">平均策略损失:</span><br><span class="line"> L_policy = (0.1057 + 0.1055 + 0.0955) / 3 ≈ 0.1022</span><br><span class="line"></span><br><span class="line">KL 散度(简化,假设每个 token 的 log ratio 平均 0.05):</span><br><span class="line"> L_KL ≈ 0.05</span><br><span class="line"></span><br><span class="line">总损失:</span><br><span class="line"> L_total = L_policy + β * L_KL = 0.1022 + 0.05 * 0.05 = 0.1047</span><br></pre></td></tr></table></figure><p><strong>关键观察</strong>:</p><ul><li>好序列的损失为正(鼓励),坏序列的损失为负(惩罚)</li><li>通过最小化总损失,模型会学习生成更多好序列</li><li>KL 项保证了模型不会偏离参考模型太远</li></ul><hr><h2 id="GRPO-vs-GSPO-的选择"><a href="#GRPO-vs-GSPO-的选择" class="headerlink" title="GRPO vs GSPO 的选择"></a>GRPO vs GSPO 的选择</h2><h3 id="什么情况下使用-GRPO"><a href="#什么情况下使用-GRPO" class="headerlink" title="什么情况下使用 GRPO"></a>什么情况下使用 GRPO</h3><ol><li><strong>短文本任务</strong>:生成文本长度一致,token 级方差不大</li><li><strong>需要精细控制</strong>:想要针对不同位置的 token 进行不同的优化</li><li><strong>实验早期</strong>:GRPO 理论更清晰,调试更容易</li></ol><h3 id="什么情况下使用-GSPO"><a href="#什么情况下使用-GSPO" class="headerlink" title="什么情况下使用 GSPO"></a>什么情况下使用 GSPO</h3><ol><li><strong>长文本任务</strong>:生成长序列或文本长度差异大</li><li><strong>模型规模大</strong>:大模型中 token 级波动更明显</li><li><strong>MoE 架构</strong>:避免专家利用率不均衡</li><li><strong>追求稳定性</strong>:GSPO 训练更稳定,收敛更快</li><li><strong>生产环境</strong>:GSPO 的鲁棒性更好</li></ol><h3 id="实践建议"><a href="#实践建议" class="headerlink" title="实践建议"></a>实践建议</h3><p><strong>从 GRPO 迁移到 GSPO 的步骤</strong>:</p><ol><li><strong>验证 GRPO 的基线</strong>:确保 GRPO 训练稳定</li><li><strong>调整参数</strong>:<ul><li>保持 $\beta$ 不变(或略减小)</li><li>可能需要增加 $G$(因为序列级优势更稳定)</li></ul></li><li><strong>逐步替换</strong>:先在部分数据上尝试 GSPO</li><li><strong>监控关键指标</strong>:<ul><li>KL 散度(应保持在 0.01-0.1)</li><li>奖励均值和方差</li><li>序列长度分布</li></ul></li></ol><hr><h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>GSPO 相比 GRPO 的核心改进:</p><ol><li><strong>序列级优化</strong>:从 token 级提升到序列级,减少方差</li><li><strong>几何平均归一化</strong>:公平处理不同长度的序列</li><li><strong>好坏对比学习</strong>:优化目标更直观、更稳定</li><li><strong>更强的鲁棒性</strong>:特别是在长文本和 MoE 模型上</li></ol><p><strong>关键理解</strong>:</p><ul><li>GRPO 和 GSPO 都不需要 Critic 模型</li><li>两者都使用组内比较来计算优势</li><li>GSPO 的创新在于 <strong>将粒度从 token 级提升到序列级</strong></li></ul><p><strong>性能对比</strong>:</p><ul><li><strong>训练稳定性</strong>:GSPO > GRPO(特别是长文本)</li><li><strong>收敛速度</strong>:GSPO ≥ GRPO</li><li><strong>计算复杂度</strong>:两者相近(GSPO 略高)</li><li><strong>理论完备性</strong>:GRPO(基于 PPO)vs GSPO(创新设计)</li></ul><p><strong>推荐阅读</strong>:</p><ul><li><a href="https://murphypei.github.io/blog/2025/07/llm-grpo.html">GRPO 博文</a>,理解基础概念</li><li><a href="https://murphypei.github.io/blog/2025/05/llm-ppo.html">PPO 博文</a>,理解策略梯度方法</li><li>对比两种算法的优缺点,根据实际场景选择</li></ul>]]></content>
<summary type="html">
<p>GSPO(Group Sequence Policy Optimization,群组序列策略优化)是对 GRPO 的重要改进。通过将优化粒度从 <strong>token 级</strong> 提升到 <strong>序列级</strong>,GSPO 从根本上解决了 GRPO 在处理长文本、MoE 模型时的训练不稳定问题,同时保持了轻量化的优势(无需 Critic 模型)。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="强化学习" scheme="https://murphypei.github.io/tags/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0/"/>
<category term="GRPO" scheme="https://murphypei.github.io/tags/GRPO/"/>
<category term="GSPO" scheme="https://murphypei.github.io/tags/GSPO/"/>
<category term="序列级优化" scheme="https://murphypei.github.io/tags/%E5%BA%8F%E5%88%97%E7%BA%A7%E4%BC%98%E5%8C%96/"/>
</entry>
<entry>
<title>让 LLM 输出规范 JSON 的方法</title>
<link href="https://murphypei.github.io/blog/2025/09/llm-output-json.html"/>
<id>https://murphypei.github.io/blog/2025/09/llm-output-json.html</id>
<published>2025-09-15T06:20:00.000Z</published>
<updated>2026-01-07T08:41:01.087Z</updated>
<content type="html"><![CDATA[<p>在现代 AI 应用开发中,让大语言模型(LLM)生成结构化的 JSON 数据是一个关键需求。无论是构建 API 服务、数据处理流水线,还是与现有系统集成,结构化输出都是必不可少的。本文将深入探讨多种让 LLM 生成规范 JSON 的方法,从基础技巧到高级工程实践。</p><span id="more"></span><h3 id="为什么需要结构化输出?"><a href="#为什么需要结构化输出?" class="headerlink" title="为什么需要结构化输出?"></a>为什么需要结构化输出?</h3><p>在实际应用中,我们经常需要 LLM 的输出能够被程序直接解析和使用,而不是仅仅作为文本供人类阅读。结构化的 JSON 输出具有以下优势:</p><ul><li><strong>可解析性</strong>:程序可以直接解析和处理 JSON 数据</li><li><strong>类型安全</strong>:明确的字段类型和结构规范</li><li><strong>可验证性</strong>:可以通过 Schema 验证数据完整性</li><li><strong>易集成</strong>:与现有系统和 API 无缝集成</li></ul><h3 id="方法一:提示词-JSON-format"><a href="#方法一:提示词-JSON-format" class="headerlink" title="方法一:提示词 + JSON format"></a>方法一:提示词 + JSON format</h3><p>这是最常用的让模型输出 JSON 格式的方法了。首先就是 prompt 中声明要输出 JSON,配合一些样例 few shot,然后呢,类似 Azure、Gemini 这类的大模型调用接口,都有类似 response_format 可以指定输出 JSON 格式。</p><p>模式如下:<br>JSON 提示词描述 + JSON 样例(Few shot)+ JSON response_format 来约束大模型的 JSON 输出。</p><p>如何用比较好地用提示词描述 JSON 字段呢?网络上比较好的实践是用 Typescript 或者 Yaml 格式描述(LLM生成Json结构化数据的几种方案,个人目前认为最好的方式依然是`TypeScript约束Prompt + Yaml格),当然简单地就直接用列表描述就好。可以参考这篇文章:<a href="https://juejin.cn/post/7325429835387404307">https://juejin.cn/post/7325429835387404307</a></p><p><strong>需要注意的点</strong>:</p><ul><li><ol><li>这个方法不能百分百保证。笔者在 gemini 2.5 pro,指定了 response_format 为 JSON,prompt 给了 few shot,在大型的 JSON 生成的时候,一样会失败。但是 gemini 2.5 pro 指定 json schema,生成成功率大幅度提高。</li></ol></li><li><ol><li>对于没有 json schema 参数的接口,few show 中的 json 样例非常重要。</li></ol></li></ul><p>再说一下这些 API 提供的 JSON response format,本质是一个 Constrained Decoding,即在预测下一字符时,把不符合 JSON 格式的丢掉,基本在高级模型可以非常稳定地输出 JSON 格式,所以还是会有一些极端 case 导致解码失败,或者解码后不是完整的 JSON 格式。</p><h3 id="方法二:Function-call"><a href="#方法二:Function-call" class="headerlink" title="方法二:Function call"></a>方法二:Function call</h3><p>本质和方法一其实是一样的…</p><h3 id="方法三:后处理"><a href="#方法三:后处理" class="headerlink" title="方法三:后处理"></a>方法三:后处理</h3><p>三个臭皮匠顶个诸葛亮,大模型输出的 JSON 不完美,那就做善后处理。笔者最开始用 GPT-4 生成 JSON 的时候,写了很多后处理的函数,包括:去掉 ```、修复转移错误、轻微语法问题等等。这些案例现在网上有很多,可以自行参考。</p><h3 id="方法四:大模型修复"><a href="#方法四:大模型修复" class="headerlink" title="方法四:大模型修复"></a>方法四:大模型修复</h3><p>在方法三的基础上,让大模型自己纠正输出的 JSON,其实比单独 JSON 解码要简单的多,大模型自己也能做好。但是这个方法的前提是,第一个模型输出的 JSON 内容是基本正确的,如果内容是错误的,后面的模型也很难纠正。</p><h3 id="结语"><a href="#结语" class="headerlink" title="结语"></a>结语</h3><p>总结下来就是,一套基本可用的链路就是:prompt + few shot / json schema + response formt + 后处理,这是 API 接口调用常规用法。笔者日常项目体验,目前比较强的模型基本这套链路能处理的生成 JSON 的规模已经很大了。</p>]]></content>
<summary type="html">
<p>在现代 AI 应用开发中,让大语言模型(LLM)生成结构化的 JSON 数据是一个关键需求。无论是构建 API 服务、数据处理流水线,还是与现有系统集成,结构化输出都是必不可少的。本文将深入探讨多种让 LLM 生成规范 JSON 的方法,从基础技巧到高级工程实践。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="JSON" scheme="https://murphypei.github.io/tags/JSON/"/>
<category term="结构化输出" scheme="https://murphypei.github.io/tags/%E7%BB%93%E6%9E%84%E5%8C%96%E8%BE%93%E5%87%BA/"/>
<category term="API" scheme="https://murphypei.github.io/tags/API/"/>
<category term="schema" scheme="https://murphypei.github.io/tags/schema/"/>
</entry>
<entry>
<title>翻译《如何构建多智能体研究系统:Anthropic 的工程实践》</title>
<link href="https://murphypei.github.io/blog/2025/09/how-to-build-multi-agents.html"/>
<id>https://murphypei.github.io/blog/2025/09/how-to-build-multi-agents.html</id>
<published>2025-08-31T22:17:48.000Z</published>
<updated>2026-01-07T08:41:01.087Z</updated>
<content type="html"><![CDATA[<p>最近在看打造 Agent 相关的研究,发现 Anthropic 他们的一篇文章写的特别好,有很多工程实践经验值得参考。虽然没有披露更多细节,但是也指出了很多方向。以下基本是原文翻译。</p><span id="more"></span><p>Claude 现在具备了研究功能,可以在网络、Google Workspace 以及任何集成系统中搜索,完成复杂任务。</p><p>这个多智能体系统从原型到生产的历程,让我们学到了关于系统架构、工具设计和提示工程的重要经验。多智能体系统由多个智能体(大语言模型在循环中自主使用工具)协同工作组成。我们的研究功能涉及一个智能体,它根据用户查询规划研究过程,然后使用工具创建并行智能体,同时搜索信息。具有多个智能体的系统在智能体协调、评估和可靠性方面引入了新的挑战。</p><p>本文分解了对我们有效的原则——我们希望您在构建自己的多智能体系统时会发现这些原则有用。</p><h2 id="多智能体系统的好处"><a href="#多智能体系统的好处" class="headerlink" title="多智能体系统的好处"></a>多智能体系统的好处</h2><p>研究工作涉及开放性问题,很难提前预测所需的步骤。你无法为探索复杂主题硬编码固定路径,因为这个过程本质上是动态的和路径依赖的。当人们进行研究时,他们倾向于根据发现持续更新方法,跟进调查过程中出现的线索。</p><p>这种不可预测性使 AI 智能体特别适合研究任务。研究需要灵活性,能够在调查展开时转向或探索切向连接。模型必须自主运行多轮,基于中间发现做出探索方向的决策。线性的一次性流水线无法处理这些任务。</p><p>搜索的本质是压缩:从庞大语料库中提取洞察。子智能体通过在各自的上下文窗口中并行操作来促进压缩,同时探索问题的不同方面,然后为主研究智能体压缩最重要的 token。每个子智能体还提供了关注点分离——不同的工具、提示和探索轨迹——这减少了路径依赖,实现了彻底、独立的调查。</p><p>一旦智能达到阈值,多智能体系统就成为扩展性能的重要方式。例如,虽然个体人类在过去 10 万年中变得更加智能,但在信息时代,人类社会因为我们的<strong>集体智能</strong>和协调能力而变得<strong>指数级</strong>更有能力。即使是通用智能体在单独操作时也面临限制;智能体群体可以完成更多任务。</p><p>我们的内部评估显示,多智能体研究系统在广度优先查询方面表现尤其出色,这些查询涉及同时追求多个独立方向。我们发现,以 Claude Opus 4 为主智能体、Claude Sonnet 4 为子智能体的多智能体系统在我们的内部研究评估中比单智能体 Claude Opus 4 表现好 90.2%。例如,当被要求识别信息技术 S&P 500 中所有公司的董事会成员时,多智能体系统通过将此分解为子智能体任务找到了正确答案,而单智能体系统在缓慢的顺序搜索中失败了。</p><p>多智能体系统主要起作用,因为它们帮助花费足够的 token 来解决问题。在我们的分析中,三个因素解释了 BrowseComp 评估中 95% 的性能方差(该评估测试浏览智能体定位难以找到信息的能力)。我们发现,token 使用本身解释了 80% 的方差,工具调用次数和模型选择是另外两个解释因素。这一发现验证了我们的架构,该架构将工作分布在具有独立上下文窗口的智能体中,为并行推理增加更多容量。最新的 Claude 模型作为 token 使用的大效率倍增器,因为升级到 Claude Sonnet 4 比在 Claude Sonnet 3.7 上将 token 预算翻倍带来更大的性能提升。多智能体架构有效地扩展了超出单智能体限制的任务的 token 使用。</p><p>有一个缺点:在实践中,这些架构快速消耗 token。在我们的数据中,智能体通常使用比聊天交互多约 4 倍的 token,而多智能体系统使用比聊天多约 15 倍的 token。对于经济可行性,多智能体系统需要任务的价值足够高,能够支付增加的性能成本。此外,一些需要所有智能体共享相同上下文或涉及智能体之间许多依赖关系的领域,目前不太适合多智能体系统。例如,大多数编码任务涉及的真正可并行化任务比研究少,而且 LLM 智能体还不擅长实时协调和委托给其他智能体。我们发现多智能体系统在涉及大量并行化、超出单个上下文窗口的信息以及与众多复杂工具接口的有价值任务中表现出色。</p><h2 id="研究功能的架构概述"><a href="#研究功能的架构概述" class="headerlink" title="研究功能的架构概述"></a>研究功能的架构概述</h2><p>我们的研究系统使用编排器-工作者模式的多智能体架构,其中主智能体协调过程,同时委托给并行操作的专门子智能体。</p><p><img src="/images/posts/agent/multi-agent-architecture.webp" alt="多智能体架构"></p><p><em>多智能体架构实际应用:用户查询通过主智能体流动,主智能体创建专门的子智能体来并行搜索不同方面。</em></p><p>当用户提交查询时,主智能体分析查询,制定策略,并生成子智能体来同时探索不同方面。如上图所示,子智能体作为智能过滤器,迭代使用搜索工具收集信息,在本例中是关于 2025 年的 AI 智能体公司,然后向主智能体返回公司列表,以便编制最终答案。</p><p>使用检索增强生成(RAG)的传统方法使用静态检索。也就是说,它们获取与输入查询最相似的一些块集合,并使用这些块生成响应。相比之下,我们的架构使用多步搜索,动态找到相关信息,适应新发现,并分析结果以制定高质量答案。</p><p><img src="/images/posts/agent/multi-agent-process.webp" alt="多智能体处理流程"></p><p><em>流程图显示了我们多智能体研究系统的完整工作流程。当用户提交查询时,系统创建一个 LeadResearcher 智能体,进入迭代研究过程。LeadResearcher 首先思考方法并将其计划保存到内存中以持久化上下文,因为如果上下文窗口超过 200,000 个 token,它将被截断,保留计划很重要。然后它创建具有特定研究任务的专门子智能体(这里显示了两个,但可以是任意数量)。每个子智能体独立执行网络搜索,使用交错思考评估工具结果,并将发现返回给 LeadResearcher。LeadResearcher 综合这些结果并决定是否需要更多研究——如果需要,它可以创建额外的子智能体或完善其策略。一旦收集到足够的信息,系统退出研究循环并将所有发现传递给 CitationAgent,后者处理文档和研究报告以识别引用的特定位置。这确保所有声明都正确归属于其来源。最终的研究结果,连同引用,然后返回给用户。</em></p><h2 id="研究智能体的提示工程和评估"><a href="#研究智能体的提示工程和评估" class="headerlink" title="研究智能体的提示工程和评估"></a>研究智能体的提示工程和评估</h2><p>多智能体系统与单智能体系统有关键差异,包括协调复杂性的快速增长。早期智能体犯了诸如为简单查询生成 50 个子智能体、无休止地搜索不存在的来源、以及过度更新相互干扰等错误。由于每个智能体都由提示引导,提示工程是我们改善这些行为的主要杠杆。以下是我们学到的一些提示智能体的原则:</p><h3 id="1-像智能体一样思考"><a href="#1-像智能体一样思考" class="headerlink" title="1. 像智能体一样思考"></a>1. 像智能体一样思考</h3><p>要迭代提示,你必须理解它们的效果。为了帮助我们做到这一点,我们使用来自系统的确切提示和工具,通过我们的控制台构建了模拟,然后逐步观察智能体工作。这立即揭示了故障模式:智能体在已经有足够结果时继续工作,使用过于冗长的搜索查询,或选择错误的工具。有效的提示依赖于开发智能体的准确心理模型,这可以使最有影响力的变化变得明显。</p><h3 id="2-教编排器如何委托"><a href="#2-教编排器如何委托" class="headerlink" title="2. 教编排器如何委托"></a>2. 教编排器如何委托</h3><p>在我们的系统中,主智能体将查询分解为子任务并向子智能体描述它们。每个子智能体需要一个目标、输出格式、使用工具和来源的指导,以及明确的任务边界。没有详细的任务描述,智能体会重复工作、留下空白或无法找到必要信息。我们开始允许主智能体给出简单、简短的指令,如”研究半导体短缺”,但发现这些指令通常足够模糊,子智能体会误解任务或执行与其他智能体完全相同的搜索。例如,一个子智能体探索 2021 年汽车芯片危机,而另外两个重复工作调查当前 2025 年供应链,没有有效的分工。</p><h3 id="3-根据查询复杂性调整努力"><a href="#3-根据查询复杂性调整努力" class="headerlink" title="3. 根据查询复杂性调整努力"></a>3. 根据查询复杂性调整努力</h3><p>智能体很难判断不同任务的适当努力,所以我们在提示中嵌入了缩放规则。简单的事实查找只需要 1 个智能体进行 3-10 次工具调用,直接比较可能需要 2-4 个子智能体,每个进行 10-15 次调用,复杂研究可能使用超过 10 个子智能体,有明确划分的责任。这些明确的指导方针帮助主智能体有效分配资源,防止在简单查询上过度投资,这是我们早期版本中的常见故障模式。</p><h3 id="4-工具设计和选择至关重要"><a href="#4-工具设计和选择至关重要" class="headerlink" title="4. 工具设计和选择至关重要"></a>4. 工具设计和选择至关重要</h3><p>智能体-工具接口与人机接口一样重要。使用正确的工具是高效的——通常,这是严格必要的。例如,智能体在网络上搜索包含学术术语的内容可能需要 Google Scholar,而不是通用网络搜索。我们设计工具时明确定义了它们的用途、限制和最佳用例。模糊的工具描述导致误用,从而产生无关结果和浪费的努力。我们还发现,为智能体提供太多工具选择可能令人困惑——策划的工具集比大量通用工具更有效。</p><h3 id="5-建立明确的停止条件"><a href="#5-建立明确的停止条件" class="headerlink" title="5. 建立明确的停止条件"></a>5. 建立明确的停止条件</h3><p>智能体可能会陷入无限循环,不断搜索或细化已经足够的结果。我们设置了明确的完成标准:特定数量的来源、满足的关键信息要求,或达到的时间/调用限制。我们还教智能体识别何时找到了”足够好”的答案,而不是完美的答案,这在开放性研究中经常是现实的目标。</p><p>评估多智能体系统需要新方法。传统指标如 BLEU 分数或困惑度无法捕捉复杂的多步骤研究过程的质量。我们开发了测量准确性、完整性、效率和可用性的定制评估。准确性检查智能体是否找到正确信息,完整性评估是否回答了查询的所有部分,效率测量达到结果所需的时间和资源,可用性评估最终输出对用户的有用性。</p><p>我们还发现人工评估仍然至关重要。虽然自动化指标可以大规模捕捉基本质量,但人类判断对于评估细微差别、创造性和整体实用性是不可替代的。我们的提示帮助解决了这个问题。即使在自动化评估的世界中,手动测试仍然必不可少。</p><p>多智能体系统具有涌现行为,这些行为在没有特定编程的情况下出现。例如,对主智能体的小改变可能不可预测地改变子智能体的行为方式。成功需要理解交互模式,而不仅仅是个体智能体行为。因此,这些智能体的最佳提示不仅仅是严格的指令,而是定义分工、问题解决方法和努力预算的协作框架。正确做到这一点依赖于仔细的提示和工具设计、可靠的启发式方法、可观察性和紧密的反馈循环。</p><h2 id="生产可靠性和工程挑战"><a href="#生产可靠性和工程挑战" class="headerlink" title="生产可靠性和工程挑战"></a>生产可靠性和工程挑战</h2><p>在传统软件中,错误可能破坏功能、降低性能或导致中断。在智能体系统中,微小的变化会级联成大的行为变化,这使得为必须在长时间运行过程中维护状态的复杂智能体编写代码变得非常困难。</p><h3 id="智能体有状态且错误会复合"><a href="#智能体有状态且错误会复合" class="headerlink" title="智能体有状态且错误会复合"></a>智能体有状态且错误会复合</h3><p>智能体可以长时间运行,在许多工具调用中维护状态。这意味着我们需要持久执行代码并处理沿途的错误。没有有效的缓解措施,轻微的系统故障对智能体来说可能是灾难性的。当错误发生时,我们不能只是从头开始重新启动:重新启动对用户来说是昂贵且令人沮丧的。相反,我们构建了可以从智能体发生错误时的位置恢复的系统。我们还使用模型的智能来优雅地处理问题:例如,让智能体知道工具何时失败并让它适应,效果出奇地好。我们将建立在 Claude 基础上的 AI 智能体的适应性与重试逻辑和定期检查点等确定性保障措施相结合。</p><h3 id="调试受益于新方法"><a href="#调试受益于新方法" class="headerlink" title="调试受益于新方法"></a>调试受益于新方法</h3><p>智能体做出动态决策,在运行之间是非确定性的,即使使用相同的提示。这使调试变得更困难。例如,用户会报告智能体”没有找到明显信息”,但我们看不出原因。智能体是使用糟糕的搜索查询吗?选择差的来源吗?遇到工具故障吗?添加完整的生产跟踪让我们诊断智能体失败的原因并系统性地修复问题。除了标准可观察性,我们监控智能体决策模式和交互结构——所有这些都不监控个别对话的内容,以维护用户隐私。这种高级可观察性帮助我们诊断根本原因,发现意外行为,并修复常见故障。</p><h3 id="部署需要仔细协调"><a href="#部署需要仔细协调" class="headerlink" title="部署需要仔细协调"></a>部署需要仔细协调</h3><p>智能体系统是几乎连续运行的提示、工具和执行逻辑的高度有状态网络。这意味着每当我们部署更新时,智能体可能处于其过程中的任何位置。因此,我们需要防止我们善意的代码更改破坏现有智能体。我们不能同时将每个智能体更新到新版本。相反,我们使用彩虹部署来避免干扰运行中的智能体,通过逐渐将流量从旧版本转移到新版本,同时保持两者同时运行。</p><h3 id="同步执行创建瓶颈"><a href="#同步执行创建瓶颈" class="headerlink" title="同步执行创建瓶颈"></a>同步执行创建瓶颈</h3><p>目前,我们的主智能体同步执行子智能体,等待每组子智能体完成后再继续。这简化了协调,但在智能体之间的信息流中创建了瓶颈。例如,主智能体无法引导子智能体,子智能体无法协调,整个系统可能因等待单个子智能体完成搜索而被阻塞。异步执行将实现额外的并行性:智能体并发工作并在需要时创建新的子智能体。但这种异步性在结果协调、状态一致性和跨子智能体错误传播方面增加了挑战。随着模型能够处理更长更复杂的研究任务,我们预期性能增益将证明复杂性是合理的。</p><h2 id="结论"><a href="#结论" class="headerlink" title="结论"></a>结论</h2><p>在构建 AI 智能体时,最后一英里通常成为大部分旅程。在开发机器上工作的代码库需要重要的工程才能成为可靠的生产系统。智能体系统中错误的复合性质意味着传统软件的小问题可能完全使智能体脱轨。一步失败可能导致智能体探索完全不同的轨迹,导致不可预测的结果。由于本文中描述的所有原因,原型和生产之间的差距通常比预期的更宽。</p><p>尽管存在这些挑战,多智能体系统已被证明对开放性研究任务有价值。用户表示 Claude 帮助他们找到了没有考虑过的商业机会,导航复杂的医疗保健选项,解决棘手的技术错误,并通过发现他们单独不会找到的研究连接节省了数天的工作。多智能体研究系统可以通过仔细的工程、全面的测试、细致的提示和工具设计、健壮的操作实践,以及对当前智能体能力有深刻理解的研究、产品和工程团队之间的紧密协作,在规模上可靠运行。我们已经看到这些系统正在改变人们解决复杂问题的方式。</p><p><em>人们今天使用研究功能的最常见方式。主要用例类别是:跨专业领域开发软件系统(10%),开发和优化专业和技术内容(8%),制定业务增长和收入生成策略(8%),协助学术研究和教育材料开发(7%),以及研究和验证关于人员、地点或组织的信息(5%)。</em></p><h2 id="附录"><a href="#附录" class="headerlink" title="附录"></a>附录</h2><p>以下是多智能体系统的一些额外杂项提示。</p><h3 id="改变状态的多轮智能体的终态评估"><a href="#改变状态的多轮智能体的终态评估" class="headerlink" title="改变状态的多轮智能体的终态评估"></a>改变状态的多轮智能体的终态评估</h3><p>评估在多轮对话中修改持久状态的智能体面临独特挑战。与只读研究任务不同,每个动作都可能改变后续步骤的环境,创建传统评估方法难以处理的依赖关系。我们发现专注于终态评估而不是逐轮分析是成功的。不是判断智能体是否遵循了特定过程,而是评估它是否达到了正确的最终状态。这种方法承认智能体可能找到实现相同目标的替代路径,同时仍确保它们提供预期结果。对于复杂工作流程,将评估分解为应该发生特定状态变化的离散检查点,而不是试图验证每个中间步骤。</p><h3 id="长期对话管理"><a href="#长期对话管理" class="headerlink" title="长期对话管理"></a>长期对话管理</h3><p>生产智能体经常参与跨越数百轮的对话,需要仔细的上下文管理策略。随着对话延长,标准上下文窗口变得不足,需要智能的压缩和内存机制。我们实施了智能体总结已完成工作阶段并在继续新任务之前将重要信息存储在外部内存中的模式。当上下文限制接近时,智能体可以通过仔细的交接生成具有干净上下文的新子智能体,同时保持连续性。此外,它们可以从内存中检索存储的上下文,如研究计划,而不是在达到上下文限制时丢失先前的工作。这种分布式方法防止上下文溢出,同时在扩展交互中保持对话连贯性。</p><h3 id="子智能体输出到文件系统以最小化”传话游戏”"><a href="#子智能体输出到文件系统以最小化”传话游戏”" class="headerlink" title="子智能体输出到文件系统以最小化”传话游戏”"></a>子智能体输出到文件系统以最小化”传话游戏”</h3><p>直接子智能体输出可以绕过主协调器处理某些类型的结果,提高保真度和性能。与其要求子智能体通过主智能体沟通一切,不如实施工件系统,专门智能体可以创建独立持久的输出。子智能体调用工具将其工作存储在外部系统中,然后将轻量级引用传回协调器。这防止在多阶段处理期间的信息丢失,并减少通过对话历史复制大输出的 token 开销。该模式特别适用于结构化输出,如代码、报告或数据可视化,其中子智能体的专门提示比通过通用协调器过滤产生更好的结果。</p><p><em>参考文献:</em></p><ul><li><a href="https://www.anthropic.com/engineering/multi-agent-research-system">原文链接</a></li></ul>]]></content>
<summary type="html">
<p>最近在看打造 Agent 相关的研究,发现 Anthropic 他们的一篇文章写的特别好,有很多工程实践经验值得参考。虽然没有披露更多细节,但是也指出了很多方向。以下基本是原文翻译。</p>
</summary>
<category term="Agent" scheme="https://murphypei.github.io/categories/Agent/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="AI" scheme="https://murphypei.github.io/tags/AI/"/>
<category term="Agent" scheme="https://murphypei.github.io/tags/Agent/"/>
<category term="Claude" scheme="https://murphypei.github.io/tags/Claude/"/>
<category term="多智能体" scheme="https://murphypei.github.io/tags/%E5%A4%9A%E6%99%BA%E8%83%BD%E4%BD%93/"/>
<category term="工程实践" scheme="https://murphypei.github.io/tags/%E5%B7%A5%E7%A8%8B%E5%AE%9E%E8%B7%B5/"/>
</entry>
<entry>
<title>RAG 中的检索核心问题</title>
<link href="https://murphypei.github.io/blog/2025/08/rag-retrieval-quality.html"/>
<id>https://murphypei.github.io/blog/2025/08/rag-retrieval-quality.html</id>
<published>2025-08-29T05:35:25.000Z</published>
<updated>2026-01-07T08:41:01.087Z</updated>
<content type="html"><![CDATA[<p>RAG(检索增强生成)系统的核心在于能否准确、高效地检索到与用户查询最相关的文档片段。检索质量的好坏直接决定了最终生成结果的准确性和可靠性。本文将深入探讨 RAG 系统中检索优化的关键策略和最佳实践。</p><span id="more"></span><h2 id="引言"><a href="#引言" class="headerlink" title="引言"></a>引言</h2><p>在当今的 AI 应用中,RAG 已经成为解决大语言模型知识局限性的重要技术方案。然而,构建一个高质量的 RAG 系统远不止是简单地将文档向量化后进行相似度搜索。检索环节的优化往往决定了整个系统的成败。</p><p>要保证 RAG 应用能准确检索到所需的文档,我们需要同时关注<strong>召回率(Recall)</strong>和<strong>准确率(Precision)</strong>:</p><ul><li><strong>召回率</strong>:所有相关文档中,被系统检索出来的比例</li><li><strong>准确率</strong>:被系统检索出来的文档中,真正相关的比例</li></ul><p>在 RAG 中,这两者往往是此消彼长的关系,需要找到一个最优的平衡点。</p><h2 id="检索质量评估框架"><a href="#检索质量评估框架" class="headerlink" title="检索质量评估框架"></a>检索质量评估框架</h2><h3 id="核心指标"><a href="#核心指标" class="headerlink" title="核心指标"></a>核心指标</h3><p>除了召回率和准确率,我们还需要关注以下指标:</p><ul><li><strong>NDCG(Normalized Discounted Cumulative Gain)</strong>:考虑排序位置的相关性指标</li><li><strong>MRR(Mean Reciprocal Rank)</strong>:平均倒数排名,衡量第一个相关结果的位置</li><li><strong>Hit Rate@K</strong>:前 K 个结果中包含相关文档的查询比例</li></ul><h3 id="评估数据集构建"><a href="#评估数据集构建" class="headerlink" title="评估数据集构建"></a>评估数据集构建</h3><p>建立高质量的评估数据集是优化的前提。例如,对于”如何优化深度学习模型的训练速度?”这样的查询,正相关文档应该是讨论 GPU 并行、批处理优化的内容,部分相关的可能涉及模型压缩、量化,而纯理论介绍或其他领域的文档则属于不相关。</p><h2 id="提升召回率的策略"><a href="#提升召回率的策略" class="headerlink" title="提升召回率的策略"></a>提升召回率的策略</h2><p>召回率关注的是”不漏”,即尽可能地找到所有相关的文档。</p><h3 id="1-优化分块(Chunking)策略"><a href="#1-优化分块(Chunking)策略" class="headerlink" title="1. 优化分块(Chunking)策略"></a>1. 优化分块(Chunking)策略</h3><p>分块是 RAG 的基础,而且非常重要。直接影响检索效果:</p><h4 id="分块大小优化"><a href="#分块大小优化" class="headerlink" title="分块大小优化"></a>分块大小优化</h4><ul><li><strong>固定大小分块</strong>:通常 512-1024 tokens 为宜。既不会太短,也不会太长。分块太小可能丢失上下文,分块太大则可能引入无关信息。</li><li><strong>基于语义的动态分块</strong>:尽量按照段落、章节等自然语义边界进行分块,而不是简单的固定字符数。例如,将完整的问答对或表格作为独立的块。</li><li><strong>重叠策略</strong>: 在相邻分块之间设置一定的重叠部分,有助于保留跨块的上下文信息,比如相邻块间保持 20-50 tokens 重叠。</li></ul><h4 id="语义边界保持"><a href="#语义边界保持" class="headerlink" title="语义边界保持"></a>语义边界保持</h4><p>实际分块操作可能是多种分块策略结合使用。比如先基于段落进行动态分块,然后在段落内部,基于固定大一或者句子来分块。基于语义的分块策略应该考虑句子边界和语义相似度,通过句子分割、动态调整块大小、保持语义完整性等方式,确保每个分块都包含完整的语义信息。</p><h4 id="结构化信息处理"><a href="#结构化信息处理" class="headerlink" title="结构化信息处理"></a>结构化信息处理</h4><ul><li><strong>表格数据</strong>:保持表格完整性,添加表头上下文。</li><li><strong>代码片段</strong>:包含函数/类的完整定义。</li><li><strong>列表项</strong>:保持列表的逻辑完整性。</li></ul><h3 id="2-查询增强(Query-Expansion)"><a href="#2-查询增强(Query-Expansion)" class="headerlink" title="2. 查询增强(Query Expansion)"></a>2. 查询增强(Query Expansion)</h3><p>通过扩展查询来增加召回相关文档的机会:</p><h4 id="多角度查询生成"><a href="#多角度查询生成" class="headerlink" title="多角度查询生成"></a>多角度查询生成</h4><p>使用 LLM 生成多个查询变体,可以通过以下方式实现:使用同义词替换、改变句式结构、添加相关术语、简化表达、详细描述等方法,为原始查询生成多个语义相似但表达方式不同的变体。</p><h4 id="层次化查询策略"><a href="#层次化查询策略" class="headerlink" title="层次化查询策略"></a>层次化查询策略</h4><ul><li><strong>粗粒度检索</strong>:使用概括性词汇扩大检索范围</li><li><strong>细粒度检索</strong>:使用具体术语提高精确度</li><li><strong>多层融合</strong>:综合不同粒度的检索结果</li></ul><h3 id="3-混合检索(Hybrid-Search)"><a href="#3-混合检索(Hybrid-Search)" class="headerlink" title="3. 混合检索(Hybrid Search)"></a>3. 混合检索(Hybrid Search)</h3><p>结合多种检索方法发挥各自优势:</p><h4 id="稀疏检索-密集检索"><a href="#稀疏检索-密集检索" class="headerlink" title="稀疏检索 + 密集检索"></a>稀疏检索 + 密集检索</h4><p>混合检索通过结合 BM25 关键词检索和向量语义检索来实现。首先分别进行两种检索,然后对检索分数进行归一化处理,最后通过加权融合(如向量检索权重 70%,BM25 权重 30%)得到最终排序结果。</p><h4 id="多模态检索"><a href="#多模态检索" class="headerlink" title="多模态检索"></a>多模态检索</h4><ul><li><strong>文本+图像</strong>:同时索引文档中的文字和图表信息</li><li><strong>结构化+非结构化</strong>:结合表格数据和自然语言描述</li><li><strong>元数据增强</strong>:利用时间、作者、类别等元信息</li></ul><h2 id="提升准确率的策略"><a href="#提升准确率的策略" class="headerlink" title="提升准确率的策略"></a>提升准确率的策略</h2><p>准确率关注的是”不瞎”,即召回的文档都是真正需要的。</p><h3 id="1-嵌入模型优化"><a href="#1-嵌入模型优化" class="headerlink" title="1. 嵌入模型优化"></a>1. 嵌入模型优化</h3><p>嵌入模型质量直接决定向量检索的准确性:</p><h4 id="领域适配"><a href="#领域适配" class="headerlink" title="领域适配"></a>领域适配</h4><ul><li><strong>预训练模型选择</strong>:如科学文献使用 SciBERT,法律文档使用 LegalBERT</li><li><strong>微调策略</strong>:在特定领域数据上进行对比学习微调</li><li><strong>多语言支持</strong>:针对中英文混合文档的特殊处理</li></ul><h4 id="嵌入维度优化"><a href="#嵌入维度优化" class="headerlink" title="嵌入维度优化"></a>嵌入维度优化</h4><p>嵌入维度的选择需要在性能和质量之间找平衡:</p><ul><li>128 维:速度快、准确率低、内存占用小</li><li>384 维:速度中等、准确率中等、内存占用中等</li><li>768 维:速度慢、准确率高、内存占用大</li><li>1024 维:速度很慢、准确率很高、内存占用很大</li></ul><h3 id="2-重排序(Reranking)"><a href="#2-重排序(Reranking)" class="headerlink" title="2. 重排序(Reranking)"></a>2. 重排序(Reranking)</h3><p>对初始检索结果进行精细化排序:</p><h4 id="Cross-Encoder-重排序"><a href="#Cross-Encoder-重排序" class="headerlink" title="Cross-Encoder 重排序"></a>Cross-Encoder 重排序</h4><p>使用重排序模型对候选文档进行精细排序的过程包括:构建查询-文档对、使用重排序模型计算相关性分数、根据分数进行排序,最终返回排序后的前 K 个文档。</p><h4 id="多阶段重排序"><a href="#多阶段重排序" class="headerlink" title="多阶段重排序"></a>多阶段重排序</h4><ol><li><strong>粗排</strong>:使用轻量级模型快速筛选Top-100</li><li><strong>精排</strong>:使用复杂模型对Top-20进行精确排序</li><li><strong>多样性调整</strong>:避免结果过于集中在相似文档</li></ol><h3 id="3-智能过滤策略"><a href="#3-智能过滤策略" class="headerlink" title="3. 智能过滤策略"></a>3. 智能过滤策略</h3><h4 id="预过滤机制"><a href="#预过滤机制" class="headerlink" title="预过滤机制"></a>预过滤机制</h4><ul><li><strong>元数据过滤</strong>:根据文档类型、时间范围、权威性筛选</li><li><strong>关键词门槛</strong>:确保文档包含查询的核心术语</li><li><strong>质量评分</strong>:基于文档完整性、可读性的预评分</li></ul><h4 id="后过滤优化"><a href="#后过滤优化" class="headerlink" title="后过滤优化"></a>后过滤优化</h4><p>检索后的文档过滤包括多个层面:相关性阈值过滤(如最低相关性 0.3)、重复内容检测(如相似度阈值 0.8)、内容质量检查等,通过这些过滤机制确保最终返回的文档都符合质量要求。</p><h2 id="高级优化技术"><a href="#高级优化技术" class="headerlink" title="高级优化技术"></a>高级优化技术</h2><h3 id="1-自适应检索策略"><a href="#1-自适应检索策略" class="headerlink" title="1. 自适应检索策略"></a>1. 自适应检索策略</h3><p>根据查询特征动态调整检索策略。首先分析查询的复杂度特征,然后根据不同类型采用相应策略:简单查询重点使用关键词匹配,概念性查询重点进行语义理解,复杂查询则采用混合策略。</p><h3 id="2-查询意图理解"><a href="#2-查询意图理解" class="headerlink" title="2. 查询意图理解"></a>2. 查询意图理解</h3><h4 id="意图分类"><a href="#意图分类" class="headerlink" title="意图分类"></a>意图分类</h4><ul><li><strong>事实查询</strong>:寻找具体信息(who, what, when)</li><li><strong>程序查询</strong>:寻找操作步骤(how to)</li><li><strong>比较查询</strong>:对比不同选项(difference, comparison)</li><li><strong>分析查询</strong>:深入理解(why, analysis)</li></ul><h4 id="针对性优化"><a href="#针对性优化" class="headerlink" title="针对性优化"></a>针对性优化</h4><p>根据不同查询意图采用相应的优化策略:</p><ul><li><strong>事实查询</strong>:embedding 权重 30%,关键词权重 70%,使用事实聚焦的重排序模型</li><li><strong>程序查询</strong>:embedding 权重 60%,关键词权重 40%,使用步骤感知的重排序模型 </li><li><strong>分析查询</strong>:embedding 权重 80%,关键词权重 20%,使用上下文感知的重排序模型</li></ul><h3 id="3-动态索引优化"><a href="#3-动态索引优化" class="headerlink" title="3. 动态索引优化"></a>3. 动态索引优化</h3><h4 id="增量更新策略"><a href="#增量更新策略" class="headerlink" title="增量更新策略"></a>增量更新策略</h4><ul><li><strong>热点文档</strong>:高频访问文档的索引优化</li><li><strong>时效性文档</strong>:新增文档的快速索引</li><li><strong>过期清理</strong>:定期清理不再相关的文档</li></ul><h4 id="索引压缩技术"><a href="#索引压缩技术" class="headerlink" title="索引压缩技术"></a>索引压缩技术</h4><ul><li><strong>向量量化</strong>:使用PQ(Product Quantization)压缩</li><li><strong>稀疏化</strong>:去除低权重的向量维度</li><li><strong>分层索引</strong>:构建粗粒度到细粒度的多层索引</li></ul><h2 id="实践中的平衡之道"><a href="#实践中的平衡之道" class="headerlink" title="实践中的平衡之道"></a>实践中的平衡之道</h2><h3 id="召回优先策略"><a href="#召回优先策略" class="headerlink" title="召回优先策略"></a>召回优先策略</h3><p>在实际应用中,通常采用”召回优先,精度优化”的两阶段策略:</p><ol><li><p><strong>广泛召回阶段</strong></p><ul><li>使用宽松的相似度阈值</li><li>应用多种查询扩展技术</li><li>结合多种检索方法</li></ul></li><li><p><strong>精度优化阶段</strong></p><ul><li>应用重排序模型</li><li>执行多层过滤</li><li>进行结果去重和多样性优化</li></ul></li></ol><h3 id="性能与质量权衡"><a href="#性能与质量权衡" class="headerlink" title="性能与质量权衡"></a>性能与质量权衡</h3><p>不同应用场景需要不同的配置策略:</p><ul><li><strong>实时问答</strong>:延迟预算 <200ms,采用轻量级 embedding+简单重排序,牺牲部分准确率换取响应速度</li><li><strong>深度分析</strong>:延迟预算 <5s,采用高质量 embedding+复杂重排序,容忍较高延迟获得最佳质量</li><li><strong>批量处理</strong>:无延迟限制,采用多模型 ensemble+全面后处理,追求最高质量</li></ul><h3 id="持续优化机制"><a href="#持续优化机制" class="headerlink" title="持续优化机制"></a>持续优化机制</h3><h4 id="A-B-测试框架"><a href="#A-B-测试框架" class="headerlink" title="A/B 测试框架"></a>A/B 测试框架</h4><ul><li><strong>检索策略对比</strong>:不同算法的效果验证</li><li><strong>参数调优</strong>:阈值、权重等超参数优化</li><li><strong>用户体验监控</strong>:基于真实反馈的持续改进</li></ul><h4 id="监控指标体系"><a href="#监控指标体系" class="headerlink" title="监控指标体系"></a>监控指标体系</h4><p>建立全面的监控指标体系:</p><ul><li><strong>检索质量</strong>:NDCG@10、MRR、Hit Rate@5</li><li><strong>系统性能</strong>:平均延迟、P99 延迟、QPS</li><li><strong>用户满意度</strong>:点击率、停留时间、反馈评分</li><li><strong>业务指标</strong>:任务完成率、准确答案比例、用户留存</li></ul><h2 id="未来发展趋势"><a href="#未来发展趋势" class="headerlink" title="未来发展趋势"></a>未来发展趋势</h2><h3 id="1-多模态检索融合"><a href="#1-多模态检索融合" class="headerlink" title="1. 多模态检索融合"></a>1. 多模态检索融合</h3><p>随着多模态大模型的发展,RAG 系统将更好地处理文本、图像、音视频等多种模态的信息检索和融合。</p><h3 id="2-个性化检索优化"><a href="#2-个性化检索优化" class="headerlink" title="2. 个性化检索优化"></a>2. 个性化检索优化</h3><p>基于用户历史行为和偏好,构建个性化的检索模型,提供更精准的个人知识服务。</p><h3 id="3-实时学习能力"><a href="#3-实时学习能力" class="headerlink" title="3. 实时学习能力"></a>3. 实时学习能力</h3><p>检索系统将具备从用户反馈中实时学习的能力,持续优化检索质量。</p><h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>RAG 系统的检索优化是一个系统性工程,需要从分块策略、查询处理、检索算法、重排序等多个维度进行综合优化。关键在于:</p><ol><li><strong>建立完善的评估体系</strong>:确保优化方向正确</li><li><strong>平衡召回率与准确率</strong>:根据应用场景找到最优平衡点</li><li><strong>采用分层优化策略</strong>:粗排+精排的两阶段设计</li><li><strong>持续监控和迭代</strong>:基于真实数据不断优化</li></ol><p>只有通过系统性的优化和持续的迭代改进,才能构建出既能全面检索,又能精准定位的高质量 RAG 应用,为用户提供准确、及时、有价值的信息服务。</p>]]></content>
<summary type="html">
<p>RAG(检索增强生成)系统的核心在于能否准确、高效地检索到与用户查询最相关的文档片段。检索质量的好坏直接决定了最终生成结果的准确性和可靠性。本文将深入探讨 RAG 系统中检索优化的关键策略和最佳实践。</p>
</summary>
<category term="RAG" scheme="https://murphypei.github.io/categories/RAG/"/>
<category term="chunk" scheme="https://murphypei.github.io/tags/chunk/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="RAG" scheme="https://murphypei.github.io/tags/RAG/"/>
<category term="retrieval" scheme="https://murphypei.github.io/tags/retrieval/"/>
<category term="index" scheme="https://murphypei.github.io/tags/index/"/>
<category term="检索优化" scheme="https://murphypei.github.io/tags/%E6%A3%80%E7%B4%A2%E4%BC%98%E5%8C%96/"/>
<category term="质量评估" scheme="https://murphypei.github.io/tags/%E8%B4%A8%E9%87%8F%E8%AF%84%E4%BC%B0/"/>
</entry>
<entry>
<title>RAG 分块策略</title>
<link href="https://murphypei.github.io/blog/2025/08/rag-chunking.html"/>
<id>https://murphypei.github.io/blog/2025/08/rag-chunking.html</id>
<published>2025-08-28T06:35:25.000Z</published>
<updated>2026-01-07T08:41:01.087Z</updated>
<content type="html"><![CDATA[<p>在构建一个高效的检索增强生成(RAG)系统时,一个常常被忽视但至关重要的环节是<strong>分块(Chunking)</strong>。这个过程是将大型文档切分成小块,以便 LLM 可以轻松地检索和理解。如果分块策略不当,即使拥有最先进的语言模型和向量数据库,你的 RAG 系统也可能表现不佳。</p><span id="more"></span><p>本文将深入探讨为什么分块如此重要,以及如何通过几种核心策略来优化它。</p><h2 id="什么是词块"><a href="#什么是词块" class="headerlink" title="什么是词块"></a>什么是词块</h2><p>将词块视为 RAG 系统中知识的原子单元 。每个词块应该:</p><ul><li><strong>语义完整</strong> :包含单独来看有意义的连贯信息。</li><li>语境丰富 :包含足够的周围环境,无需外部参考即可理解。</li><li>最佳尺寸 :足够大以具有意义,足够小以便精确检索。</li><li>边界感知 :尊重句子、段落和章节等自然语言结构。</li></ul><p>这些要求经常相互冲突。语义完整的块可能太大,无法进行精确检索。大小合适的块可能会在边界处丢失关键上下文。这时,不同的分块策略就派上用场了。每种方法都有各自的优缺点,最佳选择取决于您的具体用例、文档类型和检索需求。</p><h2 id="分块策略"><a href="#分块策略" class="headerlink" title="分块策略"></a>分块策略</h2><p>下面由浅入深,介绍不同的分块策略:</p><h3 id="固定大小分块(简单基线)"><a href="#固定大小分块(简单基线)" class="headerlink" title="固定大小分块(简单基线)"></a>固定大小分块(简单基线)</h3><p>固定大小分块将文档拆分为预定大小的块,通常以以下方式测量:</p><ul><li>字符数 (例如,每块 1000 个字符)</li><li>令牌数量 (例如,每块 256 个令牌)</li><li>字数统计 (例如,每段 200 个字)</li></ul><p>重叠参数在这里至关重要。如果没有重叠,您可能会丢失跨越块边界的信息。重叠率为 20% 时,可以确保相邻块之间有一定的连续性。</p><p>固定大小分块策略几乎不适用于任何场景,仅在有限计算资源的情况。固定大小分块适用的场景:</p><ul><li>有统一、简单的文档(博客、文章、小说)</li><li>处理速度至关重要</li><li>存储和计算资源严重有限</li></ul><h3 id="语义分块"><a href="#语义分块" class="headerlink" title="语义分块"></a>语义分块</h3><p>语义分块代表着思维的根本性转变。语义分块不再根据大小限制任意切割文本,而是根据语义和文档结构识别自然断点。语义分块需要更多的计算资源,但能带来更好的结果。</p><p>核心见解:<strong>文档具有固有的语义边界,可以指导分块决策</strong>。</p><p>语义分块的工作原理:</p><ol><li>句子分割 :将文档分解成单个句子</li><li>嵌入生成 :为每个句子创建向量表示</li><li>相似度分析 :测量<strong>相邻句子</strong>之间的语义相似度</li><li>边界检测 :在相似度低于阈值的地方创建块边界</li><li>组块 :将连续相似的句子组合成连贯的组块</li></ol><p>高级语义技术:</p><ul><li>主题建模集成 :使用基于 LDA 或 BERT 的主题模型来识别主题边界</li><li>层次聚类 :在分块之前按语义相似度对句子进行分组</li><li>实体连续性 :确保命名实体及其引用保持在同一块内</li><li>话语标记 :使用语言线索(但是、此外、总之)来识别界限</li></ul><p>语义分块适用于很多场景:</p><ul><li>文档质量差异很大</li><li>上下文保存很重要</li><li>准确性比速度更重要</li></ul><h3 id="分层文档分块"><a href="#分层文档分块" class="headerlink" title="分层文档分块"></a>分层文档分块</h3><p>对于具有嵌套结构的复杂文档(例如技术手册、法律合同或学术论文),分层分块可以保留文档的组织逻辑。这种方法认识到文档不是平面文本,而是具有章节、小节和逻辑层次的结构化知识。</p><p>分层分块的主要优点:</p><ul><li>上下文继承 :每个块都知道它在文档层次结构中的位置</li><li>元数据保存 :维护章节标题、级别和关系</li><li>结构感知检索 :您可以按不同的粒度进行检索(部分与小节)</li><li>交叉引用处理 :对其他部分的引用仍然有意义</li></ul><p>对于 200 页的技术手册,分层分块可能会创建:</p><ol><li>15 个章节级块(高级概述)</li><li>127 个部分级块(详细解释)</li><li>340 个子节级块(具体程序)</li></ol><p>在检索过程中,系统可以将查询与适当的细节级别进行匹配,并提供分层上下文。</p><p>如果出现以下情况,则实施分层分块:</p><ul><li>文件结构清晰(手册、报告、学术论文)</li><li>用户需要不同程度的细节</li><li>交叉引用很常见</li><li>文档格式结构良好</li></ul><h3 id="滑动窗口分块(重叠优化器)"><a href="#滑动窗口分块(重叠优化器)" class="headerlink" title="滑动窗口分块(重叠优化器)"></a>滑动窗口分块(重叠优化器)</h3><p>传统的组块方法会创建离散的、不重叠的组块。但是,如果最重要的信息恰好落在组块边界上,该怎么办呢?滑动窗口分块通过创建重叠块解决了这个问题,确保没有信息被遗漏。</p><p>当 window_size=1000、stride=800 时,可获得 200 个字符的重叠(20%)。此重叠可确保:</p><ol><li>关键句子没有被拆分成多个部分</li><li>上下文在相邻块之间自然流动</li><li>检索有多次机会找到相关信息</li></ol><h3 id="自适应滑动窗口"><a href="#自适应滑动窗口" class="headerlink" title="自适应滑动窗口"></a>自适应滑动窗口</h3><p>高级实现根据内容密度调整窗口大小和步幅:</p><ol><li>常见问题解答文档 :问题和答案受益于重叠的上下文</li><li>法律文件 :精确的语言界限至关重要</li><li>技术程序 :步骤参考先前的操作</li><li>叙事内容 :故事连续性至关重要</li></ol><h3 id="大模型分块"><a href="#大模型分块" class="headerlink" title="大模型分块"></a>大模型分块</h3><p>使用大语言模型进行递归分块: 最新的进展是使用 LLM 本身进行分块决策。</p><h3 id="多模态分块"><a href="#多模态分块" class="headerlink" title="多模态分块"></a>多模态分块</h3><p>对于包含图像、表格和混合内容的文档。</p><h3 id="基于查询模式的动态分块"><a href="#基于查询模式的动态分块" class="headerlink" title="基于查询模式的动态分块"></a>基于查询模式的动态分块</h3><p>最先进的系统根据文档的实际查询方式来调整分块。</p>]]></content>
<summary type="html">
<p>在构建一个高效的检索增强生成(RAG)系统时,一个常常被忽视但至关重要的环节是<strong>分块(Chunking)</strong>。这个过程是将大型文档切分成小块,以便 LLM 可以轻松地检索和理解。如果分块策略不当,即使拥有最先进的语言模型和向量数据库,你的 RAG 系统也可能表现不佳。</p>
</summary>
<category term="RAG" scheme="https://murphypei.github.io/categories/RAG/"/>
<category term="chunk" scheme="https://murphypei.github.io/tags/chunk/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="RAG" scheme="https://murphypei.github.io/tags/RAG/"/>
<category term="retrieval" scheme="https://murphypei.github.io/tags/retrieval/"/>
<category term="index" scheme="https://murphypei.github.io/tags/index/"/>
<category term="分块" scheme="https://murphypei.github.io/tags/%E5%88%86%E5%9D%97/"/>
<category term="检索" scheme="https://murphypei.github.io/tags/%E6%A3%80%E7%B4%A2/"/>
</entry>
<entry>
<title>Agent 几个问题思考</title>
<link href="https://murphypei.github.io/blog/2025/07/agent-fc-mcp.html"/>
<id>https://murphypei.github.io/blog/2025/07/agent-fc-mcp.html</id>
<published>2025-07-28T19:35:25.000Z</published>
<updated>2026-01-07T08:41:01.087Z</updated>
<content type="html"><![CDATA[<p>最近遇到了几个大模型模型算法应用的关键问题,作为记录。</p><span id="more"></span><h3 id="Agent-设计范式"><a href="#Agent-设计范式" class="headerlink" title="Agent 设计范式"></a>Agent 设计范式</h3><p>目前主流的 AI Agent(智能体)设计模式,通常不是单一的,而是基于一些核心思想和框架的组合。这些设计模式旨在赋予大语言模型<strong>自主思考、规划、行动和反思</strong>的能力,以完成更复杂的任务。</p><h4 id="1-ReAct-Reasoning-and-Acting"><a href="#1-ReAct-Reasoning-and-Acting" class="headerlink" title="1. ReAct (Reasoning and Acting)"></a>1. ReAct (Reasoning and Acting)</h4><p>这是目前最流行、最基础也是最核心的一种 Agent 设计模式。它的核心思想是让大语言模型在一个<strong>“思考-行动-观察”</strong>的循环中持续工作,直到任务完成。</p><ul><li><strong>思考(Thought)</strong>:LLM 会根据当前任务和过往观察,产生一段内部思考,比如“我需要找到XX信息,我应该使用什么工具?”</li><li><strong>行动(Action)</strong>:LLM 根据思考,生成一个调用外部工具的指令(即 Function Call)。</li><li><strong>观察(Observation)</strong>:LLM 接收到外部工具返回的结果。</li></ul><p>这个循环会反复进行,直到 LLM 判断任务已完成并生成最终回复。这种模式将”思考”过程显性化,使得模型的决策过程更加透明和可控。</p><h4 id="2-Plan-and-Execute(规划与执行)"><a href="#2-Plan-and-Execute(规划与执行)" class="headerlink" title="2. Plan-and-Execute(规划与执行)"></a>2. Plan-and-Execute(规划与执行)</h4><p>这种模式更侧重于对复杂任务的<strong>结构化分解</strong>。它将一个大任务分为两个阶段:</p><ul><li><strong>规划阶段(Planning)</strong>:LLM 首先会分析用户的请求,并生成一个详细的、分步骤的行动计划。这个计划是固定的,不会在执行过程中轻易改变。</li><li><strong>执行阶段(Execution)</strong>:LLM 严格按照规划好的步骤,一步一步地调用工具和执行任务。</li></ul><p>这种模式的优点是任务流程清晰、稳定,适合需要按部就班完成的复杂任务。但缺点是缺乏灵活性,如果计划中的某个步骤失败,Agent 可能无法自主调整。</p><h4 id="3-Self-Correction(自我纠正-反思)"><a href="#3-Self-Correction(自我纠正-反思)" class="headerlink" title="3. Self-Correction(自我纠正/反思)"></a>3. Self-Correction(自我纠正/反思)</h4><p>这种设计模式的核心是让 Agent 具备<strong>复盘和纠错</strong>的能力。它通常与其他模式结合使用,为 Agent 增加一个”反思”步骤。</p><ul><li><strong>反思(Reflection)</strong>:Agent 在完成一个任务或得到一个结果后,会重新评估这个结果是否正确或达到预期。</li><li><strong>纠正(Correction)</strong>:如果评估结果不理想,Agent 会根据反思结果,修改其原始的思考路径或行动计划,然后再次尝试,直到达到满意的结果。</li></ul><p>这种模式能显著提升 Agent 在复杂问题上的表现,因为它允许 Agent 从错误中学习,避免重复犯错。</p><p>这三种模式并不是相互独立的。一个强大的 Agent 通常会结合使用这些思想:例如,一个 Agent 可能先用 <strong>Plan-and-Execute</strong> 进行任务分解,然后在每个执行步骤中,使用 <strong>ReAct</strong> 循环来调用工具,并最终用 <strong>Self-Correction</strong> 机制来验证和修正结果。</p><h3 id="大模型如何”用”工具?Agent、Function-Call-与-MCP-的进化之路"><a href="#大模型如何”用”工具?Agent、Function-Call-与-MCP-的进化之路" class="headerlink" title="大模型如何”用”工具?Agent、Function Call 与 MCP 的进化之路"></a>大模型如何”用”工具?Agent、Function Call 与 MCP 的进化之路</h3><p>在构建基于大语言模型(LLM)的应用时,一个核心挑战是让 LLM 不只停留在”聊天”,而是真正具备”行动”能力,比如联网搜索、调用 API 或执行代码。这就是 <strong>Agent</strong> 的核心思想:让 LLM 像一个智能体一样,能够根据用户的指令,自主决定是直接回答,还是调用外部工具来获取信息或完成任务。</p><h4 id="1-Agent-的决策过程:LLM-如何知道何时调用工具?"><a href="#1-Agent-的决策过程:LLM-如何知道何时调用工具?" class="headerlink" title="1. Agent 的决策过程:LLM 如何知道何时调用工具?"></a>1. Agent 的决策过程:LLM 如何知道何时调用工具?</h4><p>Agent 的决策机制本质上是 <strong>Prompt Engineering</strong> 的一种高级应用。开发者会设计一个精巧的 <strong>Prompt</strong>,将可用的工具列表、它们的用途和描述一并告诉 LLM。</p><p>例如,当我们问一个 Agent:“今天北京的天气怎么样?”它的思考过程可能如下:</p><ol><li><strong>用户意图分析:</strong> 用户想知道北京的天气。</li><li><strong>工具匹配:</strong> 我有一个可以查询天气的工具(<code>get_weather(location)</code>)。</li><li><strong>决策与执行:</strong> 我需要调用这个工具,并把“北京”作为参数。</li></ol><p>这个思考过程并不神秘,而是通过精心设计的 <strong>Prompt</strong> 来引导 LLM 生成。LLM 会根据输入的指令和工具描述,在输出中”思考”并生成一个结构化的行动指令,然后由外部程序(Agent 框架)去实际执行。</p><h4 id="2-Function-Call-FC-:将决策能力内置到模型中"><a href="#2-Function-Call-FC-:将决策能力内置到模型中" class="headerlink" title="2. Function Call (FC):将决策能力内置到模型中"></a>2. Function Call (FC):将决策能力内置到模型中</h4><p><strong>Function Call (FC)</strong> 是对上述 Agent 决策机制的一种原生优化。它将“思考”和“生成调用指令”的能力直接通过模型训练内置进去。</p><p><strong>FC 的核心是:</strong> 模型能够根据上下文,直接以预先定义好的 <strong>JSON 格式</strong> 生成对外部工具的调用,而不是像传统 Agent 那样需要外部框架去解析 LLM 生成的文本。</p><p>这是一种巨大的进步,因为它使得工具调用更加稳定、高效,并减少了外部解析的复杂性。</p><p><strong>那么,为什么说 FC 存在“MxN”的问题?</strong></p><p>这里有一个常见的误解:很多人以为每增加一个工具,就需要重新训练模型。<strong>这是不正确的。</strong></p><p>FC 的”MxN”问题不在于模型本身,而在于 <strong>工具的描述格式</strong>。每个拥有 FC 能力的 LLM 平台(如 OpenAI, Google, Anthropic)都有自己独特的工具描述 Schema(函数签名)。一个搜索工具,为了能被不同的模型调用,开发者需要为它编写 <strong>M</strong> 份不同的描述文档。同样,一个 Agent 开发者如果想使用 <strong>N</strong> 个工具,并支持 <strong>M</strong> 个不同的 LLM,就需要处理 MxN 个兼容性问题。</p><p>简而言之,<strong>FC 解决了“让模型知道如何调用工具”的问题,但没有解决“工具描述格式不统一”的问题。</strong></p><h4 id="3-MCP:通过抽象层实现真正的解耦"><a href="#3-MCP:通过抽象层实现真正的解耦" class="headerlink" title="3. MCP:通过抽象层实现真正的解耦"></a>3. MCP:通过抽象层实现真正的解耦</h4><p><strong>Multi-tool Coordinator Protocol (MCP)</strong> 正是为了解决 FC 带来的兼容性与扩展性问题而诞生的。</p><p><strong>MCP 的核心思想是:在模型和工具之间增加一个抽象的中间层。</strong></p><p>这个中间层通常由一个 <strong>MCP Server</strong> 组成,其工作流程如下:</p><ol><li><strong>工具注册:</strong> 所有外部工具都按照一套统一的 <strong>MCP 协议</strong> 接入并注册到 MCP Server。</li><li><strong>Agent (LLM) 接入:</strong> 所有的 Agent 只需要学习并支持这套统一的 <strong>MCP 协议</strong>。它们不再需要知道每个工具的具体描述细节,只需向 MCP Server 发送一个统一的请求,来获取可用的工具列表或执行工具调用。</li></ol><p><strong>为什么说 MCP 实现了“M+N”?</strong></p><ul><li><strong>M</strong> 个不同的模型(或 Agent)只需要对接 <strong>1</strong> 个统一的 MCP Server。</li><li><strong>N</strong> 个不同的工具也只需要对接 <strong>1</strong> 个统一的 MCP Server。</li></ul><p>通过这种方式,模型与工具之间不再是复杂的点对点连接,而是通过一个中心化的枢纽进行通信。这不仅解决了 FC 在不同平台间的兼容性问题,更是一种架构上的巨大优化,它让工具的管理、维护和扩展变得前所未有的简单。</p><p><strong>总结来说:</strong></p><ul><li><strong>Agent</strong> 是让 LLM 具备行动能力的 <strong>思想</strong>。</li><li><strong>Function Call</strong> 是将这种思想 <strong>内置到模型中</strong> 的一种能力。</li><li><strong>MCP</strong> 是在此基础上,进一步将 <strong>模型与工具解耦</strong> 的一种 <strong>架构设计</strong>。</li></ul><p>当然,MCP 还需要在调用端封装。在典型的 <strong>MCP (Multi-tool Coordinator Protocol)</strong> 实现中,会区分 <strong>AI Host (AI 宿主)</strong> 和 <strong>MCP Client (MCP 客户端)</strong> 这两个角色,它们通过一种清晰的协作模式来共同完成任务。</p><ol><li><p><strong>AI Host (AI 宿主)</strong></p><ul><li><strong>角色:</strong> AI Host 是整个系统的核心,通常是运行大语言模型(LLM)或 Agent 框架的那部分。它负责接收用户的指令,并进行高级别的决策与推理。</li><li><strong>职责:</strong><ul><li>接收用户输入。</li><li>与 LLM 交互,进行意图分析。</li><li>根据 LLM 的输出,决定是直接生成回复还是需要调用工具。</li><li><strong>注意:</strong> AI Host 不直接与具体的工具交互,它只知道 MCP 协议。</li></ul></li></ul></li><li><p><strong>MCP Client (MCP 客户端)</strong></p><ul><li><strong>角色:</strong> MCP Client 是一个独立的模块或库,作为 AI Host 与 MCP Server 之间的桥梁。它封装了所有与 MCP Server 通信的细节。</li><li><strong>职责:</strong><ul><li>将来自 AI Host 的工具调用请求,按照 <strong>MCP 协议</strong> 进行格式化,并发送给 MCP Server。</li><li>接收 MCP Server 返回的结果,并将其转换回 AI Host 可理解的格式。</li><li>管理与 MCP Server 的连接和会话。</li><li><strong>注意:</strong> MCP Client 也不关心具体的工具是如何实现的,它只负责协议层面的通信。</li></ul></li></ul></li></ol><p>整个 MCP 协作过程可以分解为以下几个步骤:</p><ol><li><strong>用户请求:</strong> 用户向 AI Host 发出指令,例如:“帮我查一下旧金山的天气。”</li><li><strong>AI Host 意图分析:</strong> AI Host 将用户指令发送给内置的 LLM。LLM 基于其 Function Call 或 ReAct 能力,判断出需要调用一个天气查询工具。它会生成一个结构化的调用请求,比如 <code>{"tool_name": "weather_api", "parameters": {"city": "旧金山"}}</code>。</li><li><strong>AI Host 委托:</strong> AI Host 拿到 LLM 生成的调用请求后,并不会自己去执行,而是将其 <strong>委托</strong> 给 <strong>MCP Client</strong>。</li><li><strong>MCP Client 协议转换与发送:</strong> MCP Client 接收到这个请求,将其封装成符合 MCP 协议的格式(例如一个特定的 HTTP POST 请求或 gRPC 调用),然后发送给远端的 <strong>MCP Server</strong>。</li><li><strong>MCP Server 工具协调与执行:</strong> MCP Server 接收到请求后,根据 <code>tool_name</code> 找到对应的工具,将 <code>parameters</code> 传递给该工具并执行。</li><li><strong>结果返回:</strong> 工具执行完毕,结果返回给 MCP Server,MCP Server 再将结果通过 MCP 协议返回给 MCP Client。</li><li><strong>MCP Client 结果转换与回传:</strong> MCP Client 接收到 MCP Server 的响应,将其解析并转换为 AI Host 可理解的格式,然后返回给 AI Host。</li><li><strong>AI Host 回复生成:</strong> AI Host 拿到工具执行结果(例如:”旧金山今天多云,气温 15 摄氏度”),将其作为上下文的一部分再次输入给 LLM,最终由 LLM 生成完整的自然语言回复给用户。</li></ol><p>通过这种方式,AI Host 始终保持”干净”,只关心高级别的推理和决策,而具体的工具调用和协议通信的复杂性则完全由 MCP Client 和 MCP Server 这层抽象来处理,实现了 AI 能力与工具生态的彻底解耦。</p><p>最后,通过一个图说明 FC 和 MCP 工作模式:</p><p><img src="/images/posts/agent/fc_mcp.gif" alt="FC 和 MCP 工作模式"></p><h3 id="JSON-格式化输出"><a href="#JSON-格式化输出" class="headerlink" title="JSON 格式化输出"></a>JSON 格式化输出</h3><p>在 <strong>Function Call (FC)</strong> 模式下,要保证模型稳定输出 JSON 格式,主要依赖于 <strong>模型本身的训练和微调</strong>。</p><h4 id="如何保证模型稳定输出-JSON?"><a href="#如何保证模型稳定输出-JSON?" class="headerlink" title="如何保证模型稳定输出 JSON?"></a>如何保证模型稳定输出 JSON?</h4><p>这并非单纯通过 Prompt Engineering 就能完美解决的问题。核心在于:</p><ol><li><strong>大规模训练:</strong> 在模型的预训练和指令微调阶段,会使用大量的包含 JSON 格式的结构化数据作为训练样本。这些样本告诉模型,当它接收到特定类型的指令(例如,要求它调用某个工具)时,应该输出一个遵循特定 JSON Schema 的结果。</li><li><strong>特殊的解码策略:</strong> 一些模型在解码时会采用特殊的约束,比如 <strong>JSON Schema 约束解码</strong>。这意味着,模型在生成每一个 token 时,都会检查其是否符合预先定义的 JSON 格式规则。如果生成的 token 会导致 JSON 格式无效,模型会将其“回溯”并尝试生成另一个 token,直到生成完整且正确的 JSON。这种方法极大地提高了输出的稳定性和正确性。</li><li><strong>Prompt 工程辅助:</strong> 尽管核心能力来自模型本身,但好的 Prompt 仍然至关重要。例如,在 Prompt 中清晰地描述工具的函数签名,并明确要求模型“请以 JSON 格式输出调用结果”,可以进一步引导模型输出期望的格式。</li></ol><p>比如这个 FC 输出的格式:</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">{</span><br><span class="line"> "name": "get_current_weather",</span><br><span class="line"> "arguments": {</span><br><span class="line"> "location": "旧金山",</span><br><span class="line"> "unit": "celsius"</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>其中的每个 key 和 value 都要非常精准的匹配到调用的工具才可以。</p><h4 id="MCP-是否也要求模型输出固定格式?"><a href="#MCP-是否也要求模型输出固定格式?" class="headerlink" title="MCP 是否也要求模型输出固定格式?"></a>MCP 是否也要求模型输出固定格式?</h4><p><strong>是的,但要求的是更简单、更统一的格式。</strong></p><p><strong>MCP</strong> 的核心理念是 <strong>解耦</strong>。它将工具的复杂性和多样性从模型端剥离,因此模型不需要了解每个工具具体的 JSON Schema。模型唯一需要知道的,是与 <strong>MCP Client</strong> 交互的 <strong>统一协议</strong>。</p><p>这个统一协议的格式通常非常简单,比如一个包含 <code>tool_name</code> 和 <code>parameters</code> 的 JSON 对象。模型需要做的,只是稳定输出这个简单的、所有工具都通用的 JSON 格式,然后由 <strong>MCP Client</strong> 去处理后续的协议转换和工具调用。</p><p>这正是 <strong>MCP</strong> 的优势所在:它降低了对模型的要求。模型不需要针对每一个新工具去学习其独特的调用格式,它只需要掌握一套通用的、简单的输出格式即可。这使得模型可以更专注于其核心的自然语言理解和意图识别,而将复杂的<br>工具协调任务交给 <strong>MCP Client</strong> 和 <strong>MCP Server</strong> 去完成。</p><p>下面是一个模型输出的 MCP 格式的调用:</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">{</span><br><span class="line"> "tool_name": "weather_api",</span><br><span class="line"> "parameters": {</span><br><span class="line"> "city": "旧金山",</span><br><span class="line"> "temp_unit": "celsius"</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>这个看起来比较类似,但实际上模型并不知道 weather_api 到底是什么,也不知道对于一个 city,应该传入的参数是什么(city?location?),对于温度单位,也类似。</p><p>FC (Function Call) 模式下,模型的难度体现在 “量” 上。<strong>模型需要记忆并理解每一个工具独特的 JSON Schema,而且要非常精确</strong>。如果你的系统有 100 个不同的工具,每个工具的参数都不一样,那模型就需要稳定地输出 100 种不同结构的 JSON。这就像让一个人记住 100 个完全不同的表格格式,并根据指令填写。</p><p>而 MCP (Multi-tool Coordinator Protocol) 模式下,模型的难度只体现在 “质” 上。它只需要学会一种 统一且简单的 JSON 格式,即 {“tool_name”: “…”, “parameters”: {…}}。无论有多少个工具,这个输出格式始终不变。这就像让一个人永远只填写一种固定格式的表格,然后把表格交给一个“总机”去处理后续的细节。</p><p>MCP Client 拿到的模型输出,也就是通用且简单的 JSON 格式(例如:{“tool_name”: “…”, “parameters”: {…}}),负责将其解析、封装、调用 Server、接收响应、解析响应,回传给 AI Host。</p>]]></content>
<summary type="html">
<p>最近遇到了几个大模型模型算法应用的关键问题,作为记录。</p>
</summary>
<category term="Agent" scheme="https://murphypei.github.io/categories/Agent/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="agent" scheme="https://murphypei.github.io/tags/agent/"/>
<category term="rag" scheme="https://murphypei.github.io/tags/rag/"/>
<category term="mcp" scheme="https://murphypei.github.io/tags/mcp/"/>
</entry>
<entry>
<title>LLM 训练:ZeRO 技术详解</title>
<link href="https://murphypei.github.io/blog/2025/07/llm-zero.html"/>
<id>https://murphypei.github.io/blog/2025/07/llm-zero.html</id>
<published>2025-07-22T22:08:12.000Z</published>
<updated>2026-01-07T08:41:01.086Z</updated>
<content type="html"><![CDATA[<p>在大语言模型(LLM)训练中,显存不足是一个普遍存在的问题。随着模型规模的不断增长,单个 GPU 的显存容量成为了训练大规模模型的主要瓶颈。DeepSpeed ZeRO(Zero Redundancy Optimizer)技术通过创新的数据分片策略,有效解决了这一问题,使得我们能够训练远超单卡显存上限的超大规模模型。</p><span id="more"></span><h2 id="引言"><a href="#引言" class="headerlink" title="引言"></a>引言</h2><p>随着大语言模型规模的快速增长,显存需求呈指数级增长。传统的分布式训练方法虽然能够利用多 GPU 进行训练,但每个 GPU 仍然需要存储完整的模型参数、梯度和优化器状态,这严重限制了可训练的模型规模。</p><p>ZeRO 技术通过<strong>数据并行</strong>与<strong>内存优化</strong>的结合,将模型训练中的大块数据(优化器状态、梯度和模型参数)分散到不同的 GPU 上,而非在每个 GPU 上都完整存储一份,从而显著降低了每个 GPU 的显存需求。</p><h2 id="ZeRO-技术原理"><a href="#ZeRO-技术原理" class="headerlink" title="ZeRO 技术原理"></a>ZeRO 技术原理</h2><h3 id="传统数据并行的问题"><a href="#传统数据并行的问题" class="headerlink" title="传统数据并行的问题"></a>传统数据并行的问题</h3><p>在传统的数据并行训练中,每个 GPU 都需要存储:</p><ol><li><strong>模型参数</strong>:完整的模型权重</li><li><strong>梯度</strong>:完整的梯度信息</li><li><strong>优化器状态</strong>:如 Adam 优化器的动量、方差等状态</li></ol><p>对于大规模模型,这些数据占用的显存非常庞大。例如,一个 175B 参数的模型使用 Adam 优化器时,仅优化器状态就需要约 700GB 显存(每个参数需要 4 个 float32 值)。</p><h3 id="ZeRO-的核心思想"><a href="#ZeRO-的核心思想" class="headerlink" title="ZeRO 的核心思想"></a>ZeRO 的核心思想</h3><p>ZeRO 的核心思想是<strong>消除冗余存储</strong>,通过分片技术将原本每个 GPU 都需要存储的完整数据分散到多个 GPU 上,实现显存的线性扩展。</p><p><strong>关键洞察</strong>:</p><ul><li>在数据并行中,不同 GPU 上的模型参数是相同的</li><li>梯度在反向传播后需要进行 All-Reduce 操作</li><li>优化器状态与参数一一对应</li></ul><p>基于这些观察,ZeRO 提出了分阶段的内存优化策略。</p><h2 id="ZeRO-的三个阶段"><a href="#ZeRO-的三个阶段" class="headerlink" title="ZeRO 的三个阶段"></a>ZeRO 的三个阶段</h2><h3 id="ZeRO-Stage-1:优化器状态分片(Optimizer-State-Sharding)"><a href="#ZeRO-Stage-1:优化器状态分片(Optimizer-State-Sharding)" class="headerlink" title="ZeRO-Stage 1:优化器状态分片(Optimizer State Sharding)"></a>ZeRO-Stage 1:优化器状态分片(Optimizer State Sharding)</h3><p><strong>原理</strong>:<br>将优化器状态分片存储在不同的 GPU 上,每个 GPU 只存储部分优化器状态。</p><p><strong>具体做法</strong>:</p><ul><li>假设有 $N$ 个 GPU,模型参数为 $P$</li><li>将优化器状态分成 $N$ 个分片,每个 GPU 存储 $P/N$ 个参数对应的优化器状态</li><li>在参数更新时,每个 GPU 只更新自己负责的那部分参数</li></ul><p><strong>内存节省</strong>:</p><ul><li>优化器状态内存减少 $N$ 倍</li><li>对于 Adam 优化器,每个参数需要 4 个 float32 值,节省效果显著</li></ul><h3 id="ZeRO-Stage-2:梯度分片(Gradient-Sharding)"><a href="#ZeRO-Stage-2:梯度分片(Gradient-Sharding)" class="headerlink" title="ZeRO-Stage 2:梯度分片(Gradient Sharding)"></a>ZeRO-Stage 2:梯度分片(Gradient Sharding)</h3><p><strong>原理</strong>:<br>在 Stage 1 的基础上,进一步将梯度分片存储。</p><p><strong>具体做法</strong>:</p><ul><li>每个 GPU 只计算和存储部分梯度</li><li>在反向传播结束时,通过 All-Reduce 操作收集完整的梯度</li><li>然后每个 GPU 只更新自己负责的参数部分</li></ul><p><strong>内存节省</strong>:</p><ul><li>梯度内存减少 $N$ 倍</li><li>与 Stage 1 结合,总内存节省更加显著</li></ul><h3 id="ZeRO-Stage-3:参数分片(Parameter-Sharding)"><a href="#ZeRO-Stage-3:参数分片(Parameter-Sharding)" class="headerlink" title="ZeRO-Stage 3:参数分片(Parameter Sharding)"></a>ZeRO-Stage 3:参数分片(Parameter Sharding)</h3><p><strong>原理</strong>:<br>在 Stage 1 和 Stage 2 的基础上,进一步将模型参数分片存储。</p><p><strong>具体做法</strong>:</p><ul><li>模型参数也被分片存储在不同的 GPU 上</li><li>在训练过程中,当需要某个层的所有参数时,通过 All-Gather 操作将所需参数动态地收集到当前 GPU</li><li>这意味着在任何给定时间点,每个 GPU 上只完整存在模型参数的一部分</li></ul><p><strong>内存节省</strong>:</p><ul><li>模型参数内存减少 $N$ 倍</li><li>实现了最大程度的内存优化</li></ul><h2 id="ZeRO-的具体实现"><a href="#ZeRO-的具体实现" class="headerlink" title="ZeRO 的具体实现"></a>ZeRO 的具体实现</h2><h3 id="通信模式"><a href="#通信模式" class="headerlink" title="通信模式"></a>通信模式</h3><p>ZeRO 使用了两种主要的通信模式:</p><ol><li><p><strong>All-Gather</strong>:用于参数收集</p><ul><li>当需要某个层的完整参数时,从所有 GPU 收集该层的参数分片</li><li>通信开销:$O(P)$,其中 $P$ 是参数数量</li></ul></li><li><p><strong>All-Reduce</strong>:用于梯度聚合</p><ul><li>在反向传播后,聚合所有 GPU 上的梯度分片</li><li>通信开销:$O(P)$</li></ul></li></ol><h3 id="内存管理策略"><a href="#内存管理策略" class="headerlink" title="内存管理策略"></a>内存管理策略</h3><p><strong>按需加载机制</strong>:</p><ul><li>参数只在需要时才加载到 GPU 显存</li><li>使用完毕后立即释放,避免长期占用显存</li></ul><p><strong>分片存储策略</strong>:</p><ul><li>优化器状态:静态分片,训练过程中保持不变</li><li>梯度:动态分片,每次反向传播后重新分配</li><li>参数:动态分片,根据计算需求动态加载</li></ul><h3 id="计算流程"><a href="#计算流程" class="headerlink" title="计算流程"></a>计算流程</h3><p><strong>前向传播</strong>:</p><ol><li>通过 All-Gather 收集当前层需要的参数</li><li>执行前向计算</li><li>释放不需要的参数</li></ol><p><strong>反向传播</strong>:</p><ol><li>通过 All-Gather 收集当前层需要的参数</li><li>计算梯度</li><li>将梯度分片存储</li><li>释放参数</li></ol><p><strong>参数更新</strong>:</p><ol><li>通过 All-Reduce 聚合所有梯度分片</li><li>每个 GPU 更新自己负责的参数部分</li><li>更新对应的优化器状态</li></ol><h2 id="ZeRO-的变体技术"><a href="#ZeRO-的变体技术" class="headerlink" title="ZeRO 的变体技术"></a>ZeRO 的变体技术</h2><h3 id="ZeRO-Offload"><a href="#ZeRO-Offload" class="headerlink" title="ZeRO-Offload"></a>ZeRO-Offload</h3><p><strong>原理</strong>:<br>对于模型训练中一些对性能不那么敏感,但内存占用大的部分(如优化器状态、甚至梯度),将其从 GPU 显存转移到 CPU 内存或硬盘(NVMe SSD)。</p><p><strong>具体做法</strong>:</p><ul><li>优化器状态存储在 CPU 内存中</li><li>梯度可以存储在 CPU 内存或 NVMe SSD 中</li><li>在需要时通过 PCIe 总线传输数据</li></ul><p><strong>优势</strong>:</p><ul><li>进一步减少 GPU 显存需求</li><li>能够训练更大的模型</li><li>成本相对较低</li></ul><p><strong>劣势</strong>:</p><ul><li>增加了 CPU-GPU 数据传输开销</li><li>训练速度可能有所下降</li></ul><h3 id="ZeRO-FSDP(Fully-Sharded-Data-Parallelism)"><a href="#ZeRO-FSDP(Fully-Sharded-Data-Parallelism)" class="headerlink" title="ZeRO-FSDP(Fully Sharded Data Parallelism)"></a>ZeRO-FSDP(Fully Sharded Data Parallelism)</h3><p><strong>原理</strong>:<br>ZeRO-FSDP 是 ZeRO-Stage 3 的完整实现,实现了优化器状态、梯度和模型参数的全面分片。</p><p><strong>特点</strong>:</p><ul><li>最大程度的内存优化</li><li>支持任意大小的模型训练</li><li>通信开销相对较高</li></ul>]]></content>
<summary type="html">
<p>在大语言模型(LLM)训练中,显存不足是一个普遍存在的问题。随着模型规模的不断增长,单个 GPU 的显存容量成为了训练大规模模型的主要瓶颈。DeepSpeed ZeRO(Zero Redundancy Optimizer)技术通过创新的数据分片策略,有效解决了这一问题,使得我们能够训练远超单卡显存上限的超大规模模型。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="training" scheme="https://murphypei.github.io/tags/training/"/>
<category term="zero" scheme="https://murphypei.github.io/tags/zero/"/>
<category term="deepspeed" scheme="https://murphypei.github.io/tags/deepspeed/"/>
<category term="memory" scheme="https://murphypei.github.io/tags/memory/"/>
</entry>
<entry>
<title>LLM 训练:GRPO 算法详解</title>
<link href="https://murphypei.github.io/blog/2025/07/llm-grpo.html"/>
<id>https://murphypei.github.io/blog/2025/07/llm-grpo.html</id>
<published>2025-07-22T17:43:43.000Z</published>
<updated>2026-01-13T03:49:48.689Z</updated>
<content type="html"><![CDATA[<p>在<a href="https://murphypei.github.io/blog/2025/05/llm-ppo.html">之前的 PPO 细节探讨</a>中,我们详细介绍了 PPO 算法和 GAE 优势估计。今天我们来深入探讨 GRPO(Group Relative Policy Optimization)算法,这是 PPO 在大语言模型训练中的一个重要改进版本。</p><span id="more"></span><h2 id="GRPO-算法概述"><a href="#GRPO-算法概述" class="headerlink" title="GRPO 算法概述"></a>GRPO 算法概述</h2><p>在 LLM 的 RLHF 训练中,PPO 虽然表现良好,但仍存在一些可以改进的地方:</p><ol><li><strong>GAE 的复杂性</strong>:PPO 使用 GAE 计算每个 token 的优势,涉及多个超参数($\lambda$、$\gamma$)的调优</li><li><strong>价值函数的需求</strong>:PPO 需要训练一个 Critic 模型来估计状态价值,增加计算和内存成本</li><li><strong>训练不稳定</strong>:在某些场景下,GAE 估计的方差仍然较大,导致训练波动</li></ol><p>GRPO 通过一个更直接的方法来解决这些问题:<strong>将序列级的奖励直接分配到 token 级别,然后通过组内相对比较来计算优势</strong>。这种方法在实践中被证明更加稳定高效。</p><p><strong>GRPO 的核心创新</strong>:</p><ol><li><strong>简化优势计算</strong>:不使用 GAE,而是通过组内相对奖励直接计算优势</li><li><strong>减少计算负担</strong>:虽然仍需要参考模型,但优势计算逻辑更简洁</li><li><strong>提高训练稳定性</strong>:通过序列奖励的分解和组内比较,减少方差</li><li><strong>Token 级别的精细控制</strong>:在每个 token 级别进行重要性采样和裁剪</li></ol><h3 id="GRPO-的动机"><a href="#GRPO-的动机" class="headerlink" title="GRPO 的动机"></a>GRPO 的动机</h3><p>GRPO 相比 PPO 的改进方向:</p><div class="table-container"><table><thead><tr><th>问题</th><th>PPO 解法</th><th>GRPO 改进</th></tr></thead><tbody><tr><td><strong>优势计算</strong></td><td>需要 Critic 估计价值,使用 GAE</td><td>不需要 Critic,直接使用组内相对比较</td></tr><tr><td><strong>方差</strong></td><td>GAE 通过多步累积,仍有较大方差</td><td>通过组内比较,方差更小</td></tr><tr><td><strong>超参数</strong></td><td>$\lambda, \gamma$ 需要调优</td><td>更少的超参数</td></tr><tr><td><strong>计算成本</strong></td><td>需要训练两个模型</td><td>只需要一个模型+参考模型</td></tr></tbody></table></div><h3 id="GRPO-的三个核心模型"><a href="#GRPO-的三个核心模型" class="headerlink" title="GRPO 的三个核心模型"></a>GRPO 的三个核心模型</h3><p>与 PPO 不同,GRPO 只需要三个模型:</p><div class="table-container"><table><thead><tr><th>模型</th><th>作用</th><th>参数更新</th><th>备注</th></tr></thead><tbody><tr><td><strong>Actor</strong></td><td>生成回答的策略模型</td><td>✓ 需要</td><td>最终要部署的模型</td></tr><tr><td><strong>Reference</strong></td><td>参考模型,计算 KL 约束</td><td>✗ 不需要</td><td>通常是 SFT 模型的初始状态</td></tr><tr><td><strong>Reward Model</strong></td><td>评分函数,评估整体质量</td><td>✗ 不需要</td><td>在 RLHF 第一阶段提前训练好</td></tr></tbody></table></div><p><strong>关键区别</strong>:GRPO 相比 PPO <strong>不需要 Critic 模型</strong>,这大大减少了训练的计算和内存成本。</p><hr><h2 id="GRPO-的算法原理"><a href="#GRPO-的算法原理" class="headerlink" title="GRPO 的算法原理"></a>GRPO 的算法原理</h2><h3 id="1-序列奖励的分解"><a href="#1-序列奖励的分解" class="headerlink" title="1. 序列奖励的分解"></a>1. 序列奖励的分解</h3><p>GRPO 的核心创新是如何将<strong>序列级的奖励</strong>分配到<strong>token 级别</strong>。</p><h4 id="奖励分配机制"><a href="#奖励分配机制" class="headerlink" title="奖励分配机制"></a>奖励分配机制</h4><p>对于长度为 $T$ 的生成序列,Reward Model 给出一个标量奖励 $r_{\text{RM}}$。GRPO 将这个奖励分配到每个 token:</p><script type="math/tex; mode=display">r_t = \begin{cases}0, & t = 0, 1, \ldots, T-1 \\r_{\text{RM}}, & t = T\end{cases}</script><p>即:<strong>只有最后一个 token(序列终止位置)获得奖励,中间所有 token 的奖励为 0</strong>。</p><p>这与 PPO 中的 KL 惩罚方式不同(PPO 中每个 token 都有 <script type="math/tex">-\beta \log \frac{\pi_\theta}{\pi_{\text{ref}}}</script> 的惩罚)。</p><h4 id="为什么这样分配"><a href="#为什么这样分配" class="headerlink" title="为什么这样分配"></a>为什么这样分配</h4><ul><li><strong>简洁性</strong>:不需要为每个 token 计算即时奖励和 KL 惩罚的组合</li><li><strong>合理性</strong>:反映了现实情况——我们只在完整回答后评分</li><li><strong>直观性</strong>:整个序列的质量用一个数字衡量,然后均等分配(或按距离衰减)</li></ul><h3 id="2-Token-级别的重要性采样与裁剪"><a href="#2-Token-级别的重要性采样与裁剪" class="headerlink" title="2. Token 级别的重要性采样与裁剪"></a>2. Token 级别的重要性采样与裁剪</h3><p>这是 GRPO 与 PPO 最重要的区别之一。</p><h4 id="PPO-的做法(token-级别)"><a href="#PPO-的做法(token-级别)" class="headerlink" title="PPO 的做法(token 级别)"></a>PPO 的做法(token 级别)</h4><p>PPO 对每个 token 计算重要性采样比率:</p><script type="math/tex; mode=display">r_t^{\text{PPO}}(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}</script><p>然后进行裁剪和加权:</p><script type="math/tex; mode=display">L^{\text{CLIP}} = \mathbb{E}_t \left[\min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)\right]</script><p><strong>特点</strong>:重要性采样是 <strong>token 级别</strong>的,优势 $A_t$ 也是 <strong>token 级别</strong>的。</p><h4 id="GRPO-的做法(仍是-token-级别,但优势不同)"><a href="#GRPO-的做法(仍是-token-级别,但优势不同)" class="headerlink" title="GRPO 的做法(仍是 token 级别,但优势不同)"></a>GRPO 的做法(仍是 token 级别,但优势不同)</h4><p>GRPO 同样对每个 token 计算重要性采样比率:</p><script type="math/tex; mode=display">r_t^{\text{GRPO}}(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{ref}}(a_t|s_t)}</script><p>然后进行裁剪和加权:</p><script type="math/tex; mode=display">L_{\text{seq}} = \frac{1}{T} \sum_{t=1}^{T} \min(r_t(\theta) \tilde{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \tilde{A}_t)</script><p>其中 $\tilde{A}_t$ 是 token 级别的归一化优势(见下一节)。</p><p><strong>重要</strong>:GRPO 仍然是 <strong>token 级别</strong>的重要性采样和裁剪,但优势的计算方式不同。</p><h3 id="3-组内相对优势计算"><a href="#3-组内相对优势计算" class="headerlink" title="3. 组内相对优势计算"></a>3. 组内相对优势计算</h3><p>这是 GRPO 最核心的创新。</p><h4 id="序列优势(第一步)"><a href="#序列优势(第一步)" class="headerlink" title="序列优势(第一步)"></a>序列优势(第一步)</h4><p>对于生成的每个序列 $i$,根据其奖励计算序列级优势:</p><script type="math/tex; mode=display">A_i^{\text{seq}} = r_i - \bar{r}</script><p>其中:</p><ul><li>$r_i$:第 $i$ 个序列的 RM 奖励</li><li>$\bar{r} = \frac{1}{G} \sum_{j=1}^{G} r_j$:组内(所有 $G$ 个序列)的平均奖励</li></ul><p><strong>含义</strong>:这个序列相对于组内平均水平的相对优势。</p><h4 id="Token-级优势(第二步)"><a href="#Token-级优势(第二步)" class="headerlink" title="Token 级优势(第二步)"></a>Token 级优势(第二步)</h4><p>GRPO 将序列优势分配给序列中的每个 token。最简单的做法是<strong>均等分配</strong>:</p><script type="math/tex; mode=display">\tilde{A}_t^{(i)} = \frac{A_i^{\text{seq}}}{T_i}</script><p>其中 $T_i$ 是第 $i$ 个序列的长度。</p><p><strong>也可以使用折扣分配</strong>(越靠近末端权重越高):</p><script type="math/tex; mode=display">\tilde{A}_t^{(i)} = \frac{\gamma^{T_i-t} A_i^{\text{seq}}}{T_i}</script><p>其中 $\gamma \in [0,1]$ 是折扣因子(通常 0.99)。折扣分配的好处是给予序列末端(生成过程的最后)更高的权重,这更符合直觉:序列的最后部分对最终质量影响更大。</p><h4 id="优势归一化"><a href="#优势归一化" class="headerlink" title="优势归一化"></a>优势归一化</h4><p>为了稳定训练,对组内的优势进行归一化:</p><script type="math/tex; mode=display">\hat{A}_t^{(i)} = \frac{\tilde{A}_t^{(i)} - \mu_A}{\sigma_A + \epsilon}</script><p>其中:</p><ul><li>$\mu_A = \frac{1}{G \cdot T} \sum_i \sum_t \tilde{A}_t^{(i)}$:所有 token 优势的平均值</li><li>$\sigma_A$:所有 token 优势的标准差</li><li>$\epsilon$:小常数防止除零</li></ul><h3 id="4-完整的损失函数"><a href="#4-完整的损失函数" class="headerlink" title="4. 完整的损失函数"></a>4. 完整的损失函数</h3><p>GRPO 的损失函数包含两部分:</p><script type="math/tex; mode=display">L_{\text{GRPO}} = L_{\text{Policy}} + \beta \cdot L_{\text{KL}}</script><h4 id="策略损失"><a href="#策略损失" class="headerlink" title="策略损失"></a>策略损失</h4><script type="math/tex; mode=display">L_{\text{Policy}} = -\mathbb{E}_{i,t} \left[\log \pi_\theta(a_t|s_t) \hat{A}_t^{(i)}\right]</script><p>或者使用 PPO 风格的裁剪:</p><script type="math/tex; mode=display">L_{\text{Policy}}^{\text{CLIP}} = \mathbb{E}_{i,t} \left[\min(r_t(\theta) \hat{A}_t^{(i)}, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t^{(i)})\right]</script><h4 id="KL-约束损失"><a href="#KL-约束损失" class="headerlink" title="KL 约束损失"></a>KL 约束损失</h4><script type="math/tex; mode=display">L_{\text{KL}} = \mathbb{E}_{i,t} \left[D_{\text{KL}}(\pi_\theta(\cdot|s_t) \| \pi_{\text{ref}}(\cdot|s_t))\right]</script><p><strong>注意</strong>:GRPO 仍然保留 KL 约束,防止策略偏离参考模型太远。</p><hr><h2 id="GRPO-与-PPO-的详细对比"><a href="#GRPO-与-PPO-的详细对比" class="headerlink" title="GRPO 与 PPO 的详细对比"></a>GRPO 与 PPO 的详细对比</h2><h3 id="模型架构对比"><a href="#模型架构对比" class="headerlink" title="模型架构对比"></a>模型架构对比</h3><div class="table-container"><table><thead><tr><th>维度</th><th>PPO</th><th>GRPO</th></tr></thead><tbody><tr><td><strong>Actor</strong></td><td>✓ 需要训练</td><td>✓ 需要训练</td></tr><tr><td><strong>Critic</strong></td><td>✓ 需要训练</td><td>✗ <strong>不需要</strong></td></tr><tr><td><strong>Reference</strong></td><td>✓ 需要冻结</td><td>✓ 需要冻结</td></tr><tr><td><strong>Reward Model</strong></td><td>✓ 需要冻结</td><td>✓ 需要冻结</td></tr><tr><td><strong>总体模型数</strong></td><td>4 个</td><td>3 个</td></tr><tr><td><strong>可训练参数</strong></td><td>Actor + Critic</td><td><strong>仅 Actor</strong></td></tr></tbody></table></div><h3 id="优势计算对比"><a href="#优势计算对比" class="headerlink" title="优势计算对比"></a>优势计算对比</h3><div class="table-container"><table><thead><tr><th>维度</th><th>PPO</th><th>GRPO</th></tr></thead><tbody><tr><td><strong>信息来源</strong></td><td>单个序列的 token 奖励 + Critic 价值估计</td><td>组内序列的相对奖励</td></tr><tr><td><strong>计算方式</strong></td><td>GAE(加权累加多步 TD 误差)</td><td>组内平均 - 单序列奖励</td></tr><tr><td><strong>需要模型</strong></td><td>Critic 模型</td><td><strong>无需额外模型</strong></td></tr><tr><td><strong>超参数</strong></td><td>$\lambda, \gamma, \beta_{\text{vf}}$ 等</td><td>组大小 $G$、折扣 $\gamma$、权重 $\beta$</td></tr><tr><td><strong>方差</strong></td><td>中等(GAE 加权后)</td><td><strong>较低</strong>(组内比较)</td></tr><tr><td><strong>偏差</strong></td><td>低(无模型依赖)</td><td>中等(依赖同组其他序列)</td></tr></tbody></table></div><h3 id="重要性采样对比"><a href="#重要性采样对比" class="headerlink" title="重要性采样对比"></a>重要性采样对比</h3><div class="table-container"><table><thead><tr><th>维度</th><th>PPO</th><th>GRPO</th></tr></thead><tbody><tr><td><strong>采样粒度</strong></td><td>Token 级别</td><td><strong>Token 级别</strong>(与 PPO 相同)</td></tr><tr><td><strong>参考模型</strong></td><td>$\pi<em>{\theta</em>{\text{old}}}$(旧策略)</td><td>$\pi_{\text{ref}}$(参考模型)</td></tr><tr><td><strong>裁剪范围</strong></td><td>Token 级别</td><td><strong>Token 级别</strong>(与 PPO 相同)</td></tr><tr><td><strong>损失聚合</strong></td><td>先平均 token 再平均序列</td><td><strong>先平均 token 再平均序列</strong>(相同)</td></tr></tbody></table></div><p><strong>关键相同点</strong>:GRPO 和 PPO 都在 token 级别进行重要性采样和裁剪,区别在于优势的计算来源。</p><h3 id="优缺点对比"><a href="#优缺点对比" class="headerlink" title="优缺点对比"></a>优缺点对比</h3><h4 id="PPO-的优势"><a href="#PPO-的优势" class="headerlink" title="PPO 的优势"></a>PPO 的优势</h4><ol><li><strong>原理清晰</strong>:基于价值函数的强化学习理论</li><li><strong>广泛应用</strong>:在各种 RL 任务中都有良好表现</li><li><strong>灵活性强</strong>:可以适应各种奖励结构</li><li><strong>理论支持</strong>:GAE 有深厚的理论基础</li></ol><h4 id="PPO-的缺陷"><a href="#PPO-的缺陷" class="headerlink" title="PPO 的缺陷"></a>PPO 的缺陷</h4><ol><li><strong>模型多</strong>:需要训练两个大模型(Actor + Critic)</li><li><strong>Critic 不准确</strong>:Critic 的价值估计可能不准确,导致优势估计有偏差</li><li><strong>超参数多</strong>:GAE 的 $\lambda, \gamma$ 需要调优</li><li><strong>计算成本高</strong>:训练 Critic 占用显存和计算力</li></ol><h4 id="GRPO-的优势"><a href="#GRPO-的优势" class="headerlink" title="GRPO 的优势"></a>GRPO 的优势</h4><ol><li><strong>模型少</strong>:无需额外训练 Critic,只需 Actor</li><li><strong>方差小</strong>:组内比较天然降低方差,不依赖模型准确性</li><li><strong>简洁直接</strong>:优势计算逻辑简单明了</li><li><strong>计算成本低</strong>:节省 Critic 的计算和内存</li></ol><h4 id="GRPO-的缺陷"><a href="#GRPO-的缺陷" class="headerlink" title="GRPO 的缺陷"></a>GRPO 的缺陷</h4><ol><li><strong>需要分组</strong>:必须采样一个组的序列,组大小 $G$ 需要调优</li><li><strong>依赖组质量</strong>:如果组内所有序列都很差,相对优势可能无法有效指导</li><li><strong>序列奖励稀疏</strong>:不能充分利用中间步骤的反馈信息</li><li><strong>较少应用</strong>:目前应用不如 PPO 广泛</li></ol><hr><h2 id="GRPO-的实现细节"><a href="#GRPO-的实现细节" class="headerlink" title="GRPO 的实现细节"></a>GRPO 的实现细节</h2><h3 id="1-分组采样"><a href="#1-分组采样" class="headerlink" title="1. 分组采样"></a>1. 分组采样</h3><p>GRPO 的关键是在每次更新时采样一个<strong>组</strong>的序列。</p><h4 id="采样策略"><a href="#采样策略" class="headerlink" title="采样策略"></a>采样策略</h4><ol><li><strong>输入</strong>:一个 batch 的 prompts,共 $B$ 个</li><li><strong>采样</strong>:对每个 prompt,采样 $G$ 个序列(即 $G$ 个不同的回答)</li><li><strong>结果</strong>:得到 $B \times G$ 个序列</li></ol><h4 id="组大小选择"><a href="#组大小选择" class="headerlink" title="组大小选择"></a>组大小选择</h4><ul><li><strong>$G = 4$</strong>:计算效率高,但相对优势估计可能不够准确</li><li><strong>$G = 8 \sim 16$</strong>:平衡效率和准确性(推荐)</li><li><strong>$G = 32$</strong>:优势估计准确,但计算成本大</li></ul><p><strong>实践建议</strong>:根据显存和计算资源选择,通常 $G = 8$ 或 $G = 16$。</p><h3 id="2-奖励获取与归一化"><a href="#2-奖励获取与归一化" class="headerlink" title="2. 奖励获取与归一化"></a>2. 奖励获取与归一化</h3><h4 id="奖励获取流程"><a href="#奖励获取流程" class="headerlink" title="奖励获取流程"></a>奖励获取流程</h4><p>对每个序列:</p><ol><li><strong>Actor 生成</strong>:Actor 生成完整序列</li><li><strong>RM 评分</strong>:将序列输入 RM,得到标量奖励 $r_{\text{RM}}$</li><li><strong>记录</strong>:保存奖励值</li></ol><h4 id="奖励归一化"><a href="#奖励归一化" class="headerlink" title="奖励归一化"></a>奖励归一化</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 对每个 prompt 的 G 个序列的奖励进行归一化</span></span><br><span class="line"><span class="keyword">for</span> prompt_idx <span class="keyword">in</span> <span class="built_in">range</span>(B):</span><br><span class="line"> rewards = [r[prompt_idx, :] <span class="keyword">for</span> r <span class="keyword">in</span> reward_scores] <span class="comment"># G 个奖励</span></span><br><span class="line"> mu = np.mean(rewards)</span><br><span class="line"> sigma = np.std(rewards)</span><br><span class="line"> normalized_rewards = (rewards - mu) / (sigma + <span class="number">1e-8</span>)</span><br></pre></td></tr></table></figure><p><strong>为什么归一化</strong>:</p><ul><li>不同的 prompt 可能有不同的奖励尺度</li><li>归一化后的优势在 0 附近,梯度更稳定</li></ul><h3 id="3-Token-优势计算"><a href="#3-Token-优势计算" class="headerlink" title="3. Token 优势计算"></a>3. Token 优势计算</h3><h4 id="算法步骤"><a href="#算法步骤" class="headerlink" title="算法步骤"></a>算法步骤</h4><p>对于第 $i$ 个序列(长度 $T_i$,奖励 $r_i$):</p><p><strong>步骤 1</strong>:计算序列优势</p><script type="math/tex; mode=display">A_i^{\text{seq}} = r_i - \bar{r}</script><p><strong>步骤 2</strong>:分配到每个 token(选择一种)</p><p><em>方式 A:均等分配</em></p><script type="math/tex; mode=display">\tilde{A}_t^{(i)} = \frac{A_i^{\text{seq}}}{T_i}</script><p><em>方式 B:折扣分配</em></p><script type="math/tex; mode=display">\tilde{A}_t^{(i)} = \frac{\gamma^{T_i-t} A_i^{\text{seq}}}{\sum_{k=1}^{T_i} \gamma^{T_i-k}}</script><p><strong>步骤 3</strong>:全局归一化</p><script type="math/tex; mode=display">\hat{A}_t^{(i)} = \frac{\tilde{A}_t^{(i)} - \mu_A}{\sigma_A + \epsilon}</script><h4 id="Python-伪代码"><a href="#Python-伪代码" class="headerlink" title="Python 伪代码"></a>Python 伪代码</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">compute_advantages</span>(<span class="params">rewards, sequence_lengths, discount=<span class="number">0.99</span></span>):</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> rewards: shape (B, G) - B 个 prompts,每个 G 个序列的奖励</span></span><br><span class="line"><span class="string"> sequence_lengths: shape (B, G) - 每个序列的长度</span></span><br><span class="line"><span class="string"> discount: 折扣因子</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> Returns:</span></span><br><span class="line"><span class="string"> advantages: shape (B, G, T_max) - 每个 token 的优势</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> B, G, T_max = rewards.shape[<span class="number">0</span>], rewards.shape[<span class="number">1</span>], sequence_lengths.<span class="built_in">max</span>()</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 全局平均奖励</span></span><br><span class="line"> global_mean = rewards.mean()</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 序列优势</span></span><br><span class="line"> seq_advantages = rewards - global_mean <span class="comment"># shape (B, G)</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># token 级优势(折扣分配)</span></span><br><span class="line"> token_advantages = np.zeros((B, G, T_max))</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> b <span class="keyword">in</span> <span class="built_in">range</span>(B):</span><br><span class="line"> <span class="keyword">for</span> g <span class="keyword">in</span> <span class="built_in">range</span>(G):</span><br><span class="line"> T = sequence_lengths[b, g]</span><br><span class="line"> A_seq = seq_advantages[b, g]</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算折扣和</span></span><br><span class="line"> discount_sum = <span class="built_in">sum</span>(discount**(T-<span class="number">1</span>-t) <span class="keyword">for</span> t <span class="keyword">in</span> <span class="built_in">range</span>(T))</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 分配到每个 token</span></span><br><span class="line"> <span class="keyword">for</span> t <span class="keyword">in</span> <span class="built_in">range</span>(T):</span><br><span class="line"> token_advantages[b, g, t] = (</span><br><span class="line"> discount**(T-<span class="number">1</span>-t) * A_seq / discount_sum</span><br><span class="line"> )</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 全局归一化</span></span><br><span class="line"> mean_adv = token_advantages.mean()</span><br><span class="line"> std_adv = token_advantages.std()</span><br><span class="line"> normalized_advantages = (token_advantages - mean_adv) / (std_adv + <span class="number">1e-8</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> normalized_advantages</span><br></pre></td></tr></table></figure><h3 id="4-损失计算与反向传播"><a href="#4-损失计算与反向传播" class="headerlink" title="4. 损失计算与反向传播"></a>4. 损失计算与反向传播</h3><h4 id="Token-级别的损失计算"><a href="#Token-级别的损失计算" class="headerlink" title="Token 级别的损失计算"></a>Token 级别的损失计算</h4><p>对于第 $i$ 个序列的第 $t$ 个 token:</p><p><strong>重要性采样比率</strong>:</p><script type="math/tex; mode=display">r_t^{(i)} = \frac{\pi_\theta(a_t^{(i)}|s_t^{(i)})}{\pi_{\text{ref}}(a_t^{(i)}|s_t^{(i)})}</script><p><strong>PPO 风格的裁剪损失</strong>:</p><script type="math/tex; mode=display">\ell_t^{(i)} = \min\left(r_t^{(i)} \hat{A}_t^{(i)}, \text{clip}(r_t^{(i)}, 1-\epsilon, 1+\epsilon) \hat{A}_t^{(i)}\right)</script><p>其中 $\epsilon = 0.2$(裁剪范围)。</p><h4 id="KL-约束"><a href="#KL-约束" class="headerlink" title="KL 约束"></a>KL 约束</h4><script type="math/tex; mode=display">\ell_t^{\text{KL},(i)} = D_{\text{KL}}(\pi_\theta(\cdot|s_t^{(i)}) \| \pi_{\text{ref}}(\cdot|s_t^{(i)}))</script><p><strong>计算方式</strong>(对于语言模型):</p><script type="math/tex; mode=display">D_{\text{KL}} = \sum_v \pi_\theta(v|s) \log \frac{\pi_\theta(v|s)}{\pi_{\text{ref}}(v|s)}</script><p>在实际实现中,通常只计算被生成 token 处的 KL:</p><script type="math/tex; mode=display">\ell_t^{\text{KL},(i)} \approx \log \pi_\theta(a_t^{(i)}|s_t^{(i)}) - \log \pi_{\text{ref}}(a_t^{(i)}|s_t^{(i)})</script><h4 id="总体损失"><a href="#总体损失" class="headerlink" title="总体损失"></a>总体损失</h4><p><strong>单个序列的损失</strong>(先 token 后序列平均):</p><script type="math/tex; mode=display">L_i = \frac{1}{T_i} \sum_{t=1}^{T_i} \left[\ell_t^{(i)} + \beta \ell_t^{\text{KL},(i)}\right]</script><p><strong>整个组的损失</strong>(序列平均):</p><script type="math/tex; mode=display">L = \frac{1}{B \cdot G} \sum_{i=1}^{B \times G} L_i</script><h4 id="Python-伪代码-1"><a href="#Python-伪代码-1" class="headerlink" title="Python 伪代码"></a>Python 伪代码</h4><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">compute_loss</span>(<span class="params">logits, reference_logits, advantages, token_ids, beta=<span class="number">0.05</span>, epsilon=<span class="number">0.2</span></span>):</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> logits: shape (B, G, T, V) - Actor 模型输出</span></span><br><span class="line"><span class="string"> reference_logits: shape (B, G, T, V) - Reference 模型输出</span></span><br><span class="line"><span class="string"> advantages: shape (B, G, T) - 归一化优势</span></span><br><span class="line"><span class="string"> token_ids: shape (B, G, T) - 生成的 token IDs</span></span><br><span class="line"><span class="string"> beta: KL 权重</span></span><br><span class="line"><span class="string"> epsilon: 裁剪范围</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> Returns:</span></span><br><span class="line"><span class="string"> loss: 标量损失</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> B, G, T, V = logits.shape</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算 log 概率</span></span><br><span class="line"> log_probs = log_softmax(logits, dim=-<span class="number">1</span>) <span class="comment"># shape (B, G, T, V)</span></span><br><span class="line"> ref_log_probs = log_softmax(reference_logits, dim=-<span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 选择生成 token 的概率</span></span><br><span class="line"> chosen_log_probs = log_probs.gather(-<span class="number">1</span>, token_ids.unsqueeze(-<span class="number">1</span>)).squeeze(-<span class="number">1</span>)</span><br><span class="line"> chosen_ref_log_probs = ref_log_probs.gather(-<span class="number">1</span>, token_ids.unsqueeze(-<span class="number">1</span>)).squeeze(-<span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 重要性采样比率</span></span><br><span class="line"> log_ratio = chosen_log_probs - chosen_ref_log_probs <span class="comment"># log(pi/pi_ref)</span></span><br><span class="line"> ratio = exp(log_ratio) <span class="comment"># pi/pi_ref</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># PPO 裁剪损失</span></span><br><span class="line"> clipped_ratio = clamp(ratio, <span class="number">1</span> - epsilon, <span class="number">1</span> + epsilon)</span><br><span class="line"> policy_loss = -<span class="built_in">min</span>(ratio * advantages, clipped_ratio * advantages)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># KL 散度(简化版,只计算生成的 token)</span></span><br><span class="line"> kl_loss = log_ratio <span class="comment"># 近似</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 总损失</span></span><br><span class="line"> total_loss = policy_loss + beta * kl_loss</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 求平均</span></span><br><span class="line"> <span class="keyword">return</span> total_loss.mean()</span><br></pre></td></tr></table></figure><h3 id="5-超参数设置"><a href="#5-超参数设置" class="headerlink" title="5. 超参数设置"></a>5. 超参数设置</h3><h4 id="关键超参数"><a href="#关键超参数" class="headerlink" title="关键超参数"></a>关键超参数</h4><div class="table-container"><table><thead><tr><th>超参数</th><th>典型值</th><th>范围</th><th>说明</th></tr></thead><tbody><tr><td><strong>$G$ (组大小)</strong></td><td>8-16</td><td>4-32</td><td>更大的组更稳定但计算成本高</td></tr><tr><td><strong>$\epsilon$ (裁剪范围)</strong></td><td>0.2</td><td>0.1-0.3</td><td>限制重要性采样比率的范围</td></tr><tr><td><strong>$\beta$ (KL 权重)</strong></td><td>0.05</td><td>0.01-0.2</td><td>平衡奖励追求和分布约束</td></tr><tr><td><strong>$\gamma$ (折扣因子)</strong></td><td>0.99</td><td>0.95-1.0</td><td>控制到末端 token 的权重衰减</td></tr><tr><td><strong>学习率</strong></td><td>5e-7</td><td>1e-7-1e-6</td><td>较小的学习率,通常比 SFT 小 100 倍</td></tr><tr><td><strong>批大小</strong></td><td>$B=8, G=16$</td><td>-</td><td>总共 $B \times G$ 个序列</td></tr></tbody></table></div><h4 id="调优建议"><a href="#调优建议" class="headerlink" title="调优建议"></a>调优建议</h4><ol><li><p><strong>$G$ 的选择</strong>:</p><ul><li>显存充足:使用 $G = 16$ 或更大</li><li>显存紧张:使用 $G = 4$ 或 $G = 8$</li><li>避免 $G$ 过小($< 4$),否则相对优势不可靠</li></ul></li><li><p><strong>$\beta$ 的调优</strong>:</p><ul><li>若 KL 散度过大:增大 $\beta$</li><li>若训练不动(模型变化太慢):减小 $\beta$</li><li>推荐从 $0.05$ 开始,根据 KL 散度监测值调整</li></ul></li><li><p><strong>学习率的选择</strong>:</p><ul><li>GRPO 的学习率通常比 SFT 小 100-1000 倍</li><li>从 5e-7 开始,根据梯度范数调整</li><li>避免学习率过大,否则训练不稳定</li></ul></li><li><p><strong>折扣因子 $\gamma$</strong>:</p><ul><li>$\gamma = 0.99$:末端 token 权重约为前端的 $0.37$ 倍</li><li>$\gamma = 1.0$:均等分配(最简单)</li><li>推荐从 $\gamma = 1.0$ 开始,必要时调整</li></ul></li></ol><hr><h2 id="GRPO-的完整训练流程"><a href="#GRPO-的完整训练流程" class="headerlink" title="GRPO 的完整训练流程"></a>GRPO 的完整训练流程</h2><h3 id="算法伪代码"><a href="#算法伪代码" class="headerlink" title="算法伪代码"></a>算法伪代码</h3><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br></pre></td><td class="code"><pre><span class="line">算法:GRPO 训练</span><br><span class="line"></span><br><span class="line">初始化:</span><br><span class="line"> - Actor 模型 π_θ(可以从 SFT 初始化)</span><br><span class="line"> - Reference 模型 π_ref(冻结,使用 SFT 模型或 Actor 的检查点)</span><br><span class="line"> - Reward Model R(冻结,提前训练)</span><br><span class="line"></span><br><span class="line">主循环:</span><br><span class="line"> for epoch = 1 to num_epochs:</span><br><span class="line"> for batch_idx = 1 to num_batches:</span><br><span class="line"> </span><br><span class="line"> # 步骤 1: 采样</span><br><span class="line"> prompts = get_batch() # B 个 prompts</span><br><span class="line"> sequences = []</span><br><span class="line"> log_probs = []</span><br><span class="line"> </span><br><span class="line"> for g = 1 to G:</span><br><span class="line"> # 使用 Actor 生成 G 个回答</span><br><span class="line"> seq_g, logprobs_g = generate_from_actor(prompts, π_θ)</span><br><span class="line"> sequences.append(seq_g)</span><br><span class="line"> log_probs.append(logprobs_g)</span><br><span class="line"> </span><br><span class="line"> # 步骤 2: 获取奖励</span><br><span class="line"> rewards = reward_model.score(sequences) # shape (B, G)</span><br><span class="line"> </span><br><span class="line"> # 步骤 3: 计算优势</span><br><span class="line"> advantages = compute_advantages(rewards, sequences) # shape (B, G, T)</span><br><span class="line"> </span><br><span class="line"> # 步骤 4: 前向传播</span><br><span class="line"> actor_outputs = actor(sequences) # 获取 Actor 的 logits</span><br><span class="line"> ref_outputs = reference_model(sequences) # 获取 Reference 的 logits</span><br><span class="line"> </span><br><span class="line"> # 步骤 5: 计算损失</span><br><span class="line"> loss = 0</span><br><span class="line"> for each token:</span><br><span class="line"> ratio = exp(log_pi_theta - log_pi_ref)</span><br><span class="line"> clipped_ratio = clip(ratio, 1-ε, 1+ε)</span><br><span class="line"> policy_loss += min(ratio * A, clipped_ratio * A)</span><br><span class="line"> kl_loss += (log_pi_theta - log_pi_ref)</span><br><span class="line"> </span><br><span class="line"> total_loss = policy_loss + β * kl_loss</span><br><span class="line"> </span><br><span class="line"> # 步骤 6: 反向传播和更新</span><br><span class="line"> optimizer.zero_grad()</span><br><span class="line"> total_loss.backward()</span><br><span class="line"> optimizer.step()</span><br><span class="line"> </span><br><span class="line"> # 步骤 7: 监测</span><br><span class="line"> log(loss, kl_div, reward_mean, reward_std)</span><br></pre></td></tr></table></figure><h3 id="训练流程图"><a href="#训练流程图" class="headerlink" title="训练流程图"></a>训练流程图</h3><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br></pre></td><td class="code"><pre><span class="line">┌─────────────────────────────────────────────────────┐</span><br><span class="line">│ 输入:Batch of Prompts │</span><br><span class="line">└─────────────────────┬───────────────────────────────┘</span><br><span class="line"> ↓</span><br><span class="line"> ┌─────────────────────┐</span><br><span class="line"> │ Actor 采样 G 个序列 │ (并行,每个 prompt)</span><br><span class="line"> └──────────┬──────────┘</span><br><span class="line"> ↓</span><br><span class="line"> ┌────────────────────────────────┐</span><br><span class="line"> │ Reward Model 评分每个序列 │</span><br><span class="line"> │ 得到 (B, G) 的奖励矩阵 │</span><br><span class="line"> └────────────┬───────────────────┘</span><br><span class="line"> ↓</span><br><span class="line"> ┌────────────────────────────────┐</span><br><span class="line"> │ 计算序列优势(组内相对比较) │</span><br><span class="line"> │ A_i^seq = r_i - mean(r) │</span><br><span class="line"> └────────────┬───────────────────┘</span><br><span class="line"> ↓</span><br><span class="line"> ┌────────────────────────────────┐</span><br><span class="line"> │ 分配到 Token 级优势 │</span><br><span class="line"> │ A_t = discount^(T-t) * A_seq │</span><br><span class="line"> └────────────┬───────────────────┘</span><br><span class="line"> ↓</span><br><span class="line"> ┌────────────────────────────────┐</span><br><span class="line"> │ 全局归一化优势 │</span><br><span class="line"> │ A_hat = (A - mean) / std │</span><br><span class="line"> └────────────┬───────────────────┘</span><br><span class="line"> ↓</span><br><span class="line"> ┌────────────────────────────────┐</span><br><span class="line"> │ 计算 PPO 损失(token 级别) │</span><br><span class="line"> │ L = min(r*A, clip(r)*A) │</span><br><span class="line"> │ + β * KL(π_θ || π_ref) │</span><br><span class="line"> └────────────┬───────────────────┘</span><br><span class="line"> ↓</span><br><span class="line"> ┌────────────────────────────────┐</span><br><span class="line"> │ 反向传播和梯度更新 │</span><br><span class="line"> │ θ ← θ - α ∇L │</span><br><span class="line"> └────────────┬───────────────────┘</span><br><span class="line"> ↓</span><br><span class="line"> ┌──────────────────┐</span><br><span class="line"> │ 输出:更新的 Actor │</span><br><span class="line"> └──────────────────┘</span><br></pre></td></tr></table></figure><hr><h2 id="GRPO-完整数值示例"><a href="#GRPO-完整数值示例" class="headerlink" title="GRPO 完整数值示例"></a>GRPO 完整数值示例</h2><p>让我们用一个具体例子演示整个计算过程。</p><h3 id="假设条件"><a href="#假设条件" class="headerlink" title="假设条件"></a>假设条件</h3><ul><li><strong>批大小</strong>:$B = 1$(1 个 prompt)</li><li><strong>组大小</strong>:$G = 2$(2 个序列)</li><li><strong>序列长度</strong>:$T_1 = 3, T_2 = 4$(两个序列长度不同)</li><li><strong>奖励</strong>:$r_1 = 0.8, r_2 = 0.6$(第一个序列更好)</li><li><strong>折扣因子</strong>:$\gamma = 0.99$</li><li><strong>裁剪范围</strong>:$\epsilon = 0.2$</li><li><strong>KL 权重</strong>:$\beta = 0.05$</li></ul><h3 id="计算过程"><a href="#计算过程" class="headerlink" title="计算过程"></a>计算过程</h3><h4 id="步骤-1-计算序列优势"><a href="#步骤-1-计算序列优势" class="headerlink" title="步骤 1: 计算序列优势"></a>步骤 1: 计算序列优势</h4><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">序列 1: r_1 = 0.8, T_1 = 3</span><br><span class="line">序列 2: r_2 = 0.6, T_2 = 4</span><br><span class="line"></span><br><span class="line">平均奖励: r_bar = (0.8 + 0.6) / 2 = 0.7</span><br><span class="line"></span><br><span class="line">序列优势:</span><br><span class="line"> A_1^seq = 0.8 - 0.7 = 0.1</span><br><span class="line"> A_2^seq = 0.6 - 0.7 = -0.1</span><br></pre></td></tr></table></figure><p><strong>解释</strong>:序列 1 相对于平均水平好一点,序列 2 相对于平均水平差一点。</p><h4 id="步骤-2-分配到-Token-级(使用折扣分配)"><a href="#步骤-2-分配到-Token-级(使用折扣分配)" class="headerlink" title="步骤 2: 分配到 Token 级(使用折扣分配)"></a>步骤 2: 分配到 Token 级(使用折扣分配)</h4><p><strong>折扣分配公式</strong>:</p><script type="math/tex; mode=display">\tilde{A}_t^{(i)} = \frac{\gamma^{T_i-t} A_i^{\text{seq}}}{\sum_{k=1}^{T_i} \gamma^{T_i-k}}</script><p>其中分母是折扣权重的和。</p><p><strong>序列 1 计算</strong>($T_1 = 3$):<br><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">折扣权重: γ^(3-1)=0.9801, γ^(3-2)=0.99, γ^(3-3)=1.0</span><br><span class="line">权重和: 0.9801 + 0.99 + 1.0 = 2.9701</span><br><span class="line"></span><br><span class="line">Token 1: A_1^(1) = (0.9801 * 0.1) / 2.9701 = 0.0330</span><br><span class="line">Token 2: A_2^(1) = (0.99 * 0.1) / 2.9701 = 0.0333</span><br><span class="line">Token 3: A_3^(1) = (1.0 * 0.1) / 2.9701 = 0.0337</span><br></pre></td></tr></table></figure></p><p><strong>序列 2 计算</strong>($T_2 = 4$):<br><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">折扣权重: γ^(4-1)=0.9703, γ^(4-2)=0.9801, γ^(4-3)=0.99, γ^(4-4)=1.0</span><br><span class="line">权重和: 0.9703 + 0.9801 + 0.99 + 1.0 = 3.9404</span><br><span class="line"></span><br><span class="line">Token 1: A_1^(2) = (0.9703 * (-0.1)) / 3.9404 = -0.0246</span><br><span class="line">Token 2: A_2^(2) = (0.9801 * (-0.1)) / 3.9404 = -0.0249</span><br><span class="line">Token 3: A_3^(2) = (0.99 * (-0.1)) / 3.9404 = -0.0251</span><br><span class="line">Token 4: A_4^(2) = (1.0 * (-0.1)) / 3.9404 = -0.0254</span><br></pre></td></tr></table></figure></p><p><strong>观察</strong>:</p><ul><li>序列 1 的最后 token (0.0337) 权重高于最前 token (0.0330)</li><li>序列 2 的所有优势都是负的(因为序列质量较差)</li></ul><h4 id="步骤-3-全局归一化"><a href="#步骤-3-全局归一化" class="headerlink" title="步骤 3: 全局归一化"></a>步骤 3: 全局归一化</h4><p>所有 token 的优势值:<br><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">序列 1: [0.0330, 0.0333, 0.0337]</span><br><span class="line">序列 2: [-0.0246, -0.0249, -0.0251, -0.0254]</span><br><span class="line"></span><br><span class="line">合并: [0.0330, 0.0333, 0.0337, -0.0246, -0.0249, -0.0251, -0.0254]</span><br><span class="line"></span><br><span class="line">均值: μ_A = (0.0330 + 0.0333 + 0.0337 - 0.0246 - 0.0249 - 0.0251 - 0.0254) / 7</span><br><span class="line"> = 0 / 7 = 0 (近似)</span><br><span class="line"></span><br><span class="line">方差: σ_A ≈ 0.0291</span><br></pre></td></tr></table></figure></p><p>归一化后(减去均值、除以标准差):<br><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">序列 1: [0.1134, 0.1145, 0.1159] (正值,鼓励)</span><br><span class="line">序列 2: [-0.0846, -0.0856, -0.0863, -0.0873] (负值,惩罚)</span><br></pre></td></tr></table></figure></p><h4 id="步骤-4-计算-PPO-损失"><a href="#步骤-4-计算-PPO-损失" class="headerlink" title="步骤 4: 计算 PPO 损失"></a>步骤 4: 计算 PPO 损失</h4><p>假设 Actor 和 Reference 模型在序列 1 的第 1 个 token 处:</p><ul><li>Actor log_prob: -1.5</li><li>Reference log_prob: -1.6</li><li>生成的 token ID</li></ul><p><strong>计算过程</strong>:<br><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">log_ratio = -1.5 - (-1.6) = 0.1</span><br><span class="line">ratio = exp(0.1) ≈ 1.105</span><br><span class="line"></span><br><span class="line">优势: A_hat_1^(1) = 0.1134</span><br><span class="line"></span><br><span class="line">无裁剪: ratio * A_hat = 1.105 * 0.1134 ≈ 0.1253</span><br><span class="line">有裁剪: clip(ratio, 0.8, 1.2) = 1.105</span><br><span class="line"> clipped * A_hat = 1.105 * 0.1134 ≈ 0.1253</span><br><span class="line"></span><br><span class="line">PPO loss: min(0.1253, 0.1253) = 0.1253</span><br></pre></td></tr></table></figure></p><h4 id="步骤-5-KL-散度"><a href="#步骤-5-KL-散度" class="headerlink" title="步骤 5: KL 散度"></a>步骤 5: KL 散度</h4><p>假设 KL 散度计算(简化):<br><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">KL_1^(1) = log_ratio = 0.1</span><br><span class="line"></span><br><span class="line">总 KL 损失 (简化,仅示意):</span><br><span class="line">L_KL = β * log_ratio_sum / num_tokens</span><br><span class="line"> = 0.05 * (所有 log_ratio 之和) / 7</span><br></pre></td></tr></table></figure></p><h4 id="最终损失"><a href="#最终损失" class="headerlink" title="最终损失"></a>最终损失</h4><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">序列 1 损失: L_1 = (L_policy_1 + β * L_KL_1) / 3</span><br><span class="line">序列 2 损失: L_2 = (L_policy_2 + β * L_KL_2) / 4</span><br><span class="line"></span><br><span class="line">总损失: L = (L_1 + L_2) / 2</span><br></pre></td></tr></table></figure><p><strong>关键观察</strong>:</p><ul><li>序列 1 的优势为正,鼓励增加这些 token 的概率</li><li>序列 2 的优势为负,抑制这些 token 的概率</li><li>Token 级别的优势差异(末端 > 前端)反映了对最终质量的贡献</li></ul><hr><h2 id="GRPO-的收敛性与稳定性"><a href="#GRPO-的收敛性与稳定性" class="headerlink" title="GRPO 的收敛性与稳定性"></a>GRPO 的收敛性与稳定性</h2><h3 id="方差分析"><a href="#方差分析" class="headerlink" title="方差分析"></a>方差分析</h3><p>GRPO 相比 PPO 的方差优势来自于:</p><ol><li><p><strong>组内相对比较</strong>:</p><ul><li>PPO 使用单个序列的价值函数差分 $\delta_t$,仅与该序列相关</li><li>GRPO 使用组内所有序列的平均,$E[\text{Var}(A_t)]$ 更小</li></ul></li><li><p><strong>不依赖模型准确性</strong>:</p><ul><li>PPO 的 GAE 依赖 Critic 的准确性</li><li>GRPO 直接使用真实奖励,无模型误差</li></ul></li><li><p><strong>折扣分配的效果</strong>:</p><ul><li>给予末端更高权重,减少信号的不确定性</li></ul></li></ol><h3 id="收敛性证明"><a href="#收敛性证明" class="headerlink" title="收敛性证明"></a>收敛性证明</h3><p>GRPO 的收敛性可以通过以下论证:</p><ol><li><strong>策略改进</strong>:每次更新时,$\mathbb{E}[r_t A_t] \geq 0$ 当 $A_t > 0$</li><li><strong>KL 约束限制</strong>:防止策略离开信任域(trust region)</li><li><strong>优势无偏估计</strong>:组内比较提供无偏的相对信号</li></ol><p><strong>定理(简化版)</strong>:在合理的假设下,GRPO 的策略序列收敛到满足 KL 约束的最优策略。</p><hr><h2 id="GRPO-在实践中的表现"><a href="#GRPO-在实践中的表现" class="headerlink" title="GRPO 在实践中的表现"></a>GRPO 在实践中的表现</h2><h3 id="DeepSeek-的实验结果"><a href="#DeepSeek-的实验结果" class="headerlink" title="DeepSeek 的实验结果"></a>DeepSeek 的实验结果</h3><p>根据 DeepSeek 的研究报告,GRPO 在以下指标上表现出色:</p><h4 id="数学推理任务(GSM8K-MATH)"><a href="#数学推理任务(GSM8K-MATH)" class="headerlink" title="数学推理任务(GSM8K/MATH)"></a>数学推理任务(GSM8K/MATH)</h4><ul><li><strong>收敛速度</strong>:更快(通常 10-50% 的步数)</li><li><strong>最终效果</strong>:相当或优于 PPO(特别是在数学问题上)</li><li><strong>训练稳定性</strong>:更稳定,奖励波动更小</li></ul><h4 id="通用能力(基准测试)"><a href="#通用能力(基准测试)" class="headerlink" title="通用能力(基准测试)"></a>通用能力(基准测试)</h4><ul><li><strong>不会显著退化</strong>:通过 KL 约束保持基础能力</li><li><strong>任务特定性强</strong>:在目标任务上显著改进,日常闲聊改进不大</li><li><strong>内存节省</strong>:相比 PPO 节省 20-30% 的显存(无需 Critic)</li></ul><h3 id="什么情况下-GRPO-更好"><a href="#什么情况下-GRPO-更好" class="headerlink" title="什么情况下 GRPO 更好"></a>什么情况下 GRPO 更好</h3><ol><li><strong>有明确的 RM 评分</strong>:任务有定量的奖励信号(数学、代码、写作评分等)</li><li><strong>大规模模型</strong>:模型越大,Critic 的省省成本越显著</li><li><strong>单轮生成</strong>:适合生成单个长序列的场景(如数学推导、代码)</li></ol><h3 id="GRPO-可能不如-PPO-的情况"><a href="#GRPO-可能不如-PPO-的情况" class="headerlink" title="GRPO 可能不如 PPO 的情况"></a>GRPO 可能不如 PPO 的情况</h3><ol><li><strong>奖励信号不清晰</strong>:RM 评分方差大、不稳定</li><li><strong>组大小受限</strong>:显存不足,无法采样足够大的组</li><li><strong>多步交互任务</strong>:需要按步骤获得反馈的任务</li></ol><hr><h2 id="常见问题与解答"><a href="#常见问题与解答" class="headerlink" title="常见问题与解答"></a>常见问题与解答</h2><h3 id="Q1-GRPO-是否需要参考模型?"><a href="#Q1-GRPO-是否需要参考模型?" class="headerlink" title="Q1: GRPO 是否需要参考模型?"></a>Q1: GRPO 是否需要参考模型?</h3><p><strong>A</strong>: 是的。GRPO 需要参考模型来计算 KL 约束,防止策略漂移。虽然 GRPO 不需要 Critic,但仍需要 Reference 模型(通常是 SFT 模型的冻结副本)。</p><h3 id="Q2-GRPO-的组大小-G-应该怎么选?"><a href="#Q2-GRPO-的组大小-G-应该怎么选?" class="headerlink" title="Q2: GRPO 的组大小 G 应该怎么选?"></a>Q2: GRPO 的组大小 G 应该怎么选?</h3><p><strong>A</strong>: </p><ul><li>最小值:$G \geq 4$(否则相对优势不可靠)</li><li>推荐值:$G = 8 \sim 16$(平衡准确性和效率)</li><li>最大值:受显存限制,通常 $G \leq 32$</li></ul><p>根据经验,$G$ 增加一倍通常能稍微改进效果但计算成本翻倍。</p><h3 id="Q3-GRPO-会过度拟合组内序列吗?"><a href="#Q3-GRPO-会过度拟合组内序列吗?" class="headerlink" title="Q3: GRPO 会过度拟合组内序列吗?"></a>Q3: GRPO 会过度拟合组内序列吗?</h3><p><strong>A</strong>: 有可能。如果组大小 $G$ 太小,相对优势的估计可能不稳定。解决方案:</p><ol><li>增加 $G$ 的大小</li><li>跨 batch 计算优势(不仅限于单个 batch 内的序列)</li><li>加入额外的正则化</li></ol><h3 id="Q4-GRPO-的-KL-散度约束怎么调?"><a href="#Q4-GRPO-的-KL-散度约束怎么调?" class="headerlink" title="Q4: GRPO 的 KL 散度约束怎么调?"></a>Q4: GRPO 的 KL 散度约束怎么调?</h3><p><strong>A</strong>: 观察 KL 散度的值:</p><ul><li>若 KL > 0.1:$\beta$ 太小,增大 $\beta$</li><li>若 KL < 0.01:$\beta$ 太大,减小 $\beta$</li><li>目标范围:0.01-0.05</li></ul><h3 id="Q5-GRPO-vs-PPO,我应该选哪个?"><a href="#Q5-GRPO-vs-PPO,我应该选哪个?" class="headerlink" title="Q5: GRPO vs PPO,我应该选哪个?"></a>Q5: GRPO vs PPO,我应该选哪个?</h3><p><strong>A</strong>: 根据以下决策树:</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">计算资源充足?</span><br><span class="line"> ├─ 是 → PPO(通用性强,理论支持好)</span><br><span class="line"> └─ 否 → GRPO(节省 Critic 显存)</span><br><span class="line"></span><br><span class="line">任务类型?</span><br><span class="line"> ├─ 数学/代码 → GRPO 可能更好</span><br><span class="line"> ├─ 通用闲聊 → PPO 或 GRPO 都可</span><br><span class="line"> └─ 多步交互 → PPO 更安全</span><br><span class="line"></span><br><span class="line">显存约束是否严格?</span><br><span class="line"> ├─ 是 → GRPO</span><br><span class="line"> └─ 否 → PPO 或 GRPO</span><br></pre></td></tr></table></figure><hr><h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>GRPO 代表了 LLM RLHF 训练的一个重要方向,其核心贡献包括:</p><ol><li><strong>简化优势计算</strong>:从复杂的 GAE 到简单的组内相对比较</li><li><strong>减少计算负担</strong>:不需要额外的 Critic 模型</li><li><strong>提高稳定性</strong>:组内比较天然降低方差</li><li><strong>实用高效</strong>:在大规模模型上表现更好</li></ol><p><strong>关键理解</strong>:</p><ul><li>GRPO 仍然是 token 级别的重要性采样和裁剪(与 PPO 相同)</li><li>GRPO 的创新在于优势的计算方式(组内相对而非 GAE)</li><li>GRPO 和 PPO 是互补的方法,根据具体场景选择</li></ul><p><strong>推荐阅读</strong>:</p><ul><li><a href="https://arxiv.org/abs/2402.10240">Efficient Stage-wise Pretraining for Large Language Models (GRPO 原论文)</a></li><li><a href="https://murphypei.github.io/blog/2024/07/llm-rlhf-ppo.html">配合 PPO 博文</a> 理解两种算法的异同</li></ul>]]></content>
<summary type="html">
<p>在<a href="https://murphypei.github.io/blog/2025/05/llm-ppo.html">之前的 PPO 细节探讨</a>中,我们详细介绍了 PPO 算法和 GAE 优势估计。今天我们来深入探讨 GRPO(Group Relative Policy Optimization)算法,这是 PPO 在大语言模型训练中的一个重要改进版本。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="强化学习" scheme="https://murphypei.github.io/tags/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0/"/>
<category term="rlhf" scheme="https://murphypei.github.io/tags/rlhf/"/>
<category term="ppo" scheme="https://murphypei.github.io/tags/ppo/"/>
<category term="grpo" scheme="https://murphypei.github.io/tags/grpo/"/>
<category term="advantage" scheme="https://murphypei.github.io/tags/advantage/"/>
</entry>
<entry>
<title>LLM 推理: KV Cache 原理与优化</title>
<link href="https://murphypei.github.io/blog/2025/07/kv-cache.html"/>
<id>https://murphypei.github.io/blog/2025/07/kv-cache.html</id>
<published>2025-07-01T03:56:52.000Z</published>
<updated>2026-01-07T08:41:01.086Z</updated>
<content type="html"><![CDATA[<p>继续梳理 LLM 知识,这次写 KV Cache。KV Cache 是大语言模型推理过程中的重要优化技术,能够显著减少计算量,提高推理速度。本文将从 Attention 计算原理出发,详细推导 KV Cache 的数学等价性,并分析其优化效果。</p><span id="more"></span><h2 id="引言"><a href="#引言" class="headerlink" title="引言"></a>引言</h2><p>在大语言模型的推理过程中,生成式推理(Generative Inference)是一个自回归过程,模型需要逐个生成token。在这个过程中,大量的计算被重复执行,特别是Attention机制中的Key和Value矩阵计算。KV Cache技术通过缓存这些中间结果,避免了重复计算,从而显著提高了推理效率。</p><p>本文将详细介绍KV Cache的工作原理,从Attention计算的数学原理出发,推导其等价性,并分析其在实际应用中的优化效果。</p><h2 id="Attention机制回顾"><a href="#Attention机制回顾" class="headerlink" title="Attention机制回顾"></a>Attention机制回顾</h2><h3 id="标准Attention计算"><a href="#标准Attention计算" class="headerlink" title="标准Attention计算"></a>标准Attention计算</h3><p>在Transformer的Attention机制中,对于输入序列 $X = [x_1, x_2, …, x_n]$,Attention的计算过程如下:</p><p><strong>1. 线性变换</strong></p><script type="math/tex; mode=display">Q = XW_Q, \quad K = XW_K, \quad V = XW_V</script><p>其中:</p><ul><li>$W_Q, W_K, W_V$ 是查询、键、值的权重矩阵</li><li>$Q, K, V$ 分别是查询、键、值的矩阵表示</li></ul><p><strong>2. Attention计算</strong></p><script type="math/tex; mode=display">\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V</script><p>其中 $d_k$ 是键向量的维度。</p><p><strong>3. 分步展开</strong><br>对于第 $i$ 个位置的输出,可以表示为:</p><script type="math/tex; mode=display">O_i = \sum_{j=1}^{n} \alpha_{ij} v_j</script><p>其中:</p><script type="math/tex; mode=display">\alpha_{ij} = \frac{\exp\left(\frac{q_i^T k_j}{\sqrt{d_k}}\right)}{\sum_{l=1}^{n} \exp\left(\frac{q_i^T k_l}{\sqrt{d_k}}\right)}</script><h3 id="自回归生成过程"><a href="#自回归生成过程" class="headerlink" title="自回归生成过程"></a>自回归生成过程</h3><p>在生成式推理中,模型逐个生成token。假设当前已经生成了 $t$ 个token,要生成第 $t+1$ 个token:</p><p><strong>输入序列</strong>:$X_{1:t} = [x_1, x_2, …, x_t]$</p><p><strong>计算过程</strong>:</p><ol><li>计算 $Q<em>{1:t}, K</em>{1:t}, V_{1:t}$</li><li>计算Attention输出</li><li>生成下一个token $x_{t+1}$</li><li>重复上述过程</li></ol><p><strong>问题</strong>:每次生成新token时,都需要重新计算整个序列的 $K$ 和 $V$ 矩阵,这导致了大量的重复计算。</p><h2 id="KV-Cache的核心思想"><a href="#KV-Cache的核心思想" class="headerlink" title="KV Cache的核心思想"></a>KV Cache的核心思想</h2><h3 id="基本概念"><a href="#基本概念" class="headerlink" title="基本概念"></a>基本概念</h3><p>KV Cache的核心思想是:<strong>缓存已经计算过的Key和Value矩阵,避免重复计算</strong>。</p><p><strong>缓存内容</strong>:</p><ul><li>$K_{cache} = [K_1, K_2, …, K_t]$:已生成token的Key矩阵</li><li>$V_{cache} = [V_1, V_2, …, V_t]$:已生成token的Value矩阵</li></ul><p><strong>增量更新</strong>:</p><ul><li>生成新token $x<em>{t+1}$ 时,只计算 $K</em>{t+1}$ 和 $V_{t+1}$</li><li>将新的Key和Value追加到缓存中</li><li>使用完整的缓存进行Attention计算</li></ul><h3 id="数学等价性推导"><a href="#数学等价性推导" class="headerlink" title="数学等价性推导"></a>数学等价性推导</h3><h4 id="1-标准计算的数学表示"><a href="#1-标准计算的数学表示" class="headerlink" title="1. 标准计算的数学表示"></a>1. 标准计算的数学表示</h4><p>在标准计算中,生成第 $t+1$ 个token时:</p><p><strong>输入</strong>:$X<em>{1:t+1} = [x_1, x_2, …, x_t, x</em>{t+1}]$</p><p><strong>计算过程</strong>:</p><script type="math/tex; mode=display">Q_{1:t+1} = X_{1:t+1}W_Q \\K_{1:t+1} = X_{1:t+1}W_K \\V_{1:t+1} = X_{1:t+1}W_V</script><p><strong>Attention输出</strong>:</p><script type="math/tex; mode=display">O_{t+1} = \sum_{j=1}^{t+1} \alpha_{(t+1)j} v_j</script><p>其中:</p><script type="math/tex; mode=display">\alpha_{(t+1)j} = \frac{\exp\left(\frac{q_{t+1}^T k_j}{\sqrt{d_k}}\right)}{\sum_{l=1}^{t+1} \exp\left(\frac{q_{t+1}^T k_l}{\sqrt{d_k}}\right)}</script><h4 id="2-KV-Cache的计算表示"><a href="#2-KV-Cache的计算表示" class="headerlink" title="2. KV Cache的计算表示"></a>2. KV Cache的计算表示</h4><p>在KV Cache中,生成第 $t+1$ 个token时:</p><p><strong>缓存状态</strong>:</p><ul><li>$K_{cache} = [K_1, K_2, …, K_t]$</li><li>$V_{cache} = [V_1, V_2, …, V_t]$</li></ul><p><strong>增量计算</strong>:</p><script type="math/tex; mode=display">q_{t+1} = x_{t+1}W_Q \\k_{t+1} = x_{t+1}W_K \\v_{t+1} = x_{t+1}W_V</script><p><strong>更新缓存</strong>:</p><script type="math/tex; mode=display">K_{cache}^{new} = [K_{cache}, k_{t+1}] = [K_1, K_2, ..., K_t, K_{t+1}] \\V_{cache}^{new} = [V_{cache}, v_{t+1}] = [V_1, V_2, ..., V_t, V_{t+1}]</script><p><strong>Attention计算</strong>:</p><script type="math/tex; mode=display">O_{t+1} = \sum_{j=1}^{t+1} \alpha_{(t+1)j} v_j</script><p>其中:</p><script type="math/tex; mode=display">\alpha_{(t+1)j} = \frac{\exp\left(\frac{q_{t+1}^T k_j}{\sqrt{d_k}}\right)}{\sum_{l=1}^{t+1} \exp\left(\frac{q_{t+1}^T k_l}{\sqrt{d_k}}\right)}</script><blockquote><p>这里注意重点,$O<em>{t+1}$,只和 $\alpha</em>{(t+1)j}$ 以及 $v<em>{i:t+1}$ 有关。而 $\alpha</em>{(t+1)j}$ 只和 $q<em>{t+1}$ 以及 $k</em>{i:t+1}$ 有关,这也是为何需要 KV 缓存,而不需要 Q 缓存的原因。这是 Attention 计算的核心,也是实现 KV cache 的关键。</p></blockquote><h4 id="3-等价性证明"><a href="#3-等价性证明" class="headerlink" title="3. 等价性证明"></a>3. 等价性证明</h4><p><strong>矩阵运算的线性性质</strong>:</p><p>对于线性变换 $K = XW_K$,由于矩阵乘法的线性性质:</p><script type="math/tex; mode=display">K_{1:t+1} = X_{1:t+1}W_K = [X_{1:t}, x_{t+1}]W_K = [X_{1:t}W_K, x_{t+1}W_K] = [K_{1:t}, K_{t+1}]</script><p>同理:</p><script type="math/tex; mode=display">V_{1:t+1} = [V_{1:t}, V_{t+1}]</script><p><strong>Attention计算的等价性</strong>:</p><p>在标准计算中:</p><script type="math/tex; mode=display">\text{Attention}(Q_{1:t+1}, K_{1:t+1}, V_{1:t+1}) = \text{softmax}\left(\frac{Q_{1:t+1}K_{1:t+1}^T}{\sqrt{d_k}}\right)V_{1:t+1}</script><p>在KV Cache中:</p><script type="math/tex; mode=display">\text{Attention}(q_{t+1}, [K_{cache}, k_{t+1}], [V_{cache}, v_{t+1}]) = \text{softmax}\left(\frac{q_{t+1}[K_{cache}, k_{t+1}]^T}{\sqrt{d_k}}\right)[V_{cache}, v_{t+1}]</script><p>由于:</p><ul><li>$[K<em>{cache}, k</em>{t+1}] = K_{1:t+1}$</li><li>$[V<em>{cache}, v</em>{t+1}] = V_{1:t+1}$</li><li>$q<em>{t+1}$ 是 $Q</em>{1:t+1}$ 的最后一行</li></ul><p>因此,两种计算方式在数学上完全等价。</p><h3 id="计算复杂度分析"><a href="#计算复杂度分析" class="headerlink" title="计算复杂度分析"></a>计算复杂度分析</h3><h4 id="1-标准计算复杂度"><a href="#1-标准计算复杂度" class="headerlink" title="1. 标准计算复杂度"></a>1. 标准计算复杂度</h4><p><strong>第 $t+1$ 步的计算量</strong>:</p><ul><li>线性变换:$O((t+1) \times d_{model} \times d_k)$</li><li>Attention计算:$O((t+1)^2 \times d_k)$</li><li>总复杂度:$O((t+1) \times d_{model} \times d_k + (t+1)^2 \times d_k)$</li></ul><p><strong>累积计算量</strong>(生成 $n$ 个token):</p><script type="math/tex; mode=display">\sum_{t=1}^{n} O(t \times d_{model} \times d_k + t^2 \times d_k) = O(n^2 \times d_{model} \times d_k + n^3 \times d_k)</script><h4 id="2-KV-Cache计算复杂度"><a href="#2-KV-Cache计算复杂度" class="headerlink" title="2. KV Cache计算复杂度"></a>2. KV Cache计算复杂度</h4><p><strong>第 $t+1$ 步的计算量</strong>:</p><ul><li>线性变换:$O(d_{model} \times d_k)$(只计算新token)</li><li>Attention计算:$O((t+1)^2 \times d_k)$</li><li>总复杂度:$O(d_{model} \times d_k + (t+1)^2 \times d_k)$</li></ul><p><strong>累积计算量</strong>(生成 $n$ 个token):</p><script type="math/tex; mode=display">\sum_{t=1}^{n} O(d_{model} \times d_k + t^2 \times d_k) = O(n \times d_{model} \times d_k + n^3 \times d_k)</script><h4 id="3-优化效果"><a href="#3-优化效果" class="headerlink" title="3. 优化效果"></a>3. 优化效果</h4><p><strong>计算量减少</strong>:</p><ul><li>线性变换部分:从 $O(n^2 \times d<em>{model} \times d_k)$ 减少到 $O(n \times d</em>{model} \times d_k)$</li><li>减少比例:$O(n)$ 倍</li></ul><p><strong>实际效果</strong>:</p><ul><li>对于长序列生成,计算量减少显著</li><li>特别是在生成较长文本时,优化效果明显</li></ul><h2 id="KV-Cache的实现细节"><a href="#KV-Cache的实现细节" class="headerlink" title="KV Cache的实现细节"></a>KV Cache的实现细节</h2><h3 id="内存管理"><a href="#内存管理" class="headerlink" title="内存管理"></a>内存管理</h3><h4 id="1-缓存结构"><a href="#1-缓存结构" class="headerlink" title="1. 缓存结构"></a>1. 缓存结构</h4><p><strong>缓存格式</strong>:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 缓存结构示例</span></span><br><span class="line">kv_cache = {</span><br><span class="line"> <span class="string">'key'</span>: torch.zeros(seq_len, num_layers, num_heads, head_dim),</span><br><span class="line"> <span class="string">'value'</span>: torch.zeros(seq_len, num_layers, num_heads, head_dim)</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p><p><strong>内存布局</strong>:</p><ul><li>按层(layer)组织</li><li>每层包含多个注意力头(attention heads)</li><li>支持动态扩展</li></ul><h4 id="2-内存优化策略"><a href="#2-内存优化策略" class="headerlink" title="2. 内存优化策略"></a>2. 内存优化策略</h4><p><strong>预分配策略</strong>:</p><ul><li>根据最大序列长度预分配内存</li><li>避免频繁的内存重新分配</li></ul><p><strong>内存复用</strong>:</p><ul><li>在推理过程中复用缓存空间</li><li>减少内存碎片</li></ul><h3 id="增量更新机制"><a href="#增量更新机制" class="headerlink" title="增量更新机制"></a>增量更新机制</h3><h4 id="1-缓存更新"><a href="#1-缓存更新" class="headerlink" title="1. 缓存更新"></a>1. 缓存更新</h4><p><strong>更新流程</strong>:</p><ol><li>计算新token的Key和Value</li><li>将新的Key和Value追加到缓存</li><li>更新缓存的有效长度</li></ol><p><strong>代码示例</strong>:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">update_kv_cache</span>(<span class="params">kv_cache, new_k, new_v, layer_idx</span>):</span><br><span class="line"> <span class="comment"># 追加新的Key和Value到缓存</span></span><br><span class="line"> kv_cache[<span class="string">'key'</span>][layer_idx] = torch.cat([kv_cache[<span class="string">'key'</span>][layer_idx], new_k], dim=<span class="number">0</span>)</span><br><span class="line"> kv_cache[<span class="string">'value'</span>][layer_idx] = torch.cat([kv_cache[<span class="string">'value'</span>][layer_idx], new_v], dim=<span class="number">0</span>)</span><br></pre></td></tr></table></figure></p><h4 id="2-注意力计算"><a href="#2-注意力计算" class="headerlink" title="2. 注意力计算"></a>2. 注意力计算</h4><p><strong>使用缓存的Attention计算</strong>:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">attention_with_cache</span>(<span class="params">query, kv_cache, layer_idx</span>):</span><br><span class="line"> <span class="comment"># 获取缓存的Key和Value</span></span><br><span class="line"> cached_k = kv_cache[<span class="string">'key'</span>][layer_idx]</span><br><span class="line"> cached_v = kv_cache[<span class="string">'value'</span>][layer_idx]</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算注意力分数</span></span><br><span class="line"> scores = torch.matmul(query, cached_k.transpose(-<span class="number">2</span>, -<span class="number">1</span>)) / math.sqrt(d_k)</span><br><span class="line"> attention_weights = torch.softmax(scores, dim=-<span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算输出</span></span><br><span class="line"> output = torch.matmul(attention_weights, cached_v)</span><br><span class="line"> <span class="keyword">return</span> output</span><br></pre></td></tr></table></figure></p><h3 id="多头注意力处理"><a href="#多头注意力处理" class="headerlink" title="多头注意力处理"></a>多头注意力处理</h3><h4 id="1-多头并行计算"><a href="#1-多头并行计算" class="headerlink" title="1. 多头并行计算"></a>1. 多头并行计算</h4><p><strong>缓存组织</strong>:</p><ul><li>每个注意力头独立缓存Key和Value</li><li>支持并行计算</li></ul><p><strong>计算优化</strong>:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">multi_head_attention_with_cache</span>(<span class="params">query, kv_cache, layer_idx</span>):</span><br><span class="line"> batch_size, num_heads, seq_len, head_dim = query.shape</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 并行计算所有注意力头</span></span><br><span class="line"> outputs = []</span><br><span class="line"> <span class="keyword">for</span> head_idx <span class="keyword">in</span> <span class="built_in">range</span>(num_heads):</span><br><span class="line"> head_query = query[:, head_idx, :, :]</span><br><span class="line"> head_k = kv_cache[<span class="string">'key'</span>][layer_idx][:, head_idx, :, :]</span><br><span class="line"> head_v = kv_cache[<span class="string">'value'</span>][layer_idx][:, head_idx, :, :]</span><br><span class="line"> </span><br><span class="line"> head_output = attention_with_cache(head_query, head_k, head_v)</span><br><span class="line"> outputs.append(head_output)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> torch.cat(outputs, dim=<span class="number">1</span>)</span><br></pre></td></tr></table></figure></p><h4 id="2-内存布局优化"><a href="#2-内存布局优化" class="headerlink" title="2. 内存布局优化"></a>2. 内存布局优化</h4><p><strong>连续内存布局</strong>:</p><ul><li>将多头数据存储在连续内存中</li><li>提高缓存命中率</li></ul><p><strong>批处理优化</strong>:</p><ul><li>支持批量处理多个序列</li><li>减少内存访问开销</li></ul>]]></content>
<summary type="html">
<p>继续梳理 LLM 知识,这次写 KV Cache。KV Cache 是大语言模型推理过程中的重要优化技术,能够显著减少计算量,提高推理速度。本文将从 Attention 计算原理出发,详细推导 KV Cache 的数学等价性,并分析其优化效果。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="KV Cache" scheme="https://murphypei.github.io/tags/KV-Cache/"/>
<category term="Attention" scheme="https://murphypei.github.io/tags/Attention/"/>
<category term="推理优化" scheme="https://murphypei.github.io/tags/%E6%8E%A8%E7%90%86%E4%BC%98%E5%8C%96/"/>
</entry>
<entry>
<title>LLM:RAG 中的文本检索技术</title>
<link href="https://murphypei.github.io/blog/2025/06/text-retriveval.html"/>
<id>https://murphypei.github.io/blog/2025/06/text-retriveval.html</id>
<published>2025-06-29T19:19:02.000Z</published>
<updated>2026-01-07T08:41:01.086Z</updated>
<content type="html"><![CDATA[<p>继续准备 LLM 面试知识,这次写文本检索技术。文本检索是 RAG(检索增强生成)系统的核心组件,也是面试中经常被问到的问题。本文将详细介绍稠密向量检索、稀疏向量检索、BM25算法以及混合检索策略,帮助理解现代文本检索系统的技术原理。</p><span id="more"></span><h2 id="引言"><a href="#引言" class="headerlink" title="引言"></a>引言</h2><p>在当今的信息检索领域,随着人工智能和自然语言处理技术的发展,文本检索技术已经从传统的基于关键词匹配的方法,发展到了基于深度学习的语义检索方法。文本检索是RAG(Retrieval-Augmented Generation)系统的核心组件,它负责从大规模文档集合中检索出与用户查询最相关的文档片段,为后续的生成模型提供上下文信息。</p><p>本文将详细介绍三种主要的文本检索方法:<strong>稠密向量检索(Dense Retrieval)</strong>、<strong>稀疏向量检索(Sparse Retrieval)</strong>和<strong>BM25算法</strong>,以及它们的混合使用策略。</p><h2 id="稠密向量检索(Dense-Retrieval)"><a href="#稠密向量检索(Dense-Retrieval)" class="headerlink" title="稠密向量检索(Dense Retrieval)"></a>稠密向量检索(Dense Retrieval)</h2><h3 id="稠密向量检索的基本原理"><a href="#稠密向量检索的基本原理" class="headerlink" title="稠密向量检索的基本原理"></a>稠密向量检索的基本原理</h3><p>稠密向量检索是一种基于深度学习的检索方法,它通过将文本转换为高维空间中的连续向量表示,然后使用向量相似度(如余弦相似度、欧氏距离)来检索相关文档。</p><p><strong>核心思想</strong>:</p><ul><li>将查询和文档都映射到同一个高维向量空间</li><li>通过计算向量间的相似度来衡量相关性</li><li>能够捕捉文本的深层语义信息</li></ul><h3 id="技术实现"><a href="#技术实现" class="headerlink" title="技术实现"></a>技术实现</h3><h4 id="1-文本编码"><a href="#1-文本编码" class="headerlink" title="1. 文本编码"></a>1. 文本编码</h4><p>稠密向量检索通常使用预训练的语言模型(如BERT、T5、Sentence-BERT等)对文本进行编码:</p><script type="math/tex; mode=display">\mathbf{q} = \text{Encoder}(query) \\\mathbf{d} = \text{Encoder}(document)</script><p>其中:</p><ul><li>$\mathbf{q}$ 是查询的向量表示</li><li>$\mathbf{d}$ 是文档的向量表示</li><li>$\text{Encoder}$ 是预训练的语言模型</li></ul><h4 id="2-相似度计算"><a href="#2-相似度计算" class="headerlink" title="2. 相似度计算"></a>2. 相似度计算</h4><p>常用的相似度计算方法包括:</p><p><strong>余弦相似度</strong>:</p><script type="math/tex; mode=display">\text{sim}(\mathbf{q}, \mathbf{d}) = \frac{\mathbf{q} \cdot \mathbf{d}}{|\mathbf{q}| \cdot |\mathbf{d}|}</script><p><strong>点积相似度</strong>:</p><script type="math/tex; mode=display">\text{sim}(\mathbf{q}, \mathbf{d}) = \mathbf{q} \cdot \mathbf{d}</script><p><strong>欧氏距离</strong>:</p><script type="math/tex; mode=display">\text{dist}(\mathbf{q}, \mathbf{d}) = \sqrt{\sum_{i=1}^{n} (q_i - d_i)^2}</script><h4 id="3-索引和检索"><a href="#3-索引和检索" class="headerlink" title="3. 索引和检索"></a>3. 索引和检索</h4><p><strong>向量索引</strong>:</p><ul><li>使用FAISS、Annoy、HNSW等向量索引库</li><li>支持高效的近似最近邻搜索</li><li>能够处理百万级别的向量检索</li></ul><p><strong>检索流程</strong>:</p><ol><li>将查询编码为向量</li><li>在向量索引中搜索最相似的文档向量</li><li>返回相似度最高的文档</li></ol><h3 id="稠密向量检索的优势和局限性"><a href="#稠密向量检索的优势和局限性" class="headerlink" title="稠密向量检索的优势和局限性"></a>稠密向量检索的优势和局限性</h3><p><strong>优势</strong>:</p><ol><li><strong>语义理解能力强</strong>:能够理解查询和文档的深层语义</li><li><strong>处理同义词和近义词</strong>:即使词汇不完全匹配,也能找到相关文档</li><li><strong>支持复杂查询</strong>:能够处理自然语言形式的查询</li></ol><p><strong>局限性</strong>:</p><ol><li><strong>计算成本高</strong>:需要深度学习模型进行编码</li><li><strong>索引规模限制</strong>:在大规模数据集上可能遇到性能瓶颈</li><li><strong>对训练数据敏感</strong>:检索效果依赖于编码模型的训练质量</li></ol><h3 id="应用场景"><a href="#应用场景" class="headerlink" title="应用场景"></a>应用场景</h3><ul><li><strong>智能问答系统</strong>:从知识库中检索相关答案</li><li><strong>推荐系统</strong>:基于内容相似性推荐相关文档</li><li><strong>语义搜索</strong>:理解用户意图的搜索引擎</li></ul><h2 id="稀疏向量检索(Sparse-Retrieval)"><a href="#稀疏向量检索(Sparse-Retrieval)" class="headerlink" title="稀疏向量检索(Sparse Retrieval)"></a>稀疏向量检索(Sparse Retrieval)</h2><h3 id="稀疏向量检索的基本原理"><a href="#稀疏向量检索的基本原理" class="headerlink" title="稀疏向量检索的基本原理"></a>稀疏向量检索的基本原理</h3><p>稀疏向量检索是基于传统信息检索模型的方法,它使用词袋模型(Bag of Words)将文本表示为稀疏向量,并通过计算词频-逆文档频率(TF-IDF)来评估文档与查询的相关性。</p><p><strong>核心思想</strong>:</p><ul><li>将文本表示为高维稀疏向量</li><li>每个维度对应词汇表中的一个词</li><li>通过统计方法计算词的重要性</li></ul><h3 id="技术实现-1"><a href="#技术实现-1" class="headerlink" title="技术实现"></a>技术实现</h3><h4 id="1-TF-IDF计算"><a href="#1-TF-IDF计算" class="headerlink" title="1. TF-IDF计算"></a>1. TF-IDF计算</h4><p><strong>词频(Term Frequency, TF)</strong>:</p><script type="math/tex; mode=display">\text{TF}(t, d) = \frac{\text{词 } t \text{ 在文档 } d \text{ 中的出现次数}}{\text{文档 } d \text{ 的总词数}}</script><p><strong>逆文档频率(Inverse Document Frequency, IDF)</strong>:</p><script type="math/tex; mode=display">\text{IDF}(t) = \log \frac{\text{总文档数}}{\text{包含词 } t \text{ 的文档数}}</script><p><strong>TF-IDF权重</strong>:</p><script type="math/tex; mode=display">\text{TF-IDF}(t, d) = \text{TF}(t, d) \times \text{IDF}(t)</script><h4 id="2-向量构建"><a href="#2-向量构建" class="headerlink" title="2. 向量构建"></a>2. 向量构建</h4><p>文档向量 $\mathbf{d}$ 的构建:</p><script type="math/tex; mode=display">\mathbf{d} = [\text{TF-IDF}(t_1, d), \text{TF-IDF}(t_2, d), ..., \text{TF-IDF}(t_n, d)]</script><p>其中 $t_1, t_2, …, t_n$ 是词汇表中的所有词。</p><h4 id="3-相似度计算"><a href="#3-相似度计算" class="headerlink" title="3. 相似度计算"></a>3. 相似度计算</h4><p><strong>余弦相似度</strong>:</p><script type="math/tex; mode=display">\text{sim}(\mathbf{q}, \mathbf{d}) = \frac{\mathbf{q} \cdot \mathbf{d}}{|\mathbf{q}| \cdot |\mathbf{d}|}</script><p><strong>点积相似度</strong>:</p><script type="math/tex; mode=display">\text{sim}(\mathbf{q}, \mathbf{d}) = \mathbf{q} \cdot \mathbf{d}</script><h3 id="稀疏向量检索的优势和局限性"><a href="#稀疏向量检索的优势和局限性" class="headerlink" title="稀疏向量检索的优势和局限性"></a>稀疏向量检索的优势和局限性</h3><p><strong>优势</strong>:</p><ol><li><strong>计算效率高</strong>:基于统计方法,计算速度快</li><li><strong>可解释性强</strong>:能够明确知道哪些词贡献了相关性</li><li><strong>处理大规模数据</strong>:能够高效处理大规模文档集合</li></ol><p><strong>局限性</strong>:</p><ol><li><strong>语义理解能力弱</strong>:无法处理同义词和近义词</li><li><strong>词汇匹配限制</strong>:需要查询词在文档中实际出现</li><li><strong>无法处理语义相似性</strong>:无法理解词汇的深层含义</li></ol><h3 id="应用场景-1"><a href="#应用场景-1" class="headerlink" title="应用场景"></a>应用场景</h3><ul><li><strong>传统搜索引擎</strong>:基于关键词的网页搜索</li><li><strong>文档检索系统</strong>:从文档库中检索相关文档</li><li><strong>信息过滤</strong>:基于关键词的信息过滤</li></ul><h2 id="BM25算法"><a href="#BM25算法" class="headerlink" title="BM25算法"></a>BM25算法</h2><h3 id="BM25算法的基本原理"><a href="#BM25算法的基本原理" class="headerlink" title="BM25算法的基本原理"></a>BM25算法的基本原理</h3><p>BM25(Best Matching 25)是一种经典的信息检索算法,它是TF-IDF算法的改进版,通过引入词频(TF)和文档频率(DF)的函数来计算文档与查询的相关性得分。</p><p><strong>核心思想</strong>:</p><ul><li>在TF-IDF基础上引入文档长度归一化</li><li>使用词频饱和函数处理高频词</li><li>通过参数调整优化检索效果</li></ul><h3 id="数学公式"><a href="#数学公式" class="headerlink" title="数学公式"></a>数学公式</h3><p>BM25算法的核心公式:</p><script type="math/tex; mode=display">\text{BM25}(Q, D) = \sum_{i=1}^{n} \text{IDF}(q_i) \cdot \frac{f(q_i, D) \cdot (k_1 + 1)}{f(q_i, D) + k_1 \cdot (1 - b + b \cdot \frac{|D|}{\text{avgdl}})}</script><p>其中:</p><ul><li>$Q$ 是查询,包含词 $q_1, q_2, …, q_n$</li><li>$D$ 是文档</li><li>$f(q_i, D)$ 是词 $q_i$ 在文档 $D$ 中的词频</li><li>$|D|$ 是文档 $D$ 的长度</li><li>$\text{avgdl}$ 是文档集合的平均长度</li><li>$k_1$ 和 $b$ 是调节参数</li></ul><p><strong>IDF计算</strong>:</p><script type="math/tex; mode=display">\text{IDF}(q_i) = \log \frac{N - n(q_i) + 0.5}{n(q_i) + 0.5}</script><p>其中:</p><ul><li>$N$ 是文档集合的总文档数</li><li>$n(q_i)$ 是包含词 $q_i$ 的文档数</li></ul><h3 id="参数调节"><a href="#参数调节" class="headerlink" title="参数调节"></a>参数调节</h3><p><strong>$k_1$ 参数</strong>:</p><ul><li>控制词频饱和程度</li><li>通常设置为 1.2-2.0</li><li>值越大,词频的影响越线性</li></ul><p><strong>$b$ 参数</strong>:</p><ul><li>控制文档长度归一化程度</li><li>取值范围为 0-1</li><li>$b=0$ 表示不进行长度归一化</li><li>$b=1$ 表示完全归一化</li></ul><h3 id="BM25算法的优势"><a href="#BM25算法的优势" class="headerlink" title="BM25算法的优势"></a>BM25算法的优势</h3><ol><li><strong>理论基础扎实</strong>:基于概率检索模型</li><li><strong>参数可调节</strong>:能够适应不同的数据集和需求</li><li><strong>计算效率高</strong>:基于统计方法,计算速度快</li><li><strong>效果稳定</strong>:在许多基准测试中表现优异</li></ol><h3 id="应用场景-2"><a href="#应用场景-2" class="headerlink" title="应用场景"></a>应用场景</h3><ul><li><strong>搜索引擎</strong>:Google、Bing等搜索引擎的核心算法</li><li><strong>文档检索</strong>:企业文档管理系统</li><li><strong>学术搜索</strong>:学术论文检索系统</li></ul><h2 id="混合检索策略"><a href="#混合检索策略" class="headerlink" title="混合检索策略"></a>混合检索策略</h2><h3 id="混合检索的基本原理"><a href="#混合检索的基本原理" class="headerlink" title="混合检索的基本原理"></a>混合检索的基本原理</h3><p>混合检索结合了稠密向量检索、稀疏向量检索和BM25算法的优势,通过多路召回和结果融合来提高检索系统的整体性能。</p><p><strong>核心思想</strong>:</p><ul><li>使用多种检索方法并行检索</li><li>通过融合算法合并检索结果</li><li>平衡准确性和召回率</li></ul><h3 id="混合检索的优势"><a href="#混合检索的优势" class="headerlink" title="混合检索的优势"></a>混合检索的优势</h3><ol><li><strong>互补性</strong>:不同方法各有优势,相互补充</li><li><strong>提高准确性</strong>:通过多路召回提高检索准确性</li><li><strong>提升召回率</strong>:增加检索结果的覆盖面</li><li><strong>适应性</strong>:能够适应不同的查询类型和场景</li></ol><h3 id="实现方法"><a href="#实现方法" class="headerlink" title="实现方法"></a>实现方法</h3><h4 id="1-多路召回"><a href="#1-多路召回" class="headerlink" title="1. 多路召回"></a>1. 多路召回</h4><p><strong>稠密向量检索</strong>:</p><ul><li>使用语义相似性进行检索</li><li>适合处理语义相关的查询</li></ul><p><strong>稀疏向量检索</strong>:</p><ul><li>使用关键词匹配进行检索</li><li>适合处理精确匹配的查询</li></ul><p><strong>BM25检索</strong>:</p><ul><li>使用传统信息检索方法</li><li>适合处理结构化查询</li></ul><h4 id="2-结果融合"><a href="#2-结果融合" class="headerlink" title="2. 结果融合"></a>2. 结果融合</h4><p><strong>RRF(Reciprocal Rank Fusion)</strong>:</p><script type="math/tex; mode=display">\text{RRF}(d) = \sum_{i=1}^{n} \frac{1}{k + \text{rank}_i(d)}</script><p>其中:</p><ul><li>$\text{rank}_i(d)$ 是文档 $d$ 在第 $i$ 个检索方法中的排名</li><li>$k$ 是调节参数,通常设置为 60</li></ul><p><strong>加权融合</strong>:</p><script type="math/tex; mode=display">\text{Score}(d) = \sum_{i=1}^{n} w_i \cdot \text{score}_i(d)</script><p>其中:</p><ul><li>$w_i$ 是第 $i$ 个检索方法的权重</li><li>$\text{score}_i(d)$ 是文档 $d$ 在第 $i$ 个检索方法中的得分</li></ul><h4 id="3-动态权重调整"><a href="#3-动态权重调整" class="headerlink" title="3. 动态权重调整"></a>3. 动态权重调整</h4><p>根据查询类型动态调整不同检索方法的权重:</p><ul><li><strong>语义查询</strong>:增加稠密向量检索的权重</li><li><strong>关键词查询</strong>:增加稀疏向量检索和BM25的权重</li><li><strong>混合查询</strong>:平衡各种方法的权重</li></ul><h3 id="系统架构"><a href="#系统架构" class="headerlink" title="系统架构"></a>系统架构</h3><p><strong>查询处理层</strong>:</p><ul><li>查询解析和预处理</li><li>查询类型识别</li><li>参数选择</li></ul><p><strong>检索层</strong>:</p><ul><li>多路并行检索</li><li>结果初步排序</li><li>去重和合并</li></ul><p><strong>融合层</strong>:</p><ul><li>结果融合算法</li><li>最终排序</li><li>结果返回</li></ul><h2 id="性能评估指标"><a href="#性能评估指标" class="headerlink" title="性能评估指标"></a>性能评估指标</h2><h3 id="检索性能指标"><a href="#检索性能指标" class="headerlink" title="检索性能指标"></a>检索性能指标</h3><p><strong>准确率(Precision)</strong>:</p><script type="math/tex; mode=display">\text{Precision} = \frac{\text{相关文档数}}{\text{检索文档数}}</script><p><strong>召回率(Recall)</strong>:</p><script type="math/tex; mode=display">\text{Recall} = \frac{\text{相关文档数}}{\text{总相关文档数}}</script><p><strong>F1分数</strong>:</p><script type="math/tex; mode=display">\text{F1} = \frac{2 \times \text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}</script><p><strong>平均精度(MAP)</strong>:</p><script type="math/tex; mode=display">\text{MAP} = \frac{1}{|Q|} \sum_{q \in Q} \text{AP}(q)</script><p>其中 $\text{AP}(q)$ 是查询 $q$ 的平均精度。</p><h3 id="效率指标"><a href="#效率指标" class="headerlink" title="效率指标"></a>效率指标</h3><p><strong>检索延迟</strong>:从查询到返回结果的时间<br><strong>吞吐量</strong>:单位时间内处理的查询数<br><strong>索引大小</strong>:索引占用的存储空间</p><h2 id="实际应用中的优化策略"><a href="#实际应用中的优化策略" class="headerlink" title="实际应用中的优化策略"></a>实际应用中的优化策略</h2><h3 id="索引优化"><a href="#索引优化" class="headerlink" title="索引优化"></a>索引优化</h3><p><strong>倒排索引</strong>:</p><ul><li>为每个词建立文档列表</li><li>支持快速的关键词查找</li><li>优化存储和查询效率</li></ul><p><strong>向量索引</strong>:</p><ul><li>使用HNSW、IVF等算法</li><li>支持高效的近似最近邻搜索</li><li>平衡精度和速度</li></ul><h3 id="查询优化"><a href="#查询优化" class="headerlink" title="查询优化"></a>查询优化</h3><p><strong>查询扩展</strong>:</p><ul><li>使用同义词扩展查询</li><li>基于用户反馈调整查询</li><li>利用查询日志优化</li></ul><p><strong>查询重写</strong>:</p><ul><li>将自然语言查询转换为结构化查询</li><li>使用查询模板提高效率</li><li>基于历史查询进行优化</li></ul><h3 id="缓存策略"><a href="#缓存策略" class="headerlink" title="缓存策略"></a>缓存策略</h3><p><strong>结果缓存</strong>:</p><ul><li>缓存热门查询的结果</li><li>使用LRU等策略管理缓存</li><li>提高响应速度</li></ul><p><strong>索引缓存</strong>:</p><ul><li>将常用索引加载到内存</li><li>使用分层缓存策略</li><li>优化内存使用</li></ul>]]></content>
<summary type="html">
<p>继续准备 LLM 面试知识,这次写文本检索技术。文本检索是 RAG(检索增强生成)系统的核心组件,也是面试中经常被问到的问题。本文将详细介绍稠密向量检索、稀疏向量检索、BM25算法以及混合检索策略,帮助理解现代文本检索系统的技术原理。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="RAG" scheme="https://murphypei.github.io/tags/RAG/"/>
<category term="文本检索" scheme="https://murphypei.github.io/tags/%E6%96%87%E6%9C%AC%E6%A3%80%E7%B4%A2/"/>
<category term="向量检索" scheme="https://murphypei.github.io/tags/%E5%90%91%E9%87%8F%E6%A3%80%E7%B4%A2/"/>
<category term="BM25" scheme="https://murphypei.github.io/tags/BM25/"/>
</entry>
<entry>
<title>LLM 幻觉与重复问题</title>
<link href="https://murphypei.github.io/blog/2025/06/llm-hallucination-repetition.html"/>
<id>https://murphypei.github.io/blog/2025/06/llm-hallucination-repetition.html</id>
<published>2025-06-27T00:31:30.000Z</published>
<updated>2026-01-07T08:41:01.086Z</updated>
<content type="html"><![CDATA[<p>LLM 的幻觉和重复问题是 LLM 应用中的核心挑战,也是面试中经常被问到的问题。本文将从底层机理出发,深入分析这两个问题的成因,并探讨有效的解决方案。</p><span id="more"></span><h2 id="引言"><a href="#引言" class="headerlink" title="引言"></a>引言</h2><p>大语言模型(LLM)在近年来取得了巨大的成功,但同时也面临着两个关键问题:<strong>幻觉(Hallucination)</strong>和<strong>重复(Repetition)</strong>。这些问题不仅影响了模型的实用性,也阻碍了其在关键领域的应用。本文将从底层机理出发,深入分析这两个问题的成因,并探讨有效的解决方案。</p><h2 id="幻觉问题(Hallucination)"><a href="#幻觉问题(Hallucination)" class="headerlink" title="幻觉问题(Hallucination)"></a>幻觉问题(Hallucination)</h2><h3 id="幻觉的基本概念"><a href="#幻觉的基本概念" class="headerlink" title="幻觉的基本概念"></a>幻觉的基本概念</h3><p>幻觉是指LLM生成的内容与事实不符,包括:</p><ul><li><strong>事实性幻觉</strong>:生成错误的事实信息。</li><li><strong>逻辑性幻觉</strong>:推理过程存在逻辑错误,前后矛盾不一致。</li><li><strong>指令性幻觉</strong>:指令遵循不一致,比如要求用英语回答,用中文回答。</li></ul><h3 id="幻觉产生的底层机理"><a href="#幻觉产生的底层机理" class="headerlink" title="幻觉产生的底层机理"></a>幻觉产生的底层机理</h3><h4 id="1-训练数据质量问题"><a href="#1-训练数据质量问题" class="headerlink" title="1. 训练数据质量问题"></a>1. 训练数据质量问题</h4><p><strong>数据噪声与错误</strong></p><ul><li>训练数据中本身就包含错误信息</li><li>网络爬取的数据质量参差不齐</li><li>标注错误导致模型学习到错误的知识</li></ul><p><strong>数据分布偏差</strong></p><ul><li>某些领域的数据过于稀少</li><li>时间性信息过时(如2023年之前的数据)</li><li>地域性偏见导致知识覆盖不均</li></ul><h4 id="2-Attention机制的局限性"><a href="#2-Attention机制的局限性" class="headerlink" title="2. Attention机制的局限性"></a>2. Attention机制的局限性</h4><p>根据<a href="https://arxiv.org/html/2504.04600v1">Attention理论</a>,Attention机制本质上是一个2体相互作用系统:</p><script type="math/tex; mode=display">\mathcal{P}(\mathbf{x}) = \mathbf{N}^{(0)}\mathsf{W}_V\mathbf{x}^{\mathrm{T}}</script><p>其中:</p><ul><li>$\mathbf{N}^{(0)}$ 是上下文向量</li><li>$\mathsf{W}_V$ 是Value投影矩阵</li><li>$\mathbf{x}$ 是词汇表中的token</li></ul><p><strong>Attention机制的固有问题</strong>:</p><ol><li><strong>局部性限制</strong>:Attention主要关注局部相关性,难以捕捉全局一致性</li><li><strong>缺乏事实验证</strong>:模型无法验证生成内容的真实性</li><li><strong>过度依赖训练数据</strong>:当遇到训练数据中未覆盖的情况时,容易产生幻觉</li></ol><h4 id="3-训练目标与事实性不匹配"><a href="#3-训练目标与事实性不匹配" class="headerlink" title="3. 训练目标与事实性不匹配"></a>3. 训练目标与事实性不匹配</h4><p><strong>最大似然估计的局限性</strong></p><ul><li>训练目标是最小化预测下一个token的损失</li><li>这个目标并不直接优化事实准确性</li><li>模型可能为了流畅性而牺牲准确性</li></ul><p><strong>缺乏事实性监督</strong></p><ul><li>训练过程中没有明确的事实性约束</li><li>模型无法区分事实性内容和创造性内容</li></ul><h3 id="缓解和消除幻觉的方法"><a href="#缓解和消除幻觉的方法" class="headerlink" title="缓解和消除幻觉的方法"></a>缓解和消除幻觉的方法</h3><h4 id="1-训练数据层面改进"><a href="#1-训练数据层面改进" class="headerlink" title="1. 训练数据层面改进"></a>1. 训练数据层面改进</h4><p><strong>高质量数据收集</strong></p><ul><li>结合多个高质量数据源</li><li>使用事实性强的数据(如维基百科、学术论文)</li><li>建立数据质量评估体系</li></ul><p><strong>数据清洗与验证</strong></p><ul><li>自动检测和移除错误数据</li><li>使用外部知识库验证数据准确性</li><li>建立数据版本控制机制</li></ul><p><strong>知识注入技术</strong></p><ul><li>将结构化知识(如知识图谱)注入训练数据</li><li>使用检索增强生成(RAG)技术</li><li>结合外部知识库进行训练</li></ul><h4 id="2-模型推理的改进"><a href="#2-模型推理的改进" class="headerlink" title="2. 模型推理的改进"></a>2. 模型推理的改进</h4><p><strong>改进的Attention机制</strong></p><ul><li>引入多步推理机制</li><li>使用思维链(Chain-of-Thought)提示</li><li>实现推理过程的显式建模</li></ul><p><strong>事实性约束</strong></p><ul><li>在Attention中加入事实性约束</li><li>使用外部知识库指导注意力分配</li><li>实现事实性验证的端到端训练</li></ul><p><strong>检索增强生成(RAG)</strong></p><ul><li>在生成过程中实时检索相关信息</li><li>使用向量数据库存储知识</li><li>实现检索与生成的联合优化</li></ul><h4 id="3-训练策略的改进"><a href="#3-训练策略的改进" class="headerlink" title="3. 训练策略的改进"></a>3. 训练策略的改进</h4><p><strong>事实性监督</strong></p><ul><li>设计专门的事实性损失函数</li><li>使用外部知识库计算事实性得分</li><li>在训练中平衡流畅性和事实性</li></ul><p><strong>对比学习</strong></p><ul><li>使用对比学习区分事实性和非事实性内容</li><li>训练模型识别和避免幻觉</li></ul><p><strong>强化学习优化</strong></p><ul><li>设计基于事实准确性的奖励函数</li><li>使用PPO等算法优化事实性</li><li>实现事实性与流畅性的平衡</li></ul><h4 id="4-推理阶段的改进"><a href="#4-推理阶段的改进" class="headerlink" title="4. 推理阶段的改进"></a>4. 推理阶段的改进</h4><p><strong>后处理验证</strong></p><ul><li>使用外部工具验证生成内容的真实性</li><li>实现自动的事实性评分</li><li>对低置信度的内容进行标记</li></ul><p><strong>多模型验证</strong></p><ul><li>使用多个模型交叉验证</li><li>实现模型集成提高准确性</li></ul><p><strong>不确定性量化</strong></p><ul><li>为生成内容提供置信度分数</li><li>实现不确定性量化</li><li>帮助用户判断内容的可靠性</li></ul><h2 id="重复问题(Repetition)"><a href="#重复问题(Repetition)" class="headerlink" title="重复问题(Repetition)"></a>重复问题(Repetition)</h2><h3 id="重复问题的基本概念"><a href="#重复问题的基本概念" class="headerlink" title="重复问题的基本概念"></a>重复问题的基本概念</h3><p>重复问题表现为:</p><ul><li><strong>词汇重复</strong>:同一个词或短语反复出现</li><li><strong>结构重复</strong>:相似的句子结构重复使用</li><li><strong>内容重复</strong>:相同的信息多次表达</li></ul><h3 id="重复产生的底层机理"><a href="#重复产生的底层机理" class="headerlink" title="重复产生的底层机理"></a>重复产生的底层机理</h3><h4 id="1-训练数据的重复模式"><a href="#1-训练数据的重复模式" class="headerlink" title="1. 训练数据的重复模式"></a>1. 训练数据的重复模式</h4><p><strong>数据中的重复模式</strong></p><ul><li>训练数据中存在大量重复内容</li><li>某些表达方式在数据中频繁出现</li><li>模型学习到了这些重复模式</li></ul><p><strong>注意力机制的偏好</strong></p><ul><li>模型倾向于关注高频出现的模式</li><li>重复内容往往具有较高的注意力权重</li></ul><h4 id="2-生成策略的影响"><a href="#2-生成策略的影响" class="headerlink" title="2. 生成策略的影响"></a>2. 生成策略的影响</h4><p><strong>贪婪解码的局限性</strong></p><ul><li>每次都选择概率最高的token</li><li>容易陷入局部最优,导致重复</li></ul><p><strong>缺乏多样性约束</strong></p><ul><li>没有明确的多样性目标</li><li>模型倾向于选择”安全”的重复模式</li></ul><h4 id="3-上下文窗口的限制"><a href="#3-上下文窗口的限制" class="headerlink" title="3. 上下文窗口的限制"></a>3. 上下文窗口的限制</h4><p><strong>长距离依赖问题</strong></p><ul><li>模型难以记住之前生成的内容</li><li>在生成长文本时容易重复</li></ul><p><strong>注意力衰减</strong></p><ul><li>随着序列长度增加,注意力权重衰减</li><li>导致模型”忘记”之前的内容</li></ul><h3 id="缓解和消除重复的方法"><a href="#缓解和消除重复的方法" class="headerlink" title="缓解和消除重复的方法"></a>缓解和消除重复的方法</h3><h4 id="1-解码策略的改进"><a href="#1-解码策略的改进" class="headerlink" title="1. 解码策略的改进"></a>1. 解码策略的改进</h4><p><strong>多样性解码</strong></p><p><strong>核采样(Nucleus Sampling)</strong></p><ul><li>只从累积概率达到阈值的token中采样</li><li>避免选择过于保守的token</li><li>在保持质量的同时增加多样性</li></ul><p><strong>温度调节</strong></p><ul><li>使用温度参数控制采样的随机性</li><li>在生成过程中动态调整温度</li><li>平衡创造性和一致性</li></ul><p><strong>重复惩罚</strong></p><ul><li>对重复出现的token进行惩罚</li><li>使用n-gram级别的重复检测</li><li>实现自适应的重复惩罚机制</li></ul><p><strong>长度惩罚</strong></p><ul><li>对过长的重复序列进行惩罚</li><li>鼓励模型生成更简洁的内容</li></ul><h4 id="2-模型架构的改进"><a href="#2-模型架构的改进" class="headerlink" title="2. 模型架构的改进"></a>2. 模型架构的改进</h4><p><strong>改进的注意力机制</strong></p><p><strong>相对位置编码</strong></p><ul><li>使用相对位置编码代替绝对位置编码</li><li>更好地处理长序列</li><li>减少位置相关的重复</li></ul><p><strong>稀疏注意力</strong></p><ul><li>使用稀疏注意力减少计算复杂度</li><li>提高长文本的处理能力</li><li>减少注意力衰减问题</li></ul><p><strong>记忆机制</strong></p><ul><li>使用外部记忆存储重要信息</li><li>实现长期依赖的建模</li><li>减少重复生成相同内容</li></ul><p><strong>分层记忆</strong></p><ul><li>实现短期和长期记忆的分离</li><li>使用不同的记忆机制处理不同时间尺度的信息</li></ul><h4 id="3-训练策略的改进-1"><a href="#3-训练策略的改进-1" class="headerlink" title="3. 训练策略的改进"></a>3. 训练策略的改进</h4><p><strong>多样性训练</strong></p><p><strong>多样性损失</strong></p><ul><li>在训练中加入多样性损失</li><li>鼓励模型生成多样化的内容</li><li>平衡一致性和创造性</li></ul><p><strong>对抗训练</strong></p><ul><li>使用对抗训练提高多样性</li><li>训练判别器识别重复内容</li><li>实现生成器和判别器的博弈</li></ul><p><strong>课程学习</strong></p><ul><li>从简单任务开始,逐步增加复杂度</li><li>在训练过程中引入多样性约束</li><li>实现更好的泛化能力</li></ul><h4 id="4-推理阶段的改进-1"><a href="#4-推理阶段的改进-1" class="headerlink" title="4. 推理阶段的改进"></a>4. 推理阶段的改进</h4><p><strong>动态调整</strong></p><p><strong>自适应解码</strong></p><ul><li>根据上下文动态调整解码策略</li><li>实现智能的重复检测和避免</li><li>使用机器学习优化解码参数</li></ul><p><strong>多候选生成</strong></p><ul><li>生成多个候选序列</li><li>使用多样性指标选择最佳序列</li><li>实现更好的内容质量</li></ul><p><strong>后处理优化</strong></p><ul><li>使用规则或机器学习方法检测重复</li><li>自动移除或改写重复内容</li><li>实现智能的内容优化</li></ul><p><strong>风格一致性</strong></p><ul><li>保持生成内容的风格一致性</li><li>避免风格上的重复</li><li>实现更自然的文本生成</li></ul><h2 id="幻觉与重复问题的关系"><a href="#幻觉与重复问题的关系" class="headerlink" title="幻觉与重复问题的关系"></a>幻觉与重复问题的关系</h2><h3 id="共同根源"><a href="#共同根源" class="headerlink" title="共同根源"></a>共同根源</h3><p><strong>训练数据问题</strong></p><ul><li>数据质量差是幻觉和重复的共同原因</li><li>数据分布不均匀导致模型学习到错误的模式</li></ul><p><strong>Attention机制的局限性</strong></p><ul><li>2体相互作用的限制</li><li>难以处理复杂的全局关系</li></ul><p><strong>训练目标的不完善</strong></p><ul><li>缺乏对事实性和多样性的直接优化</li><li>过度依赖局部最优</li></ul><h3 id="相互影响"><a href="#相互影响" class="headerlink" title="相互影响"></a>相互影响</h3><p><strong>幻觉导致重复</strong></p><ul><li>当模型不确定时,倾向于重复”安全”的内容</li><li>幻觉内容可能被模型认为是正确的,从而重复生成</li></ul><p><strong>重复加剧幻觉</strong></p><ul><li>重复生成错误内容会强化幻觉</li><li>缺乏多样性限制了模型的探索能力</li></ul><h3 id="联合解决方案"><a href="#联合解决方案" class="headerlink" title="联合解决方案"></a>联合解决方案</h3><p><strong>统一的数据策略</strong></p><ul><li>同时提高数据的准确性和多样性</li><li>建立综合的数据质量评估体系</li></ul><p><strong>改进的模型架构</strong></p><ul><li>设计同时解决幻觉和重复的架构</li><li>引入全局一致性和多样性约束</li></ul><p><strong>综合的训练目标</strong></p><ul><li>平衡事实性、流畅性和多样性</li><li>使用多目标优化方法</li></ul><h2 id="未来发展方向"><a href="#未来发展方向" class="headerlink" title="未来发展方向"></a>未来发展方向</h2><h3 id="理论突破"><a href="#理论突破" class="headerlink" title="理论突破"></a>理论突破</h3><p><strong>3体Attention机制</strong><br>根据物理学理论,当前的Attention是2体相互作用,未来可能发展出3体Attention机制,能够更好地处理复杂的关系和依赖。</p><p><strong>量子计算的应用</strong><br>量子计算可能为Attention机制提供新的计算范式,实现更高效的注意力计算。</p><h3 id="技术融合"><a href="#技术融合" class="headerlink" title="技术融合"></a>技术融合</h3><p><strong>多模态融合</strong><br>结合视觉、听觉等多种模态信息,提高模型的理解能力和生成质量。</p><p><strong>知识图谱集成</strong><br>深度集成知识图谱,实现更准确的事实性生成。</p><h3 id="评估体系"><a href="#评估体系" class="headerlink" title="评估体系"></a>评估体系</h3><p><strong>标准化评估</strong><br>建立标准化的幻觉和重复评估体系,为模型改进提供客观指标。</p><p><strong>实时监控</strong><br>实现生成过程的实时监控,及时发现和纠正问题。</p><h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>LLM的幻觉和重复问题是当前AI发展面临的重要挑战。通过深入理解其底层机理,我们可以从数据、模型架构、训练策略和推理优化等多个层面来缓解这些问题。随着技术的不断进步,我们有理由相信这些问题将得到更好的解决,推动LLM技术向更高水平发展。</p><p><strong>关键要点</strong>:</p><ol><li><strong>幻觉问题</strong>:主要由训练数据质量、Attention机制局限性和训练目标不匹配导致</li><li><strong>重复问题</strong>:主要由训练数据重复模式、生成策略局限性和上下文窗口限制导致</li><li><strong>解决方案</strong>:需要从数据、架构、训练和推理多个层面综合改进</li><li><strong>未来方向</strong>:3体Attention、多模态融合、标准化评估体系</li></ol><h2 id="参考文献"><a href="#参考文献" class="headerlink" title="参考文献"></a>参考文献</h2><ol><li><a href="https://arxiv.org/html/2504.04600v1">Capturing AI’s Attention: Physics of Repetition, Hallucination, Bias and Beyond</a></li><li><a href="https://zhuanlan.zhihu.com/p/677935286">大语言模型幻觉问题研究综述</a></li><li><a href="https://www.secrss.com/articles/73856">LLM幻觉问题的深度分析</a></li><li><a href="https://zhuanlan.zhihu.com/p/682647518">大语言模型重复问题解决方案</a></li><li><a href="https://zhuanlan.zhihu.com/p/1897569693658744522">Attention机制的物理基础</a></li></ol>]]></content>
<summary type="html">
<p>LLM 的幻觉和重复问题是 LLM 应用中的核心挑战,也是面试中经常被问到的问题。本文将从底层机理出发,深入分析这两个问题的成因,并探讨有效的解决方案。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="幻觉" scheme="https://murphypei.github.io/tags/%E5%B9%BB%E8%A7%89/"/>
<category term="重复" scheme="https://murphypei.github.io/tags/%E9%87%8D%E5%A4%8D/"/>
<category term="Attention机制" scheme="https://murphypei.github.io/tags/Attention%E6%9C%BA%E5%88%B6/"/>
<category term="模型训练" scheme="https://murphypei.github.io/tags/%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83/"/>
</entry>
<entry>
<title>LLM 训练:DPO 深入与实践</title>
<link href="https://murphypei.github.io/blog/2025/06/llm-dpo.html"/>
<id>https://murphypei.github.io/blog/2025/06/llm-dpo.html</id>
<published>2025-06-24T12:44:51.000Z</published>
<updated>2026-01-13T03:50:34.985Z</updated>
<content type="html"><![CDATA[<p>上一篇详细介绍了 RLHF 训练中的 <a href="https://murphypei.github.io/blog/2025/05/llm-ppo.html">PPO 原理和实现细节</a>。本文聚焦于 Direct Preference Optimization(DPO)在 LLM 对齐训练中的原理与实践。从直观动机与数学推导入手,给出训练流程与实现要点,随后比较 DPO 与基于奖励的 PPO 在适用场景、训练复杂度与稳定性上的异同,最后给出工程建议与常见陷阱。</p><span id="more"></span><h2 id="1-为什么选择-DPO?"><a href="#1-为什么选择-DPO?" class="headerlink" title="1. 为什么选择 DPO?"></a>1. 为什么选择 DPO?</h2><p>DPO 的核心动机是用人类偏好对模型直接建模,避免单独训练一个奖励模型带来的偏差与不稳定性。与传统的 RLHF(先训练 Reward Model,再用 RL 算法优化策略)相比,DPO 把偏好学习直接挪到策略优化目标中,流程更简单、计算更高效。</p><p>适用场景:</p><ul><li>只有偏好数据(pairwise preferences),没有或不想训练单独奖励模型</li><li>希望减少训练复杂度、提高迭代速度的工程化场景</li><li>需要明确保留参考策略能力(SFT 行为)同时提升偏好表现</li></ul><p>但 DPO 也有局限:对偏好数据质量敏感、探索能力弱、对温度参数敏感。</p><h2 id="2-基本设定与直观公式"><a href="#2-基本设定与直观公式" class="headerlink" title="2. 基本设定与直观公式"></a>2. 基本设定与直观公式</h2><p>给定一个提示 $x$,以及两个回答 $y<em>w$(被人类标注为更好)与 $y_l$(被标注为较差),DPO 以偏好对 $(x,y_w,y_l)$ 为训练样本,直接学习策略 $\pi</em>\theta$,并使用一个固定的参考策略 $\pi_{\mathrm{ref}}$(通常为 SFT 模型)来提供基线约束。</p><p>核心思想可以用以下三步理解:</p><ol><li>以 log-prob 差表示相对偏好</li><li>用 sigmoid/log-sigmoid 构造对数似然损失</li><li>通过最大化偏好对的对数似然来优化策略</li></ol><p>形式化地,定义隐式奖励为:</p><script type="math/tex; mode=display">r_\theta(x,y) = \beta \log \frac{\pi_\theta(y\mid x)}{\pi_{\mathrm{ref}}(y\mid x)}</script><p>这里 $\beta>0$ 为温度因子,控制敏感度。</p><p>将 Bradley–Terry 型偏好模型代入,DPO 的训练损失为:</p><script type="math/tex; mode=display">L_{\mathrm{DPO}}(\theta) = -\mathbb{E}_{(x,y_w,y_l)\sim D} \left[\log \sigma\left(r_\theta(x,y_w) - r_\theta(x,y_l)\right)\right],</script><p>其中 $\sigma(z)=1/(1+e^{-z})$。</p><p>直观上,若策略对 $y<em>w$ 的相对提升明显($r</em>\theta(x,y<em>w)\gg r</em>\theta(x,y_l)$),则损失接近 0;若相反则损失很大,驱动参数更新。</p><h2 id="3-数学推导要点与数值实现"><a href="#3-数学推导要点与数值实现" class="headerlink" title="3. 数学推导要点与数值实现"></a>3. 数学推导要点与数值实现</h2><p>1) 从偏好概率出发:Bradley–Terry 假设</p><script type="math/tex; mode=display">P(y_w\succ y_l\mid x) = \frac{\exp(r(x,y_w))}{\exp(r(x,y_w)) + \exp(r(x,y_l))} = \sigma\big(r(x,y_w)-r(x,y_l)\big).</script><p>2) 将奖励替换成策略对数比:$r(x,y)=\beta\log\frac{\pi<em>\theta(y|x)}{\pi</em>{\mathrm{ref}}(y|x)}$,得到上面的 DPO 损失。</p><p>3) 数值稳定与实现细节:</p><ul><li>计算时使用对数概率之差,避免直接 exponent/ratio 运算。对于自回归模型,序列对数概率为 token log-prob 的累加:<script type="math/tex; mode=display">\log\pi_\theta(y\mid x)=\sum_{t=1}^T \log\pi_\theta(y_t\mid x, y_{<t}).</script></li><li>在实现上直接计算 $\Delta = \beta\left(\log\pi<em>\theta(y_w|x)-\log\pi</em>\theta(y<em>l|x) - (\log\pi</em>{\mathrm{ref}}(y<em>w|x)-\log\pi</em>{\mathrm{ref}}(y_l|x))\right)$,再用 <code>log-sigmoid(Δ)</code>。</li><li>为数值稳定建议使用带有 <code>log-sum-exp</code> 或框架自带的 <code>logsigmoid</code>/<code>softplus</code> 函数,避免溢出/下溢。</li></ul><p>4) 梯度与参照策略:</p><p>DPO 只更新 $\pi<em>\theta$ 的参数,$\pi</em>{\mathrm{ref}}$ 保持冻结。梯度项直接来自 $\partial_\theta \log\sigma(\Delta)$,可以通过自动微分在模型上直接计算。</p><h2 id="4-训练流程(工程化建议)"><a href="#4-训练流程(工程化建议)" class="headerlink" title="4. 训练流程(工程化建议)"></a>4. 训练流程(工程化建议)</h2><p>一个典型的 DPO 训练步骤:</p><ol><li>数据准备:收集或构造偏好对 $(x,y_w,y_l)$,做必要的去重与清洗。</li><li>模型初始化:用 SFT 模型初始化 $\pi<em>\theta$,并将其拷贝为 $\pi</em>{\mathrm{ref}}$(并冻结 $\pi_{\mathrm{ref}}$)。</li><li>批次构造:每个 batch 包含若干偏好对,计算两条序列在策略和参考模型下的对数概率。</li><li>损失计算:对每个偏好对计算 $\Delta$,聚合负对数 sigmoid 损失。</li><li>反向传播与优化:仅更新 $\pi_\theta$,可使用 Adam/AdamW。</li><li>周期性评估:在验证偏好集上评估准确率、NLL 等指标;检查生成质量(一致性、流畅性、有害内容)。</li></ol><p>工程要点:</p><ul><li>批量化计算两条序列的 log-prob 可以共享 prefix 的前向计算以节省算力</li><li>使用 mixed-precision 时注意 log-prob 的数值精度</li><li>对长序列可截断或分段计算 log-prob,然后累加</li><li>参考模型占用显存,若显存紧张可在 CPU 上或按小批次推理参考模型</li></ul><h2 id="5-与-PPO-的对比(要点速览)"><a href="#5-与-PPO-的对比(要点速览)" class="headerlink" title="5. 与 PPO 的对比(要点速览)"></a>5. 与 PPO 的对比(要点速览)</h2><p>下面给出一张简表,突出两者在原理与工程上的核心差异:</p><div class="table-container"><table><thead><tr><th style="text-align:right">维度</th><th>DPO</th><th>PPO (基于 RM 的 RLHF)</th></tr></thead><tbody><tr><td style="text-align:right">所需模型</td><td>策略 + 参考</td><td>Actor + Critic + Reward + Reference</td></tr><tr><td style="text-align:right">数据</td><td>偏好对 (pairwise)</td><td>可为评分或偏好;通常需训练 Reward Model</td></tr><tr><td style="text-align:right">优势估计</td><td>无需 GAE/TD(直接用差)</td><td>需要 GAE/TD、密集价值估计</td></tr><tr><td style="text-align:right">训练复杂度</td><td>低</td><td>高(采样、回放、价值训练)</td></tr><tr><td style="text-align:right">采样探索</td><td>相对弱</td><td>可通过环境采样增强探索</td></tr><tr><td style="text-align:right">稳定性</td><td>对数据质量敏感;训练稳定且简单</td><td>裁剪机制提升稳定性;但训练管线复杂</td></tr><tr><td style="text-align:right">适用场景</td><td>只有偏好数据或想降低复杂度</td><td>可训练 RM 或需要更强探索时</td></tr></tbody></table></div><p>实务建议:若你有大量高质量偏好对且想快速迭代,优先尝试 DPO;若需要在任务中强化探索或有复杂回报结构,PPO 依然更灵活。</p><h2 id="6-超参数与调优建议"><a href="#6-超参数与调优建议" class="headerlink" title="6. 超参数与调优建议"></a>6. 超参数与调优建议</h2><ul><li>温度 $\beta$:通常 0.05–0.5。$\beta$ 越大,策略差异导致的奖励放大越明显,但也更容易过拟合偏好噪声。</li><li>学习率:同 SFT 微调类似起点(1e-5—5e-5),根据模型规模和批次大小调整。</li><li>批次大小:尽量保证每个 batch 包含多样的偏好对,避免过拟合单一 prompt。</li><li>参考策略频率:通常保持固定,但也可周期性用当前策略更新参考(慎用,会改变训练目标)。</li></ul><p>调优技巧:先在小规模数据上做超参搜索,关注偏好准确率和生成样本的可读性指标;使用 early stopping 防止过拟合。</p><h2 id="7-常见问题与陷阱"><a href="#7-常见问题与陷阱" class="headerlink" title="7. 常见问题与陷阱"></a>7. 常见问题与陷阱</h2><ol><li>偏好数据偏差:标签者偏好差异会直接反映到模型行为上。需要多标注/去噪/纠偏。</li><li>参考模型选择:若 $\pi_{\mathrm{ref}}$ 能力过弱或过强,都会影响训练效果;常用 SFT 作为平衡选择。</li><li>伪造优化(gaming the metric):模型可能学会优化 log-prob 差的“捷径”而不是语义质量。需要人工抽样检查。</li><li>温度过大:可能引起梯度爆炸或过拟合罕见偏好。</li></ol><h2 id="8-实战示例(伪代码)"><a href="#8-实战示例(伪代码)" class="headerlink" title="8. 实战示例(伪代码)"></a>8. 实战示例(伪代码)</h2><p>以下为训练一个 batch 的核心步骤(伪代码):</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">for (x,y_w,y_l) in batch:</span><br><span class="line"> lp_w_theta = logprob(theta, x, y_w)</span><br><span class="line"> lp_l_theta = logprob(theta, x, y_l)</span><br><span class="line"> lp_w_ref = logprob(ref, x, y_w)</span><br><span class="line"> lp_l_ref = logprob(ref, x, y_l)</span><br><span class="line"> delta = beta * ((lp_w_theta - lp_w_ref) - (lp_l_theta - lp_l_ref))</span><br><span class="line"> loss += -logsigmoid(delta)</span><br><span class="line">loss /= batch_size</span><br><span class="line">loss.backward()</span><br><span class="line">optimizer.step()</span><br></pre></td></tr></table></figure><p>注意:上面 <code>logprob</code> 需要做 token 层面的对数概率累加,并尽量在向量化运算中实现以提高效率。</p><h2 id="9-结论"><a href="#9-结论" class="headerlink" title="9. 结论"></a>9. 结论</h2><p>DPO 是一个工程友好、数据高效的偏好优化方法。当你拥有高质量的偏好对并且希望简化训练管线时,DPO 能显著降低开发与训练复杂度;但在需要更强探索性或复杂回报建模的场景下,PPO 与基于奖励的 RLHF 仍然不可或缺。</p><p>本文给出了 DPO 的数学基础、实现要点与调优建议,并对比了 PPO 的关键差异。下一步可以基于本文的训练流程实现一个小规模试验,并用人工评估(human eval)与自动指标双轨验证训练效果。</p><h2 id="参考文献"><a href="#参考文献" class="headerlink" title="参考文献"></a>参考文献</h2><ol><li>Rafailov, R., et al. “Direct preference optimization: Your language model is secretly a reward model.” arXiv:2305.18290 (2023).</li><li>Schulman, J., et al. “Proximal policy optimization algorithms.” arXiv:1707.06347 (2017).</li></ol>]]></content>
<summary type="html">
<p>上一篇详细介绍了 RLHF 训练中的 <a href="https://murphypei.github.io/blog/2025/05/llm-ppo.html">PPO 原理和实现细节</a>。本文聚焦于 Direct Preference Optimization(DPO)在 LLM 对齐训练中的原理与实践。从直观动机与数学推导入手,给出训练流程与实现要点,随后比较 DPO 与基于奖励的 PPO 在适用场景、训练复杂度与稳定性上的异同,最后给出工程建议与常见陷阱。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="RLHF" scheme="https://murphypei.github.io/tags/RLHF/"/>
<category term="DPO" scheme="https://murphypei.github.io/tags/DPO/"/>
</entry>
<entry>
<title>LLM 训练:PPO 原理和实现细节</title>
<link href="https://murphypei.github.io/blog/2025/05/llm-ppo.html"/>
<id>https://murphypei.github.io/blog/2025/05/llm-ppo.html</id>
<published>2025-05-25T09:15:51.000Z</published>
<updated>2026-01-13T03:50:08.595Z</updated>
<content type="html"><![CDATA[<p>本文详细讲解大模型 RLHF 阶段使用的 PPO(Proximal Policy Optimization)训练原理和实现细节。我们将从强化学习基础概念开始,逐步深入到 PPO 在 LLM 中的具体应用,最后分析 PPO 的损失函数和优势估计计算。</p><span id="more"></span><h2 id="1-强化学习基础"><a href="#1-强化学习基础" class="headerlink" title="1. 强化学习基础"></a>1. 强化学习基础</h2><h3 id="1-1-什么是强化学习"><a href="#1-1-什么是强化学习" class="headerlink" title="1.1 什么是强化学习"></a>1.1 什么是强化学习</h3><p>强化学习(Reinforcement Learning, RL)是机器学习的一个分支,区别于监督学习和无监督学习。其核心特点是:</p><ol><li><strong>无标签学习</strong>:没有预定义的正确答案,而是通过试错和奖励来学习</li><li><strong>与环境交互</strong>:智能体通过与环境的交互来获取反馈信号</li><li><strong>延迟反馈</strong>:奖励可能是延迟的,需要考虑长期累积回报而不仅是即时奖励</li><li><strong>自我改进</strong>:根据奖励信号不断优化决策策略</li></ol><h3 id="1-2-强化学习的核心要素"><a href="#1-2-强化学习的核心要素" class="headerlink" title="1.2 强化学习的核心要素"></a>1.2 强化学习的核心要素</h3><p>强化学习框架包含两个主要实体:<strong>智能体(Agent)</strong>和<strong>环境(Environment)</strong>,两者通过以下要素进行交互:</p><p><img src="/images/posts/llm/ppo/rl.webp" alt=""></p><p><strong>核心概念</strong>:</p><ul><li><strong>状态(State)$s_t$</strong>:智能体在时刻 $t$ 观测到的环境状态</li><li><strong>动作(Action)$a_t$</strong>:智能体选择执行的动作,属于动作空间 $\mathcal{A}$</li><li><strong>奖励(Reward)$r_t$</strong>:智能体执行动作后从环境获得的即时反馈</li><li><strong>策略(Policy)$\pi(a|s)$</strong>:智能体的决策函数,定义在给定状态下选择动作的概率</li><li><strong>轨迹(Trajectory)$\tau$</strong>:状态-动作-奖励序列 $(s_0, a_0, r_0, s_1, a_1, r_1, \ldots)$</li></ul><h3 id="1-3-交互过程"><a href="#1-3-交互过程" class="headerlink" title="1.3 交互过程"></a>1.3 交互过程</h3><p>一个完整的强化学习交互循环如下:</p><ol><li>智能体在状态 $s_t$ 下,根据策略 $\pi$ 选择动作 $a_t$</li><li>环境根据动作 $a<em>t$ 转移到新状态 $s</em>{t+1}$,并返回奖励 $r_t$</li><li>智能体根据获得的 $(s<em>t, a_t, r_t, s</em>{t+1})$ 四元组来更新策略</li><li>重复上述过程,不断优化策略</li></ol><p><strong>根本目标</strong>:找到最优策略 $\pi^*$,使得智能体根据环境状态选择的动作能最大化长期累积奖励。</p><p>数学上,累积奖励定义为:</p><script type="math/tex; mode=display">G_t = \sum_{k=0}^{\infty} \gamma^k r_{t+k}</script><p>其中 $\gamma \in [0,1]$ 是折扣因子,用于平衡即时奖励和未来奖励的重要性。</p><h3 id="1-4-策略梯度定理"><a href="#1-4-策略梯度定理" class="headerlink" title="1.4 策略梯度定理"></a>1.4 策略梯度定理</h3><p>策略梯度定理是强化学习中的核心理论基础,为 PPO 算法提供了数学基础。</p><h4 id="目标函数"><a href="#目标函数" class="headerlink" title="目标函数"></a>目标函数</h4><p>在强化学习中,我们的目标是最大化期望累积奖励。对于策略 $\pi_\theta$(由参数 $\theta$ 参数化),目标函数定义为:</p><script type="math/tex; mode=display">J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} [R(\tau)]</script><p>其中:</p><ul><li>$\tau$ 是轨迹(状态-动作-奖励序列)</li><li>$R(\tau)$ 是轨迹 $\tau$ 的累积奖励</li><li>$\mathbb{E}<em>{\tau \sim \pi</em>\theta}$ 表示在策略 $\pi_\theta$ 下的期望</li></ul><h4 id="策略梯度定理的核心内容"><a href="#策略梯度定理的核心内容" class="headerlink" title="策略梯度定理的核心内容"></a>策略梯度定理的核心内容</h4><p>策略梯度定理告诉我们如何通过梯度上升来优化策略参数 $\theta$,以最大化目标函数 $J(\theta)$。具体地:</p><script type="math/tex; mode=display">\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} [\nabla_\theta \log \pi_\theta(\tau) R(\tau)]</script><p><strong>公式含义</strong>:</p><ul><li>左侧:目标函数 $J(\theta)$ 关于参数 $\theta$ 的梯度,指示参数如何变化才能增加期望奖励</li><li>右侧:策略对数概率的梯度与轨迹奖励的乘积的期望</li></ul><p><strong>直观理解</strong>:</p><ul><li>如果轨迹 $\tau$ 的奖励 $R(\tau)$ 较高,我们就增加这个轨迹的概率</li><li>如果轨迹 $\tau$ 的奖励 $R(\tau)$ 较低,我们就减少这个轨迹的概率</li><li>$\nabla<em>\theta \log \pi</em>\theta(\tau)$ 指示在参数空间中应该朝哪个方向移动</li></ul><h4 id="策略梯度方法的局限性"><a href="#策略梯度方法的局限性" class="headerlink" title="策略梯度方法的局限性"></a>策略梯度方法的局限性</h4><p>虽然策略梯度定理提供了优化框架,但直接应用存在几个严重问题:</p><ol><li><strong>高方差</strong>:直接使用轨迹奖励会导致方差很大,训练不稳定</li><li><strong>样本效率低</strong>:需要大量样本才能准确估计期望,导致采样成本高</li><li><strong>学习信号稀疏</strong>:对于长轨迹,只有轨迹末端的奖励,无法有效指导中间步骤的学习</li><li><strong>更新步长难控</strong>:策略更新幅度可能过大,导致学习不稳定甚至发散</li></ol><p>这些问题激发了 PPO 等改进算法的发展。</p><h3 id="1-5-PPO-算法的核心思想"><a href="#1-5-PPO-算法的核心思想" class="headerlink" title="1.5 PPO 算法的核心思想"></a>1.5 PPO 算法的核心思想</h3><p>PPO(Proximal Policy Optimization)通过引入<strong>重要性采样比率(Importance Sampling Ratio)</strong> 来限制策略更新的幅度,在保证训练稳定性的同时提高样本效率。</p><h4 id="重要性采样比率"><a href="#重要性采样比率" class="headerlink" title="重要性采样比率"></a>重要性采样比率</h4><p>PPO 定义的重要性采样比率为:</p><script type="math/tex; mode=display">r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}</script><p><strong>比率含义</strong>:</p><ul><li>$\pi<em>{\theta</em>{old}}(a_t|s_t)$:旧策略(更新前)在状态 $s_t$ 下选择动作 $a_t$ 的概率</li><li>$\pi_\theta(a_t|s_t)$:新策略(更新后)在状态 $s_t$ 下选择动作 $a_t$ 的概率</li><li>$r_t(\theta)$:新策略相对于旧策略的概率变化比例</li></ul><h4 id="为什么需要限制策略更新幅度"><a href="#为什么需要限制策略更新幅度" class="headerlink" title="为什么需要限制策略更新幅度"></a>为什么需要限制策略更新幅度</h4><ol><li><strong>防止策略崩溃</strong>:过大的更新可能导致某些动作概率变为 0,失去探索能力,或导致策略完全改变而偏离原有轨道</li><li><strong>保证训练稳定性</strong>:过大的更新步长会导致目标函数振荡甚至发散</li><li><strong>避免灾难性遗忘</strong>:过度更新会使新策略完全偏离旧策略,丢失之前学到的有用知识</li></ol><h4 id="PPO-的裁剪机制"><a href="#PPO-的裁剪机制" class="headerlink" title="PPO 的裁剪机制"></a>PPO 的裁剪机制</h4><p>PPO 的主要创新是引入了裁剪机制来限制更新幅度:</p><script type="math/tex; mode=display">L^{CLIP}(\theta) = \mathbb{E}_t \left[\min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)\right]</script><p><strong>裁剪函数定义</strong>:</p><script type="math/tex; mode=display">\text{clip}(x, a, b) = \begin{cases}a & \text{if } x < a \\x & \text{if } a \leq x \leq b \\b & \text{if } x > b\end{cases}</script><p><strong>核心机制</strong>:</p><p>当优势函数 $A_t > 0$(表示这是一个好的动作)时:</p><ul><li>如果 $r_t(\theta) < 1$(新策略概率较小),鼓励增加动作概率,但最多只能增加到 $1+\epsilon$ 倍</li><li>如果 $r_t(\theta) > 1+\epsilon$(新策略概率已经足够大),停止鼓励,避免过度优化</li></ul><p>当优势函数 $A_t < 0$(表示这是一个坏的动作)时:</p><ul><li>如果 $r_t(\theta) > 1$(新策略概率较大),鼓励减少动作概率,但最多只能减少到 $1-\epsilon$ 倍</li><li>如果 $r_t(\theta) < 1-\epsilon$(新策略概率已经足够小),停止惩罚,避免过度惩罚</li></ul><p><strong>参数说明</strong>:</p><ul><li>$A_t$:优势函数(衡量动作相对于平均水平的优势)</li><li>$\epsilon$:裁剪范围,通常设置为 0.2,意味着策略更新幅度被限制在 ±20% 以内</li><li>这个设计确保了策略不会偏离过远,同时充分利用有利的轨迹进行学习</li></ul><h3 id="1-6-价值函数"><a href="#1-6-价值函数" class="headerlink" title="1.6 价值函数"></a>1.6 价值函数</h3><h4 id="为什么需要价值函数"><a href="#为什么需要价值函数" class="headerlink" title="为什么需要价值函数"></a>为什么需要价值函数</h4><p>在策略梯度方法中,我们使用轨迹的累积奖励 $R(\tau)$ 来更新策略。但这个方法存在一个问题:<strong>方差很大</strong>。</p><p>对于一条长轨迹,很多步骤都会获得同样的奖励信号(终点的奖励),这会导致许多中间步骤的梯度估计方差很高,学习效率低。</p><p>为了解决这个问题,我们引入<strong>价值函数(Value Function)</strong> 的概念:价值函数预测从某个状态开始能获得的期望累积奖励,可以作为”基线”来减少方差。</p><h4 id="状态价值函数"><a href="#状态价值函数" class="headerlink" title="状态价值函数"></a>状态价值函数</h4><p>状态价值函数 $V(s_t)$ 定义为:从状态 $s_t$ 开始,遵循策略 $\pi$ 能获得的期望累积奖励。</p><script type="math/tex; mode=display">V^\pi(s_t) = \mathbb{E}_{\pi} \left[\sum_{k=0}^{\infty} \gamma^k r_{t+k} \mid S_t = s_t\right]</script><p><strong>关键特点</strong>:</p><ul><li>仅依赖于状态,与具体的动作无关</li><li>表示该状态的”好坏程度”(价值)</li><li>在 LLM 中,Critic 模型就是用来估计 $V(s_t)$ 的</li></ul><h4 id="动作价值函数"><a href="#动作价值函数" class="headerlink" title="动作价值函数"></a>动作价值函数</h4><p>动作价值函数 $Q(s_t, a_t)$ 定义为:在状态 $s_t$ 下采取动作 $a_t$,然后遵循策略 $\pi$ 能获得的期望累积奖励。</p><script type="math/tex; mode=display">Q^\pi(s_t, a_t) = \mathbb{E}_{\pi} \left[\sum_{k=0}^{\infty} \gamma^k r_{t+k} \mid S_t = s_t, A_t = a_t\right]</script><p><strong>关键特点</strong>:</p><ul><li>依赖于状态和具体的动作</li><li>表示在特定状态下采取特定动作的”好坏程度”</li><li>可以用来评估一个动作相对于其他动作的相对优势</li></ul><h4 id="两种价值函数的关系"><a href="#两种价值函数的关系" class="headerlink" title="两种价值函数的关系"></a>两种价值函数的关系</h4><p>状态价值函数是动作价值函数在所有可能动作上的期望:</p><script type="math/tex; mode=display">V^\pi(s_t) = \mathbb{E}_{a \sim \pi(\cdot|s_t)} [Q^\pi(s_t, a)]</script><p>这个关系说明:某个状态的价值等于在该状态下遵循策略选择所有可能动作的期望价值。</p><h3 id="1-7-优势函数"><a href="#1-7-优势函数" class="headerlink" title="1.7 优势函数"></a>1.7 优势函数</h3><h4 id="概念定义"><a href="#概念定义" class="headerlink" title="概念定义"></a>概念定义</h4><p>优势函数(Advantage Function)衡量了在某个状态下采取某个特定动作相对于该状态平均水平的优势程度。它是 PPO 算法中最重要的概念之一。</p><script type="math/tex; mode=display">A_t(s_t, a_t) = Q(s_t, a_t) - V(s_t)</script><p><strong>核心含义</strong>:</p><script type="math/tex; mode=display">A_t = \text{该动作的价值} - \text{该状态下的平均价值}</script><h4 id="优势函数的直观理解"><a href="#优势函数的直观理解" class="headerlink" title="优势函数的直观理解"></a>优势函数的直观理解</h4><ul><li>如果 $A_t > 0$:表示动作 $a_t$ 比该状态的平均水平更好,应该<strong>增加</strong>这个动作的概率</li><li>如果 $A_t < 0$:表示动作 $a_t$ 比该状态的平均水平更差,应该<strong>减少</strong>这个动作的概率</li><li>如果 $A_t \approx 0$:表示动作 $a_t$ 处于平均水平,既不特别好也不特别差</li></ul><h4 id="优势函数的优势(为什么用它)"><a href="#优势函数的优势(为什么用它)" class="headerlink" title="优势函数的优势(为什么用它)"></a>优势函数的优势(为什么用它)</h4><p>与直接使用奖励相比,优势函数有几个重要优势:</p><ol><li><strong>减少方差</strong>:相比使用原始奖励,优势函数通过相对比较而不是绝对值,大大降低了学习信号的方差,使训练更稳定</li><li><strong>有效的基线</strong>:状态价值函数 $V(s_t)$ 作为基线(baseline),消除了与动作选择无关的状态价值部分,让梯度只关注动作的相对优劣</li><li><strong>更细粒度的反馈</strong>:在同一状态下,优势函数能清晰区分不同动作的相对好坏,提供更精确的学习信号</li><li><strong>加速收敛</strong>:相对比较的学习信号更容易指导策略优化,使训练收敛速度更快</li></ol><h4 id="时序差分估计"><a href="#时序差分估计" class="headerlink" title="时序差分估计"></a>时序差分估计</h4><p>在实践中,我们通常使用<strong>时序差分(Temporal Difference, TD)</strong> 方法来估计优势函数。最简单的 TD(0) 估计为:</p><script type="math/tex; mode=display">A_t^{\text{TD}} = r_t + \gamma V(s_{t+1}) - V(s_t)</script><p>这个公式的含义是:</p><ul><li>$r<em>t + \gamma V(s</em>{t+1})$:实际观测到的短期收益加上对未来的价值估计</li><li>$V(s_t)$:模型原本对该状态价值的预期</li><li>它们的差异就是”surprise”,即实际与预期的偏差</li></ul><h2 id="2-LLM-中的强化学习应用"><a href="#2-LLM-中的强化学习应用" class="headerlink" title="2. LLM 中的强化学习应用"></a>2. LLM 中的强化学习应用</h2><p>现在我们已经理解了强化学习和 PPO 的基本原理,接下来看看如何将这些理论应用到 LLM 的训练中。</p><h3 id="2-1-LLM-的生成过程"><a href="#2-1-LLM-的生成过程" class="headerlink" title="2.1 LLM 的生成过程"></a>2.1 LLM 的生成过程</h3><p>LLM 的推理过程遵循自回归范式:给定 prompt,模型逐个生成 token,直到生成完毕。具体过程如下:</p><ol><li><strong>初始化</strong>:输入 prompt,对应 RL 中的初始状态 $s_0$</li><li><strong>迭代生成</strong>:每个时刻 $t$,模型基于当前上下文(状态 $s_t$)产生一个新 token(动作 $a_t$)</li><li><strong>状态转移</strong>:产生的新 token 被添加到上下文中,形成新的上下文作为下一时刻的状态 $s_{t+1}$</li><li><strong>终止</strong>:当生成 EOS(End-of-Sequence)token 或达到最大长度时,生成过程结束</li></ol><h3 id="2-2-LLM-生成过程中的-RL-映射"><a href="#2-2-LLM-生成过程中的-RL-映射" class="headerlink" title="2.2 LLM 生成过程中的 RL 映射"></a>2.2 LLM 生成过程中的 RL 映射</h3><p><img src="/images/posts/llm/ppo/nlp-rl.webp" alt=""></p><p>让我们将 LLM 生成过程映射到强化学习框架中:</p><p><strong>核心映射</strong>:</p><div class="table-container"><table><thead><tr><th>RL 概念</th><th>LLM 中的含义</th></tr></thead><tbody><tr><td>智能体(Agent)</td><td>语言模型本身</td></tr><tr><td>状态(State)$s_t$</td><td>当前的输入提示和已生成的 token 序列(上下文)</td></tr><tr><td>动作(Action)$a_t$</td><td>在时刻$t$ 生成的 token(从词汇表中选择)</td></tr><tr><td>动作空间</td><td>词汇表(通常包含数万个 token)</td></tr><tr><td>奖励(Reward)$r_t$</td><td>人类偏好信号(通过奖励模型 RM 评估)</td></tr><tr><td>策略(Policy)$\pi$</td><td>模型的概率分布输出$P(a_t \mid s_t)$</td></tr></tbody></table></div><p><strong>关键区别</strong>:</p><ul><li>在 RL 中,每一步都可能获得奖励;在 LLM 中,<strong>奖励通常是稀疏的</strong>,只在序列生成完成时才获得一个标量奖励</li><li>这个差异导致 LLM 中的 RLHF 需要特殊的奖励设计和处理</li></ul><h3 id="2-3-RLHF-训练的必要性"><a href="#2-3-RLHF-训练的必要性" class="headerlink" title="2.3 RLHF 训练的必要性"></a>2.3 RLHF 训练的必要性</h3><p>为什么需要用 RLHF 进行后训练呢?</p><ol><li><p><strong>偏离 SFT 目标</strong>:SFT(Supervised Fine-Tuning)阶段使用人工标注的高质量回答来训练,但这种方法:</p><ul><li>数据成本高昂(需要大量人工标注)</li><li>无法很好地优化人类真实的、多样化的偏好</li></ul></li><li><strong>无法直接优化人类偏好</strong>:SFT 是模仿学习,模型学到的是”如何模仿”而不是”如何做得好”</li><li><strong>奖励函数设计困难</strong>:不同的应用场景对 LLM 的要求不同,需要灵活的奖励机制来指导模型朝着特定目标优化</li><li><strong>充分发挥模型能力</strong>:RLHF 能让已有的语言模型进一步优化,使其生成更符合人类偏好的回答</li></ol><h2 id="3-RLHF-中的模型架构"><a href="#3-RLHF-中的模型架构" class="headerlink" title="3. RLHF 中的模型架构"></a>3. RLHF 中的模型架构</h2><p>在 PPO 阶段的 RLHF 训练中,需要协调多个模型来共同完成优化目标。这一节我们详细介绍这些模型及其作用。</p><h3 id="3-1-四个核心模型"><a href="#3-1-四个核心模型" class="headerlink" title="3.1 四个核心模型"></a>3.1 四个核心模型</h3><p>RLHF 中的 PPO 训练涉及 4 个不同的模型,它们各司其职:</p><p><img src="/images/posts/llm/ppo/rlhf.webp" alt=""></p><div class="table-container"><table><thead><tr><th>模型</th><th>全称</th><th>作用</th><th>参数更新</th></tr></thead><tbody><tr><td>Actor</td><td>Actor Model</td><td>生成回答的策略模型,最终部署的模型</td><td>✓ 需要</td></tr><tr><td>Critic</td><td>Critic Model</td><td>评估状态价值,估计长期收益</td><td>✓ 需要</td></tr><tr><td>Reward Model</td><td>Reward Model</td><td>评分函数,衡量生成回答的整体质量</td><td>✗ 不需要</td></tr><tr><td>Reference</td><td>Reference Model</td><td>参考标准,限制 Actor 的偏离程度</td><td>✗ 不需要</td></tr></tbody></table></div><h3 id="3-2-模型详解"><a href="#3-2-模型详解" class="headerlink" title="3.2 模型详解"></a>3.2 模型详解</h3><h4 id="Actor-Model(策略模型)"><a href="#Actor-Model(策略模型)" class="headerlink" title="Actor Model(策略模型)"></a>Actor Model(策略模型)</h4><p><img src="/images/posts/llm/ppo/actor.webp" alt=""></p><p><strong>职责</strong>:生成符合人类偏好的回答</p><p><strong>特点</strong>:</p><ul><li>基于 SFT 模型微调得到,在 PPO 阶段继续优化</li><li>接收 prompt,生成一系列 token 作为回答</li><li>是 RLHF 训练的直接对象,参数持续更新</li></ul><p><strong>工作流程</strong>:</p><ol><li>给定 prompt 输入</li><li>模型根据学到的策略逐个生成 token</li><li>将”prompt + response”送入损失计算体系进行优化</li><li>通过反向传播更新参数</li></ol><h4 id="Reference-Model(参考模型)"><a href="#Reference-Model(参考模型)" class="headerlink" title="Reference Model(参考模型)"></a>Reference Model(参考模型)</h4><p><img src="/images/posts/llm/ppo/ref.webp" alt=""></p><p><strong>职责</strong>:作为优化的约束,防止 Actor 偏离过大</p><p><strong>特点</strong>:</p><ul><li>通常用 SFT 模型初始化</li><li><strong>参数冻结</strong>,在整个 RLHF 过程中不更新</li><li>代表”原始能力基线”</li></ul><p><strong>核心机制 - KL 散度约束</strong>:</p><p>我们希望训练出来的 Actor 既能达到符合人类偏好的目标,又尽量让输出分布与原始 SFT 模型相近。这通过 <strong>KL(Kullback-Leibler)散度</strong>来衡量:</p><script type="math/tex; mode=display">D_{KL}(\pi_{\theta} \| \pi_{\text{SFT}}) = \sum_x \pi_{\theta}(x) \log \frac{\pi_{\theta}(x)}{\pi_{\text{SFT}}(x)}</script><p><strong>防止模型训歪的直观理解</strong>:</p><ul><li>如果 Actor 更新使得某些输出的概率变化太大,KL 散度会很高,产生较大的惩罚</li><li>这种约束防止模型为了追求高奖励而生成虚假、有害或无意义的内容</li><li>保证模型保留原有的知识和能力</li></ul><h4 id="Critic-Model(价值模型)"><a href="#Critic-Model(价值模型)" class="headerlink" title="Critic Model(价值模型)"></a>Critic Model(价值模型)</h4><p><img src="/images/posts/llm/ppo/critic.webp" alt=""></p><p><strong>职责</strong>:估计状态的长期价值,指导策略优化</p><p><strong>特点</strong>:</p><ul><li>架构与 Actor 相似,基于 SFT 模型</li><li>添加了一个”Value Head”输出层:将最后一个 token 的隐藏状态(通常 4096 维)映射到一个标量值</li><li><strong>参数需要更新</strong>,通过 MSE 损失函数优化</li></ul><p><strong>价值头的计算</strong>:</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">隐藏状态 (d_model,) → 线性层 (1,) → 单一标量价值 V(s_t)</span><br></pre></td></tr></table></figure><p><strong>工作流程</strong>:</p><ol><li>接收相同的 prompt + response 输入</li><li>在每个 token 位置都可以输出一个价值估计 $V(s_t)$</li><li>用于计算时序差分 TD 误差和优势函数</li><li>通过价值函数损失来更新参数</li></ol><h4 id="Reward-Model(奖励模型)"><a href="#Reward-Model(奖励模型)" class="headerlink" title="Reward Model(奖励模型)"></a>Reward Model(奖励模型)</h4><p><img src="/images/posts/llm/ppo/reward.webp" alt=""></p><p><strong>职责</strong>:评估整个生成序列的质量</p><p><strong>特点</strong>:</p><ul><li>在 RLHF 第一阶段(Reward Modeling 阶段)提前训练好</li><li><strong>参数冻结</strong>,PPO 阶段不更新</li><li>提供客观的、绝对的评分标准</li></ul><p><strong>奖励计算</strong>:</p><ul><li>只在序列生成完成时(EOS token)进行<strong>一次</strong>评分</li><li>输出一个标量奖励 $r_{\text{final}}$,代表整个回答的质量</li><li>不对中间 token 进行评分(奖励是稀疏的)</li></ul><p><strong>为什么不更新参数</strong>:</p><ul><li>冻结奖励模型确保评分标准的一致性和客观性</li><li>如果边训练边更新,会导致奖励标准不断变化,模型学习信号混乱</li><li>冻结参数保证了训练过程中有一个稳定的、绝对的评估基准</li></ul><h3 id="3-3-四个模型的关键区别对比"><a href="#3-3-四个模型的关键区别对比" class="headerlink" title="3.3 四个模型的关键区别对比"></a>3.3 四个模型的关键区别对比</h3><p><strong>奖励模型 vs 价值模型</strong>:</p><p>这是初学者最容易混淆的地方。让我们清晰地对比:</p><div class="table-container"><table><thead><tr><th>维度</th><th>Reward Model</th><th>Critic Model</th></tr></thead><tbody><tr><td><strong>评分时机</strong></td><td>序列结束时(EOS),一次性评分</td><td>每个 token 位置都有评分</td></tr><tr><td><strong>评分特性</strong></td><td>稀疏奖励(只在末端)</td><td>密集价值估计(每步都有)</td></tr><tr><td><strong>更新</strong></td><td>冻结不更新</td><td>需要更新参数</td></tr><tr><td><strong>作用</strong></td><td>衡量整体质量</td><td>估计长期收益</td></tr><tr><td><strong>输入</strong></td><td>完整的 prompt + response</td><td>完整的 prompt + response</td></tr><tr><td><strong>评分依据</strong></td><td>人类偏好信号</td><td>TD 误差和实际奖励</td></tr><tr><td><strong>反向传播</strong></td><td>无(冻结参数)</td><td>有(通过 Value Loss)</td></tr></tbody></table></div><p><strong>直观例子</strong>:</p><ul><li>RM:就像最终的考官,在你完整答题后给一个分数</li><li>Critic:就像教练,在你的每一步都给出潜在的长期收益评估</li></ul><h2 id="4-PPO-损失函数与优化"><a href="#4-PPO-损失函数与优化" class="headerlink" title="4. PPO 损失函数与优化"></a>4. PPO 损失函数与优化</h2><p>PPO 的强大之处在于其巧妙的多目标损失函数设计。让我们深入理解 PPO 如何通过损失函数来平衡多个优化目标。</p><h3 id="4-1-总体损失函数"><a href="#4-1-总体损失函数" class="headerlink" title="4.1 总体损失函数"></a>4.1 总体损失函数</h3><p>PPO 的总损失函数由三部分组成:</p><script type="math/tex; mode=display">L_{\text{total}} = L^{\text{CLIP}} - \alpha \cdot L^{\text{KL}} + \beta \cdot L^{\text{VF}}</script><p>其中 $L^{\text{CLIP}}$、$L^{\text{KL}}$、$L^{\text{VF}}$ 分别对应:</p><script type="math/tex; mode=display">L^{\text{CLIP}} = \mathbb{E}_{t} [\cdots], \quad L^{\text{KL}} = \mathbb{E}_t [\cdots], \quad L^{\text{VF}} = \mathbb{E}_t [\cdots]</script><p><strong>重要</strong>:这里的 $\mathbb{E}_t$ 表示对<strong>所有 token</strong> 的期望,即:</p><script type="math/tex; mode=display">\mathbb{E}_t [\cdot] = \frac{1}{T} \sum_{t=1}^{T} [\cdot]</script><p>也就是说:</p><ol><li>首先对单个序列内的所有 token 计算损失并求平均</li><li>然后对 batch 内的所有序列求平均</li></ol><p>每一项都对应一个特定的优化目标:</p><div class="table-container"><table><thead><tr><th>损失项</th><th>来源模型</th><th>目标</th><th>权重</th></tr></thead><tbody><tr><td>$L^{\text{CLIP}}$</td><td>Actor</td><td>优化策略以最大化奖励</td><td>1(基准)</td></tr><tr><td>$L^{\text{KL}}$</td><td>Reference</td><td>限制与参考模型的偏离</td><td>$\alpha$ (通常 0.2)</td></tr><tr><td>$L^{\text{VF}}$</td><td>Critic</td><td>训练价值函数估计器</td><td>$\beta$ (通常 0.5)</td></tr></tbody></table></div><p><strong>直观理解</strong>:</p><ul><li>Actor 想要尽可能追求高奖励</li><li>Reference 在其肩膀上按住,说”别偏离原来的太远”</li><li>Critic 在旁边学习”这个状态值多少钱”</li></ul><h3 id="4-2-策略损失:-L-text-CLIP"><a href="#4-2-策略损失:-L-text-CLIP" class="headerlink" title="4.2 策略损失:$L^{\text{CLIP}}$"></a>4.2 策略损失:$L^{\text{CLIP}}$</h3><p>这是 PPO 的核心创新,我们已经详细讨论过。这里补充一些实现细节。</p><h4 id="定义"><a href="#定义" class="headerlink" title="定义"></a>定义</h4><script type="math/tex; mode=display">L^{\text{CLIP}} = \mathbb{E}_{t} \left[\min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)\right]</script><p>其中:</p><ul><li>$r<em>t(\theta) = \frac{\pi</em>\theta(a<em>t|s_t)}{\pi</em>{\theta_{\text{old}}}(a_t|s_t)}$:<strong>Token 级别</strong>的重要性采样比率(新策略相对于旧策略的概率比)</li><li>$A_t$:优势估计(也是 <strong>token 级别</strong>)</li><li>$\epsilon = 0.2$:裁剪范围</li></ul><p><strong>关键点</strong>:PPO 在<strong>每个 token 位置</strong>都计算一个重要性采样比率和优势,然后对这些 token 进行 min 和 clip 操作。这与只考虑序列级别的方法不同。</p><h4 id="为什么使用-min-操作"><a href="#为什么使用-min-操作" class="headerlink" title="为什么使用 min 操作"></a>为什么使用 min 操作</h4><p>PPO 使用 $\min$ 操作和裁剪的巧妙之处:</p><ol><li><p><strong>无裁剪项 $r_t(\theta) A_t$</strong>:</p><ul><li>当 $A_t > 0$ 时,$r_t$ 越大越好(鼓励增加好动作)</li><li>当 $A_t < 0$ 时,$r_t$ 越小越好(鼓励减少坏动作)</li></ul></li><li><p><strong>有裁剪项 $\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t$</strong>:</p><ul><li>强制 $r_t$ 的范围,防止更新过大</li></ul></li><li><p><strong>$\min$ 操作的作用</strong>:</p><ul><li>当 $A_t > 0$(好动作)时,选择两者中较小的,防止过度优化</li><li>当 $A_t < 0$(坏动作)时,选择两者中较小的,防止过度惩罚</li></ul></li></ol><p><strong>数值例子</strong>:</p><p>假设 $\epsilon = 0.2$,$A_t = 1.0$(好动作):</p><div class="table-container"><table><thead><tr><th>$r_t(\theta)$</th><th>$r_t A_t$</th><th>$\text{clip}(r_t) A_t$</th><th>$\min$ 结果</th><th>含义</th></tr></thead><tbody><tr><td>1.5</td><td>1.5</td><td>1.2</td><td>1.2</td><td>已经足够好了,停止优化</td></tr><tr><td>1.3</td><td>1.3</td><td>1.2</td><td>1.2</td><td>已经足够好了,停止优化</td></tr><tr><td>1.1</td><td>1.1</td><td>1.1</td><td>1.1</td><td>继续鼓励</td></tr><tr><td>0.9</td><td>0.9</td><td>0.9</td><td>0.9</td><td>有所衰退,仍需改进</td></tr><tr><td>0.5</td><td>0.5</td><td>0.8</td><td>0.5</td><td>已经太差了,停止改进</td></tr></tbody></table></div><p>这个机制保证了:<strong>好的方向上不会过度优化,坏的方向上也不会过度惩罚</strong>。</p><h3 id="4-3-KL-约束损失:-L-text-KL"><a href="#4-3-KL-约束损失:-L-text-KL" class="headerlink" title="4.3 KL 约束损失:$L^{\text{KL}}$"></a>4.3 KL 约束损失:$L^{\text{KL}}$</h3><h4 id="定义-1"><a href="#定义-1" class="headerlink" title="定义"></a>定义</h4><p>KL 散度衡量两个概率分布之间的差异:</p><script type="math/tex; mode=display">L^{\text{KL}} = \mathbb{E}_t \left[D_{\text{KL}}(\pi_{\theta}(\cdot|s_t) \| \pi_{\text{ref}}(\cdot|s_t))\right]</script><p>展开后:</p><script type="math/tex; mode=display">L^{\text{KL}} = \mathbb{E}_t \left[\sum_{a} \pi_{\theta}(a|s_t) \log \frac{\pi_{\theta}(a|s_t)}{\pi_{\text{ref}}(a|s_t)}\right]</script><h4 id="为什么需要-KL-约束"><a href="#为什么需要-KL-约束" class="headerlink" title="为什么需要 KL 约束"></a>为什么需要 KL 约束</h4><ol><li><strong>防止分布崩溃</strong>:如果不限制,Actor 可能会为了追求高奖励,把所有概率都集中到少数几个高奖励的 token,失去多样性</li><li><strong>保持模型能力</strong>:新 Actor 与 SFT 模型相近,保留了原有的知识和一般化能力,不会因过度优化而”忘记”</li><li><strong>实际部署考虑</strong>:过度优化可能导致模型在某些边缘情况下表现很差,或生成有害内容</li><li><strong>训练稳定性</strong>:限制分布变化范围有助于梯度稳定,防止训练发散</li></ol><h4 id="权重超参数-alpha"><a href="#权重超参数-alpha" class="headerlink" title="权重超参数 $\alpha$"></a>权重超参数 $\alpha$</h4><ul><li>典型值:$\alpha = 0.2$</li><li>含义:平衡追求高奖励和保持分布相似性</li><li>调优:<ul><li>如果 $\alpha$ 过小:模型过度优化奖励,生成不自然回答</li><li>如果 $\alpha$ 过大:模型变化不足,无法充分学习到人类偏好</li></ul></li></ul><h3 id="4-4-价值函数损失:-L-text-VF"><a href="#4-4-价值函数损失:-L-text-VF" class="headerlink" title="4.4 价值函数损失:$L^{\text{VF}}$"></a>4.4 价值函数损失:$L^{\text{VF}}$</h3><h4 id="定义-2"><a href="#定义-2" class="headerlink" title="定义"></a>定义</h4><script type="math/tex; mode=display">L^{\text{VF}} = \mathbb{E}_t \left[(V_\phi(s_t) - V_t^{\text{target}})^2\right]</script><p>其中:</p><ul><li>$V_\phi(s_t)$:Critic 模型对状态价值的预测</li><li>$V_t^{\text{target}}$:目标价值(由奖励和下一状态价值计算)</li></ul><h4 id="目标价值的计算"><a href="#目标价值的计算" class="headerlink" title="目标价值的计算"></a>目标价值的计算</h4><p>目标价值通常采用 n-step bootstrap 的方式:</p><script type="math/tex; mode=display">V_t^{\text{target}} = r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \cdots + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n})</script><p>最简单的 1-step 版本:</p><script type="math/tex; mode=display">V_t^{\text{target}} = r_t + \gamma V(s_{t+1})</script><h4 id="为什么需要价值函数损失"><a href="#为什么需要价值函数损失" class="headerlink" title="为什么需要价值函数损失"></a>为什么需要价值函数损失</h4><ol><li><strong>训练 Critic</strong>:Critic 需要学会准确评估状态价值,这个损失就是训练信号</li><li><strong>减少方差</strong>:准确的价值估计能更好地作为基线,减少优势函数估计的方差</li><li><strong>稳定优势计算</strong>:优势函数依赖于价值估计,好的 Critic 模型产生更稳定的优势信号</li></ol><h4 id="权重超参数-beta"><a href="#权重超参数-beta" class="headerlink" title="权重超参数 $\beta$"></a>权重超参数 $\beta$</h4><ul><li>典型值:$\beta = 0.5$</li><li>含义:平衡策略优化和价值函数拟合</li><li>调优:<ul><li>如果 $\beta$ 过小:Critic 学习不充分,方差大</li><li>如果 $\beta$ 过大:Critic 过度拟合,策略学习受影响</li></ul></li></ul><h2 id="5-优势函数的计算"><a href="#5-优势函数的计算" class="headerlink" title="5. 优势函数的计算"></a>5. 优势函数的计算</h2><p>优势函数的准确计算对 PPO 的性能至关重要。在 LLM 的 RLHF 场景中,我们需要特别考虑奖励的稀疏性和长序列的特性。</p><h3 id="5-1-Token-级别的即时奖励构造"><a href="#5-1-Token-级别的即时奖励构造" class="headerlink" title="5.1 Token 级别的即时奖励构造"></a>5.1 Token 级别的即时奖励构造</h3><p>在 LLM 中,奖励有一个特殊的特点:<strong>只有序列末端有实际奖励(来自 RM),中间步骤没有直接的奖励信号</strong>。</p><h4 id="奖励的构造方式"><a href="#奖励的构造方式" class="headerlink" title="奖励的构造方式"></a>奖励的构造方式</h4><p>对于生成序列中的第 $t$ 个 token,我们构造其即时奖励为:</p><script type="math/tex; mode=display">r_t = \begin{cases}-\beta \cdot \log \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{ref}}(a_t|s_t)} & \text{if } t < T \\r_{\text{RM}} - \beta \cdot \log \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{ref}}(a_t|s_t)} & \text{if } t = T\end{cases}</script><p>其中:</p><ul><li>$T$:序列长度(生成到 EOS 或最大长度)</li><li>$r_{\text{RM}}$:Reward Model 在序列末端给出的奖励</li><li>$\beta$:KL 惩罚的权重(典型值 0.05-0.2)</li></ul><p><strong>关键特点</strong>:</p><ol><li><p><strong>稀疏 RM 奖励</strong>:</p><ul><li>中间 token 的奖励为 0($t < T$)</li><li>只有最后一个 token($t = T$)才获得 RM 的评分 $r_{\text{RM}}$</li><li>这反映了现实:我们只在完整回答后才评分</li></ul></li><li><p><strong>密集 KL 惩罚</strong>:</p><ul><li>每个 token 都会产生 KL 惩罚</li><li>防止策略在追求高奖励时偏离参考模型太远</li></ul></li></ol><p><strong>直观理解</strong>:</p><ul><li>前 $T-1$ 个 token:只受 KL 约束的负值惩罚</li><li>第 $T$ 个 token:既获得奖励鼓励,也受 KL 约束</li></ul><h3 id="5-2-时序差分误差与-GAE"><a href="#5-2-时序差分误差与-GAE" class="headerlink" title="5.2 时序差分误差与 GAE"></a>5.2 时序差分误差与 GAE</h3><h4 id="时序差分(TD)误差"><a href="#时序差分(TD)误差" class="headerlink" title="时序差分(TD)误差"></a>时序差分(TD)误差</h4><p>基于上面的即时奖励,我们可以计算每个 token 位置的 TD 误差:</p><script type="math/tex; mode=display">\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)</script><p><strong>三项含义</strong>:</p><ul><li>$r_t$:该 token 在该步获得的即时奖励</li><li>$\gamma V(s_{t+1})$:对下一状态价值的折扣期望($\gamma$ 通常设置为 1)</li><li>$V(s_t)$:当前状态价值的预测</li></ul><p><strong>$\delta_t$ 的含义</strong>:</p><ul><li>这是一个”惊喜”信号:实际观察($r<em>t + \gamma V(s</em>{t+1})$)vs 预期($V(s_t)$)</li><li>如果 $\delta_t > 0$:实际比预期更好</li><li>如果 $\delta_t < 0$:实际比预期更差</li><li>Critic 模型通过最小化 $\delta_t^2$ 来改进价值估计</li></ul><h4 id="广义优势估计(GAE)"><a href="#广义优势估计(GAE)" class="headerlink" title="广义优势估计(GAE)"></a>广义优势估计(GAE)</h4><p>直接使用 $\delta_t$ 作为优势估计会有方差大的问题。GAE(Generalized Advantage Estimation)通过加权求和多个 TD 误差来平衡偏差和方差:</p><script type="math/tex; mode=display">A_t^{\text{GAE}(\gamma,\lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}</script><p>其中 $\lambda \in [0,1]$ 是平衡参数。</p><p><strong>展开式(有限长度)</strong>:</p><p>对于长度为 $T$ 的序列,$t$ 时刻的 GAE 为:</p><script type="math/tex; mode=display">A_t = \sum_{l=0}^{T-t-1} (\gamma \lambda)^l \delta_{t+l}</script><p>这可以递归计算(效率更高):</p><script type="math/tex; mode=display">A_t = \delta_t + \gamma \lambda A_{t+1}</script><p><strong>GAE 的两个极端</strong>:</p><ol><li><p>当 $\lambda = 0$ 时:$A_t^{\text{GAE}} = \delta_t$(单步 TD)</p><ul><li>低方差,但可能有高偏差(如果 Critic 不准确)</li></ul></li><li><p>当 $\lambda = 1$ 时:$A<em>t^{\text{GAE}} = \sum</em>{l=0}^{\infty} \gamma^l \delta_{t+l}$(蒙特卡洛)</p><ul><li>低偏差(无模型依赖),但高方差(依赖长期轨迹)</li></ul></li></ol><p><strong>$\lambda$ 的选择</strong>:</p><ul><li>典型值:$\lambda = 0.95$</li><li>这个值在低偏差和低方差之间取得很好的平衡</li></ul><h3 id="5-3-完整的优势计算流程"><a href="#5-3-完整的优势计算流程" class="headerlink" title="5.3 完整的优势计算流程"></a>5.3 完整的优势计算流程</h3><p>让我们整理一个完整的算法流程:</p><p><strong>算法:Token 级别优势估计</strong></p><p>输入:</p><ul><li>生成序列:$s_0, a_0, s_1, a_1, \ldots, s_T, a_T$</li><li>Reward Model 输出:$r_{\text{RM}}$(在 $t=T$ 时)</li><li>Critic Model:$V_\phi$(对序列中每个位置的评估)</li><li>超参数:$\gamma$(折扣因子,通常 1),$\lambda$(GAE 参数,通常 0.95),$\beta$(KL 权重)</li></ul><p><strong>步骤 1:计算即时奖励</strong></p><script type="math/tex; mode=display">r_t = \begin{cases}-\beta \log\frac{\pi_\theta(a_t\mid s_t)}{\pi_{\mathrm{ref}}(a_t\mid s_t)}, & t < T, \\r_{\mathrm{RM}} - \beta \log\frac{\pi_\theta(a_T\mid s_T)}{\pi_{\mathrm{ref}}(a_T\mid s_T)}, & t = T.\end{cases}</script><p><strong>步骤 2:计算 TD 误差</strong></p><script type="math/tex; mode=display">\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t), \quad t = 0,\dots,T,</script><p>其中当 $t=T$ 时,约定 $V(s_{T+1})=0$。</p><p><strong>步骤 3:计算 GAE 优势</strong></p><p>末端初始化并递归计算:</p><script type="math/tex; mode=display">A_T = \delta_T,</script><p>并自后向前递归:</p><script type="math/tex; mode=display">A_t = \delta_t + \gamma \lambda A_{t+1}, \quad t = T-1, T-2,\dots,0.</script><p><strong>步骤 4:使用优势进行更新</strong></p><p>得到优势 $A[0], A[1], \ldots, A[T]$ 后:</p><ul><li>用于 PPO 策略更新:$L^{\text{CLIP}} = \mathbb{E}[\min(r_t A_t, \text{clip}(r_t) A_t)]$</li><li>用于 Critic 更新:$L^{\text{VF}} = \mathbb{E}[(V(s<em>t) - (r_t + \gamma V(s</em>{t+1})))^2]$</li></ul><h3 id="5-4-一个具体例子"><a href="#5-4-一个具体例子" class="headerlink" title="5.4 一个具体例子"></a>5.4 一个具体例子</h3><p>让我们用一个简单的数值例子来演示整个计算过程。</p><p><strong>假设</strong>:</p><ul><li>序列长度 $T = 3$(3 个 token)</li><li>$\gamma = 1$,$\lambda = 0.95$</li><li>$r_{\text{RM}} = 5$(这是一个好回答)</li><li>$\beta = 0.1$</li><li>Critic 的价值估计:$V(s_0)=2, V(s_1)=3, V(s_2)=4, V(s_3)=0$(边界)</li></ul><p><strong>计算过程</strong>:</p><ol><li><p><strong>即时奖励</strong>(假设 KL 项都是 -0.1)</p><ul><li>$r[0] = -0.1$</li><li>$r[1] = -0.1$</li><li>$r[2] = 5 - 0.1 = 4.9$</li></ul></li><li><p><strong>TD 误差</strong></p><ul><li>$\delta[0] = -0.1 + 1 \times 3 - 2 = 0.9$</li><li>$\delta[1] = -0.1 + 1 \times 4 - 3 = 0.9$</li><li>$\delta[2] = 4.9 + 1 \times 0 - 4 = 0.9$</li></ul></li><li><p><strong>GAE 优势</strong>(从后向前)</p><ul><li>$A[2] = 0.9$</li><li>$A[1] = 0.9 + 0.95 \times 0.9 = 1.755$</li><li>$A[0] = 0.9 + 0.95 \times 1.755 = 2.567$</li></ul></li></ol><p><strong>结果解释</strong>:</p><ul><li>所有 token 都有正的优势,鼓励增加其生成概率</li><li>越早的 token 由于能”看到”后续的奖励,优势越大</li><li>最后一个 token 的优势最小(0.9),因为它直接看到 RM 的奖励</li></ul><h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>PPO 是一个巧妙且实用的强化学习算法,在 LLM 的 RLHF 训练中发挥了关键作用。它的核心优势包括:</p><ol><li><strong>稳定的策略更新</strong>:通过裁剪机制限制策略偏离,避免训练不稳定</li><li><strong>灵活的约束机制</strong>:通过 KL 散度约束保持模型的原有能力</li><li><strong>高效的方差减少</strong>:通过优势函数作为基线大大降低学习信号的方差</li><li><strong>实用的设计</strong>:多个可调的超参数让算法在不同场景下都能表现良好</li></ol><p>理解 PPO 的这些核心原理对于:</p><ul><li>深入掌握 LLM 的后训练方法</li><li>调优 RLHF 训练过程</li><li>开发新的强化学习训练方法</li></ul><p>都是至关重要的。希望通过本文的详细讲解,能帮助你充分理解 PPO 在 LLM 训练中的应用与实现细节。</p>]]></content>
<summary type="html">
<p>本文详细讲解大模型 RLHF 阶段使用的 PPO(Proximal Policy Optimization)训练原理和实现细节。我们将从强化学习基础概念开始,逐步深入到 PPO 在 LLM 中的具体应用,最后分析 PPO 的损失函数和优势估计计算。</p>
</summary>
<category term="LLM" scheme="https://murphypei.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://murphypei.github.io/tags/LLM/"/>
<category term="RLHF" scheme="https://murphypei.github.io/tags/RLHF/"/>
<category term="PPO" scheme="https://murphypei.github.io/tags/PPO/"/>
<category term="强化学习" scheme="https://murphypei.github.io/tags/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0/"/>
</entry>
<entry>
<title>图像生成基础:DDPM</title>
<link href="https://murphypei.github.io/blog/2024/07/aigc-ddpm.html"/>
<id>https://murphypei.github.io/blog/2024/07/aigc-ddpm.html</id>
<published>2024-07-31T09:44:51.000Z</published>
<updated>2026-01-07T08:41:01.085Z</updated>
<content type="html"><![CDATA[<p>目前所采用的扩散模型大都是来自于 2020 年的工作 DDPM。DDPM 对之前的扩散模型进行了简化,并通过变分推断(variational inference)来进行建模,这主要是因为扩散模型也是一个隐变量模型(latent variable model),相比 VAE 这样的隐变量模型,扩散模型的隐变量是和原始数据是同维度的,而且推理过程(即扩散过程)往往是固定的。</p><span id="more"></span><h2 id="引言"><a href="#引言" class="headerlink" title="引言"></a>引言</h2><p>扩散模型包括两个过程:前向过程(forward process)和反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process),如下图所示。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成数据,这里我们将通过变分推断来进行建模和求解。</p><p><img src="/images/posts/aigc/ddpm/1.webp" alt=""></p><h3 id="前向扩散过程(Forward-Diffusion-Process)"><a href="#前向扩散过程(Forward-Diffusion-Process)" class="headerlink" title="前向扩散过程(Forward Diffusion Process)"></a>前向扩散过程(Forward Diffusion Process)</h3><p>解释扩散之前先介绍一个基本的数学表示:</p><script type="math/tex; mode=display">\mathcal{N}(x_t; \mu, \Sigma)</script><p>一个正态分布,其中 $\mu$ 是均值,$\Sigma$ 是协方差矩阵。在这个过程中,$x_t$ 服从一个 $\mu$ 为均值、$\Sigma$ 为协方差矩阵的正态分布。</p><p>扩散就是对图像数据进行加噪声的过程,<strong>最核心的数学公式</strong>表示如下:</p><script type="math/tex; mode=display">q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t \mathbf{I})</script><p>$x_0$ 是原始数据,$x_t$ 是在 $t$ 时刻的样本,$\mathcal{N}$ 表示正态分布,$\beta_t$ 表示在第 $t$ 步的方差(噪音量),它是一个介于 0 和 1 之间的值,$\sqrt{1 - \beta_t}$ 表示输入数据的缩放系数,$\beta_t \mathbf{I}$ 表示加的噪音的方差。</p><p>这个公式表示,给定 $x<em>{t-1}$ 的情况下,$x_t$ 是以 $\sqrt{1 - \beta_t} x</em>{t-1}$ 为均值、$\beta<em>t \mathbf{I}$ 为协方差矩阵的正态分布。可以简单理解为,$x_t$ 是 $x</em>{t-1}$ 加上高斯噪音后的结果。</p><p>前向过程就是这么简单。当我们逐渐加大 $\beta_t$ 时,$x_t$ 逐渐变得模糊,最终变成一个高斯噪声图像。</p><script type="math/tex; mode=display">q(x_T|x_0) \approx \mathcal{N}(0, I)</script><p>这里也有一个推导,就是通过 $x_0$,可以直接表示 $x_T$,因为高斯分布可以直接相加。</p><h3 id="逆向生成过程(Reverse-Generation-Process)"><a href="#逆向生成过程(Reverse-Generation-Process)" class="headerlink" title="逆向生成过程(Reverse Generation Process)"></a>逆向生成过程(Reverse Generation Process)</h3><p>训练过程中,DDPM 学习从噪声生成数据的逆向过程。我们<strong>假设逆向过程也是一个高斯过程</strong>,但参数未知:</p><script type="math/tex; mode=display">p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \sigma_\theta^2(x_t, t) I)</script><p>这里,模型的任务是学习 $\mu<em>\theta$ 和 $\sigma</em>\theta$ 的参数化形式,使得可以从噪声生成逼真的数据样本。</p><h3 id="训练目标"><a href="#训练目标" class="headerlink" title="训练目标"></a>训练目标</h3><p>训练的目标是最小化前向过程和逆向过程之间的差异。具体来说,训练目标可以表示为以下 KL 散度的和:</p><script type="math/tex; mode=display">L = \sum_{t=1}^{T} D_{KL}\left(q(x_{t-1}|x_t, x_0) \| p_\theta(x_{t-1}|x_t)\right)</script><p>每一个 KL 项衡量在第 $t$ 个时间步长上真实分布和模型估计分布之间的差异。</p><h3 id="损失函数"><a href="#损失函数" class="headerlink" title="损失函数"></a>损失函数</h3><p>为了简化训练过程,我们可以<strong>重参数化</strong>损失函数为一个去噪过程的预测任务。目标变为预测加入噪声的程度(噪声项)的均值和方差。</p><blockquote><p>这里重参数的推导很长,可以网上找一下。</p></blockquote><script type="math/tex; mode=display">L = \mathbb{E}_{q(x_0, \epsilon)} \left[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]</script><p>其中 $\epsilon$ 是在前向过程加入的数据噪声,$\epsilon_\theta$ 是通过神经网络预测的噪声。所以神经网络的任务就是抽取一个 $t$(1~$T$ 之间),通过 $x_0$ 和加噪过程,计算得到 $x_t$,然后神经网络预测噪声,计算预测的噪声和实际噪声的分布差异。</p><h3 id="训练步骤"><a href="#训练步骤" class="headerlink" title="训练步骤"></a>训练步骤</h3><ol><li><strong>采样数据 $x_0$</strong> 从真实数据分布中。</li><li><strong>采样噪声 $\epsilon$</strong> 从标准正态分布中。</li><li><strong>计算 $x_t$</strong> 通过前向扩散过程,将噪声加入数据。</li><li><strong>计算预测的噪声 $\epsilon_\theta(x_t, t)$</strong> 使用神经网络。</li><li><strong>计算损失 $L$</strong> 并通过反向传播更新模型参数。</li></ol><h3 id="生成步骤"><a href="#生成步骤" class="headerlink" title="生成步骤"></a>生成步骤</h3><p>生成数据时,从标准正态分布中采样 $x_T$,然后逐步通过逆向生成过程去噪,生成数据 $x_0$。</p><p>在 DDPM 中,会将原始图像的像素值从 [0, 255] 范围归一化到 [-1, 1],像素值属于离散化值。</p><h3 id="背后原理"><a href="#背后原理" class="headerlink" title="背后原理"></a>背后原理</h3><p>DDPM 通过一个称为”马尔科夫链”的过程,逐步将噪声转化为数据。其核心思想是分阶段进行去噪,每个阶段只去除一小部分噪声,使得每一步的去噪过程更为简单和稳定。</p><p>总的来说,DDPM 在生成任务中表现出色,特别是生成图像和其他复杂结构的数据类型。这是因为它通过多步生成过程有效地捕捉了数据的复杂结构和细节。</p><p>DDPM 的推导过程中,最重要的就是重参数技巧,这个技巧在很多生成模型中都有应用,比如 VAE、GAN 等。</p>]]></content>
<summary type="html">
<p>目前所采用的扩散模型大都是来自于 2020 年的工作 DDPM。DDPM 对之前的扩散模型进行了简化,并通过变分推断(variational inference)来进行建模,这主要是因为扩散模型也是一个隐变量模型(latent variable model),相比 VAE 这样的隐变量模型,扩散模型的隐变量是和原始数据是同维度的,而且推理过程(即扩散过程)往往是固定的。</p>
</summary>
<category term="AIGC" scheme="https://murphypei.github.io/categories/AIGC/"/>
<category term="AIGC" scheme="https://murphypei.github.io/tags/AIGC/"/>
<category term="DDPM" scheme="https://murphypei.github.io/tags/DDPM/"/>
<category term="图像生成" scheme="https://murphypei.github.io/tags/%E5%9B%BE%E5%83%8F%E7%94%9F%E6%88%90/"/>
</entry>
<entry>
<title>STL 旋转序列算法 rotate</title>
<link href="https://murphypei.github.io/blog/2022/12/stl-rotate.html"/>
<id>https://murphypei.github.io/blog/2022/12/stl-rotate.html</id>
<published>2022-12-07T03:35:51.000Z</published>
<updated>2026-01-07T08:41:01.085Z</updated>
<content type="html"><![CDATA[<p>最近开发需要不管刷新缓冲区,发现了一个有用的 STL 算法。</p><span id="more"></span><p>先说明应用场景:我有一块缓冲区 vector,不断接收数据和消费数据(生产消费模型),接收数据就放在末尾,消费头部数据,消费完删除。之前用 realloc 和 memmove 来操作,改为 vector 之后如果每次搬移数据就很麻烦了,查了一下发现 <a href="https://en.cppreference.com/w/cpp/algorithm/rotate">rotate</a> 配合 resize 可以搞定。</p><p>std::rotate() 的第一个参数是这个序列的开始迭代器;第二个参数是指向新的第一个元素的迭代器,<strong>它必定在序列之内</strong>。第三个参数是这个序列的结束迭代器。意思是将第二个参数的元素旋转到第一个参数的位置,旋转的序列是第一个参数到第三个参数的范围。</p><p>你可以想象第一个参数 ~ 第三个参数之间的元素序列组成一个圆盘,左转就是逆时针旋转,直到第二个参数转到第一个参数的位置,旋转结束。</p><p>可以参数<a href="http://c.biancheng.net/view/609.html">图解</a></p><p>旋转完成后,头部就变到 vector 末尾了,用 resize 可以标记删除掉这些元素。</p>]]></content>
<summary type="html">
<p>最近开发需要不管刷新缓冲区,发现了一个有用的 STL 算法。</p>
</summary>
<category term="C/C++" scheme="https://murphypei.github.io/categories/C-C/"/>
<category term="C++" scheme="https://murphypei.github.io/tags/C/"/>
<category term="STL" scheme="https://murphypei.github.io/tags/STL/"/>
<category term="stl" scheme="https://murphypei.github.io/tags/stl/"/>
<category term="rotate" scheme="https://murphypei.github.io/tags/rotate/"/>
</entry>
<entry>
<title>vscode C++ 开发之使用 clangd、C/C++、clang-format</title>
<link href="https://murphypei.github.io/blog/2022/12/vscode-clang-format.html"/>
<id>https://murphypei.github.io/blog/2022/12/vscode-clang-format.html</id>
<published>2022-12-07T03:25:55.000Z</published>
<updated>2026-01-07T08:41:01.085Z</updated>
<content type="html"><![CDATA[<p>最近比较忙,废话少说,vscode 开发 C/C++ 需要很繁琐的配置,之前也说过 launch 和 tasks 的配置。这篇文章主要结合自身使用经历讲讲 C++ 相关插件。</p><span id="more"></span><p>vscode 最常用的几个 C++ 插件(不包含 cmake)就是微软的 C/C++、LLVM 的 clangd,以前我也使用 C/C++,但是智能补全和提示、include 路径都太差劲了,转投 clangd 了,确实好用。所以不废话,直接推荐使用 clangd,不过 C/C++ 也在用,为了二者不冲突,需要配置如下:</p><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><span class="line"><span class="attr">"C_Cpp.autocomplete"</span><span class="punctuation">:</span> <span class="string">"Disabled"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.clang_format_fallbackStyle"</span><span class="punctuation">:</span> <span class="string">"Visual Studio"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.clang_format_sortIncludes"</span><span class="punctuation">:</span> <span class="literal"><span class="keyword">true</span></span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.clang_format_style"</span><span class="punctuation">:</span> <span class="string">"file"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.default.compilerPath"</span><span class="punctuation">:</span> <span class="string">"/usr/bin/g++"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.default.configurationProvider"</span><span class="punctuation">:</span> <span class="string">"ms-vscode.cmake-tools"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.default.cppStandard"</span><span class="punctuation">:</span> <span class="string">"c++11"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.default.cStandard"</span><span class="punctuation">:</span> <span class="string">"c99"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.default.intelliSenseMode"</span><span class="punctuation">:</span> <span class="string">"gcc-x64"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.errorSquiggles"</span><span class="punctuation">:</span> <span class="string">"Disabled"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"C_Cpp.intelliSenseEngine"</span><span class="punctuation">:</span> <span class="string">"Disabled"</span><span class="punctuation">,</span></span><br><span class="line"><span class="attr">"clangd.arguments"</span><span class="punctuation">:</span> <span class="punctuation">[</span></span><br><span class="line"><span class="comment">// 在后台自动分析文件(基于complie_commands)</span></span><br><span class="line"><span class="string">"--background-index"</span><span class="punctuation">,</span></span><br><span class="line"><span class="string">"--compile-commands-dir=${workspaceFolder}/build"</span><span class="punctuation">,</span></span><br><span class="line"><span class="string">"-j=8"</span><span class="punctuation">,</span></span><br><span class="line"><span class="comment">// 支持 .clangd 配置</span></span><br><span class="line"><span class="string">"--enable-config"</span><span class="punctuation">,</span></span><br><span class="line"><span class="string">"--clang-tidy"</span><span class="punctuation">,</span></span><br><span class="line"><span class="string">"--clang-tidy-checks=performance-*,bugprone-*"</span><span class="punctuation">,</span></span><br><span class="line"><span class="string">"--log=verbose"</span><span class="punctuation">,</span></span><br><span class="line"><span class="string">"--pretty"</span><span class="punctuation">,</span></span><br><span class="line"><span class="comment">// 全局补全(会自动补充头文件)</span></span><br><span class="line"><span class="string">"--all-scopes-completion"</span><span class="punctuation">,</span></span><br><span class="line"><span class="comment">// 更详细的补全内容</span></span><br><span class="line"><span class="string">"--completion-style=detailed"</span><span class="punctuation">,</span></span><br><span class="line"><span class="comment">// 补充头文件的形式</span></span><br><span class="line"><span class="string">"--header-insertion=iwyu"</span><span class="punctuation">,</span></span><br><span class="line"><span class="comment">// pch优化的位置</span></span><br><span class="line"><span class="string">"--pch-storage=memory"</span><span class="punctuation">,</span></span><br><span class="line"><span class="string">"--function-arg-placeholders"</span><span class="punctuation">,</span></span><br><span class="line"><span class="punctuation">]</span><span class="punctuation">,</span></span><br></pre></td></tr></table></figure><p>clangd 的 include 可以通过如下配置:</p><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="attr">"clangd.fallbackFlags"</span><span class="punctuation">:</span> <span class="punctuation">[</span></span><br><span class="line"> <span class="string">"-std=c++11"</span><span class="punctuation">,</span></span><br><span class="line"> <span class="string">"-I/usr/include/c++/9"</span><span class="punctuation">,</span></span><br><span class="line"> <span class="string">"-I/usr/include/opencv4"</span><span class="punctuation">,</span></span><br><span class="line"> <span class="string">"-I${workspaceFolder}/src/"</span><span class="punctuation">,</span></span><br><span class="line"><span class="punctuation">]</span></span><br></pre></td></tr></table></figure><p>clangd 虽然很香,但是有个明显的缺点,就是它一定要使用自身的 clang-format 来格式化,而且无法配置使用 .clang-format 文件。为此,需要安装另一个插件 xaver clang-format。安装完成后配置格式化的程序:</p><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="attr">"[cpp]"</span><span class="punctuation">:</span> <span class="punctuation">{</span></span><br><span class="line"><span class="comment">// "editor.defaultFormatter": "llvm-vs-code-extensions.vscode-clangd"</span></span><br><span class="line"><span class="attr">"editor.defaultFormatter"</span><span class="punctuation">:</span> <span class="string">"xaver.clang-format"</span></span><br><span class="line"><span class="punctuation">}</span><span class="punctuation">,</span></span><br></pre></td></tr></table></figure><p>这个插件可以直接调用项目根目录下的 .clang-format 文件来格式化。</p><p>最后,有条件的推荐使用 clion 来开发和调试 C++。</p>]]></content>
<summary type="html">
<p>最近比较忙,废话少说,vscode 开发 C/C++ 需要很繁琐的配置,之前也说过 launch 和 tasks 的配置。这篇文章主要结合自身使用经历讲讲 C++ 相关插件。</p>
</summary>
<category term="C/C++" scheme="https://murphypei.github.io/categories/C-C/"/>
<category term="C++" scheme="https://murphypei.github.io/tags/C/"/>
<category term="vscode" scheme="https://murphypei.github.io/tags/vscode/"/>
<category term="clangd" scheme="https://murphypei.github.io/tags/clangd/"/>
<category term="clang-format" scheme="https://murphypei.github.io/tags/clang-format/"/>
</entry>
<entry>
<title>golang select 机制和超时</title>
<link href="https://murphypei.github.io/blog/2022/06/go-select-timeout.html"/>
<id>https://murphypei.github.io/blog/2022/06/go-select-timeout.html</id>
<published>2022-06-25T06:24:59.000Z</published>
<updated>2026-01-07T08:41:01.085Z</updated>
<content type="html"><![CDATA[<p>golang 中的协程使用非常方便,但是协程什么时候结束是一个控制问题,可以用 select 配合使用。</p><span id="more"></span><p>首先声明,golang 使用并不熟悉,本文仅仅是记录使用过程中遇到的一些坑。</p><p>子协程和父协程的通信通常用 context 或者 chan。我遇到一个通常的使用场景,在子协程中尝试多次处理,父协程等待一段时间超时,我选择用 chan 实现。我以为 select 和 C++ 中 switch 类似,所以最开始代码类似如下:</p><figure class="highlight golang"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> {</span><br><span class="line"> <span class="keyword">select</span> {</span><br><span class="line"> <span class="keyword">case</span> <-ctx.Done():</span><br><span class="line"> <span class="comment">// process ctx done</span></span><br><span class="line"> <span class="keyword">case</span> <-time.After(time.Second * <span class="number">3</span>):</span><br><span class="line"> <span class="comment">// process after</span></span><br><span class="line"> <span class="keyword">default</span>:</span><br><span class="line"> <span class="comment">// process code</span></span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>测试发现无法实现 timeout,又仔细查看文档,才发现 golang 中 select 另有玄机。废话少说,直接总结要点:</p><ul><li>select 中的 case 必须是进行 chan 的手法操作,也就是只能在 case 中操作 chan,并且是<strong>非阻塞接收</strong>。</li><li>select 中的 case 是同时监听的,多个 case 同时操作,并未 switch 中一个个顺序判断。如果多个 case 满足要求,随机执行一个,如果一个没有则阻塞当前的协程(没有 default 情况下)。<strong>很类似 Linux 文件符操作的 select 语义</strong>。</li><li>上面说的阻塞是没有 default 的情况下,如果有 default,则执行 default,然后退出 select,也就是不会阻塞当前协程。</li></ul><p>回到上述代码,我这个 select 会一直不断的执行 default,<code>time.After</code> 生成的 chan 并不会被阻塞判断,所以根本无法完成我想要的效果。理解了之后重新修改代码:</p><figure class="highlight golang"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line">done := <span class="built_in">make</span>(char <span class="type">int</span>)</span><br><span class="line"><span class="keyword">go</span> <span class="function"><span class="keyword">func</span><span class="params">(c <span class="keyword">chan</span> <span class="type">int</span>)</span></span> {</span><br><span class="line"> <span class="keyword">for</span> {</span><br><span class="line"> <span class="comment">// process code</span></span><br><span class="line"> <span class="keyword">if</span> {</span><br><span class="line"> c <- <span class="number">1</span></span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> c <- <span class="number">0</span></span><br><span class="line">}(done)</span><br><span class="line"></span><br><span class="line"><span class="keyword">select</span> {</span><br><span class="line"> <span class="keyword">case</span> <-ctx.Done():</span><br><span class="line"> <span class="comment">// process ctx done</span></span><br><span class="line"> <span class="keyword">case</span> <-time.After(time.Second * <span class="number">3</span>):</span><br><span class="line"> <span class="comment">// process after</span></span><br><span class="line"> <span class="keyword">case</span> <-done:</span><br><span class="line"> <span class="comment">// process code</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>开一个新的协程去不断尝试,在外的三个 case 有一个满足,则会执行。但是这里有一个问题非常需要注意:<strong>子协程什么时候退出?</strong>。</p><p>因为 gorountine 不能被强制 kill,所以在上述超时的情况下,select 语句执行 <code>case time.After</code> 之后退出,<code>done</code> 这个 chan 已经没有接受方了,因此既没有接受者,又没有缓冲区,结合 chan 的特性,则子协程会一直阻塞无法退出,所以本质上这个实现会导致子协程累积下去,也就是<strong>协程泄露</strong>,可能会使资源耗尽。</p><p>如何避免上述问题呢?一个很简单的想法就是提供缓冲区,<code>done := make(char int, 1)</code>,这样即使没有接收方,子协程也能完成发送,不会被阻塞。</p><p>还要一种办法,上面说了,select 操作 chan,并且可以指定 default,那是不是有思路了呢?</p><figure class="highlight golang"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> {</span><br><span class="line"> <span class="keyword">select</span> {</span><br><span class="line"> <span class="keyword">case</span> done <- <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">default</span>:</span><br><span class="line"> <span class="keyword">return</span></span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure><p>我们尝试往 chan 中发送,如果发不出去,则就退出,也实现了目的。</p><p>最后总结一下,goroutine 泄露的防范条例:</p><ul><li>创建 goroutine 时就要想好该 goroutine 该如何结束。</li><li>使用 chan 时,要考虑到 chan 阻塞时协程可能的行为。</li><li>实现循环语句时注意循环的退出条件,避免死循环。</li></ul>]]></content>
<summary type="html">
<p>golang 中的协程使用非常方便,但是协程什么时候结束是一个控制问题,可以用 select 配合使用。</p>
</summary>
<category term="Golang" scheme="https://murphypei.github.io/categories/Golang/"/>
<category term="select" scheme="https://murphypei.github.io/tags/select/"/>
<category term="golang" scheme="https://murphypei.github.io/tags/golang/"/>
<category term="case" scheme="https://murphypei.github.io/tags/case/"/>
<category term="timeout" scheme="https://murphypei.github.io/tags/timeout/"/>
<category term="after" scheme="https://murphypei.github.io/tags/after/"/>
<category term="chan" scheme="https://murphypei.github.io/tags/chan/"/>
</entry>
<entry>
<title>C++ 链接一个不需要的库(--no-as-needed)</title>
<link href="https://murphypei.github.io/blog/2022/04/link-noneed-lib.html"/>
<id>https://murphypei.github.io/blog/2022/04/link-noneed-lib.html</id>
<published>2022-04-18T09:16:49.000Z</published>
<updated>2026-01-07T08:41:01.085Z</updated>
<content type="html"><![CDATA[<p>使用 libtorch 的 C++ 动态链接库遇到了一个非常诡异的问题…</p><span id="more"></span><p>我使用 libtorch 的库编译了一个语音识别程序,使用 CPU 推理,能够完美运行,然后在 go 中对这个程序封装了一层 GRPC,也都 OK。</p><p>但是当我想用 GPU 推理的时候,我直接下载了 libtorch 的 <a href="https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.11.0%2Bcu113.zip">GPU 库</a>,然后直接编译语音程序(需要修改 <code>torch::Device</code>),可以直接跑在 GPU 上了,很开心。</p><p>但是我用第二次编译出来的库放到 go 程序中,则出现了诡异的错误,运行加载模型的时候,<code>model->to_device</code>,而且 <code>device_count</code> 为 0,很明显,程序没找到 GPU。</p><p>利用 ldd 查看 go 编译出来的可执行文件,发现没有链接到 <code>torch_cuda_*</code> 这些库,怎么会这么奇怪呢?我明明把这些库放到编译的 flags 中了。为此我反复调整了链接的 flag,包括库的顺序,库的路径等等,但是都无济于事。</p><p>几经辗转,终于找到一个和我类似的错误了。<a href="https://github.com/pytorch/pytorch/issues/72396">https://github.com/pytorch/pytorch/issues/72396</a></p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br></pre></td><td class="code"><pre><span class="line">Could not run 'aten::empty_strided' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::empty_strided' is only available for these backends: [CPU, Meta, BackendSelect, Python, Named, Conjugate, Negative, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, UNKNOWN_TENSOR_TYPE_ID, Autocast, Batched, VmapMode].</span><br><span class="line"></span><br><span class="line">CPU: registered at aten\src\ATen\RegisterCPU.cpp:18433 [kernel]</span><br><span class="line">Meta: registered at aten\src\ATen\RegisterMeta.cpp:12703 [kernel]</span><br><span class="line">BackendSelect: registered at aten\src\ATen\RegisterBackendSelect.cpp:665 [kernel]</span><br><span class="line">Python: registered at ..\..\aten\src\ATen\core\PythonFallbackKernel.cpp:47 [backend fallback]</span><br><span class="line">Named: registered at ..\..\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]</span><br><span class="line">Conjugate: fallthrough registered at ..\..\aten\src\ATen\ConjugateFallback.cpp:22 [kernel]</span><br><span class="line">Negative: fallthrough registered at ..\..\aten\src\ATen\native\NegateFallback.cpp:22 [kernel]</span><br><span class="line">ADInplaceOrView: fallthrough registered at ..\..\aten\src\ATen\core\VariableFallbackKernel.cpp:64 [backend fallback]</span><br><span class="line">AutogradOther: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradCPU: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradCUDA: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradXLA: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradLazy: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradXPU: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradMLC: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradHPU: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradNestedTensor: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradPrivateUse1: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradPrivateUse2: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">AutogradPrivateUse3: registered at ..\..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]</span><br><span class="line">Tracer: registered at ..\..\torch\csrc\autograd\generated\TraceType_2.cpp:11423 [kernel]</span><br><span class="line">UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at ..\..\aten\src\ATen\autocast_mode.cpp:466 [backend fallback]</span><br><span class="line">Autocast: fallthrough registered at ..\..\aten\src\ATen\autocast_mode.cpp:305 [backend fallback]</span><br><span class="line">Batched: registered at ..\..\aten\src\ATen\BatchingRegistrations.cpp:1016 [backend fallback]</span><br><span class="line">VmapMode: fallthrough registered at ..\..\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]</span><br><span class="line"></span><br><span class="line">Exception raised from reportError at ..\..\aten\src\ATen\core\dispatch\OperatorEntry.cpp:431 (most recent call first):</span><br><span class="line">00007FFEE7CAA29200007FFEE7CAA230 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFEE7C843C500007FFEE7C84350 c10.dll!c10::NotImplementedError::NotImplementedError [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F015C7100007FFD5F015AA0 torch_cpu.dll!c10::impl::OperatorEntry::reportError [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F6C6AF000007FFD5F66DBB0 torch_cpu.dll!at::_ops::xlogy_Tensor::redispatch [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F8E73F100007FFD5F8CF610 torch_cpu.dll!at::_ops::zeros_out::redispatch [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F8E3F7400007FFD5F8CF610 torch_cpu.dll!at::_ops::zeros_out::redispatch [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F6EB6E800007FFD5F6EB520 torch_cpu.dll!at::_ops::empty_strided::call [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5EF259CB00007FFD5EF258D0 torch_cpu.dll!at::empty_strided [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F2C24D100007FFD5F2C2130 torch_cpu.dll!at::native::_to_copy [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5FA7C3D600007FFD5FA7BF10 torch_cpu.dll!at::compositeexplicitautograd::xlogy_ [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5FA5A8FB00007FFD5FA3F310 torch_cpu.dll!at::compositeexplicitautograd::bitwise_xor_outf [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F4EB5AD00007FFD5F45B290 torch_cpu.dll!at::TensorMaker::make_tensor [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F8DED7700007FFD5F8CF610 torch_cpu.dll!at::_ops::zeros_out::redispatch [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F8E36EB00007FFD5F8CF610 torch_cpu.dll!at::_ops::zeros_out::redispatch [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F4EB5AD00007FFD5F45B290 torch_cpu.dll!at::TensorMaker::make_tensor [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F56326800007FFD5F563190 torch_cpu.dll!at::_ops::_to_copy::redispatch [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD60A27F0000007FFD60A27A30 torch_cpu.dll!at::redispatch::_thnn_fused_lstm_cell_backward [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD60A4031D00007FFD60A34930 torch_cpu.dll!torch::jit::Node::c_ [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F50C12B00007FFD5F50BF70 torch_cpu.dll!at::_ops::_to_copy::call [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F2C2E7900007FFD5F2C2BD0 torch_cpu.dll!at::native::to_dense_backward [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F2C2B0C00007FFD5F2C29E0 torch_cpu.dll!at::native::to [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5FB6A66800007FFD5FB63F10 torch_cpu.dll!at::compositeimplicitautograd::where [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5FB4DB5D00007FFD5FB1BE50 torch_cpu.dll!at::compositeimplicitautograd::broadcast_to [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5F7E6F4600007FFD5F7E6D70 torch_cpu.dll!at::_ops::to_dtype_layout::call [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5EF4AA8800007FFD5EF4A970 torch_cpu.dll!at::Tensor::to [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFD5EF9EAE900007FFD5EF9E9F0 torch_cpu.dll!at::tensor [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FF7714295A200007FF7714294B0 SplinterlandsSimulator.exe!main [C:\Users\xargo\source\repos\SplinterlandsSimulator\SplinterlandsSimulator\SplinterlandsSimulator.cpp @ 390]</span><br><span class="line">00007FF77144164C00007FF771441540 SplinterlandsSimulator.exe!__scrt_common_main_seh [d:\a01\_work\20\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl @ 288]</span><br><span class="line">00007FFF47C554E000007FFF47C554D0 KERNEL32.DLL!BaseThreadInitThunk [<unknown file> @ <unknown line number>]</span><br><span class="line">00007FFF48DA485B00007FFF48DA4830 ntdll.dll!RtlUserThreadStart [<unknown file> @ <unknown line number>]</span><br></pre></td></tr></table></figure><p>上述报错跟我的很像,而且从下面的回复来看,也是没能链接到 cuda 相应的库。下面的回复给我了启发:<strong>如果我的 go 程序没用到 libtorch 的 cuda 接口,是不是不会主动链接到 libtorch 相应的 cuda 的库</strong>?</p><p>前面说了,ldd 查看的确实没有,那怎么让编译器强制链接到 libtorch 的 cuda 相应的库呢?显然是的,编译器默认使用了 <code>--as-needed</code> 编译参数,这也是合理的,我们没必要链接所有的动态库,动态库本来就是按需链接,但是在我们的这个使用场景中,会遇到这种特殊情况,使用 <code>--no-as-needed</code> 强制链接到 libtorch cuda 的相应库,结果就没有问题了。</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">--as-needed</span><br><span class="line">--no-as-needed</span><br><span class="line">This option affects ELF DT_NEEDED tags for dynamic libraries mentioned on the command line after the --as-needed option. Normally the linker will add a DT_NEEDED tag for each dynamic library mentioned on the command line, regardless of whether the library is actually needed or not. --as-needed causes a DT_NEEDED tag to only be emitted for a library that satisfies an undefined symbol reference from a regular object file or, if the library is not found in the DT_NEEDED lists of other libraries linked up to that point, an undefined symbol reference from another dynamic library. --no-as-needed restores the default behaviour.</span><br></pre></td></tr></table></figure>]]></content>
<summary type="html">
<p>使用 libtorch 的 C++ 动态链接库遇到了一个非常诡异的问题…</p>
</summary>
<category term="C/C++" scheme="https://murphypei.github.io/categories/C-C/"/>
<category term="C++" scheme="https://murphypei.github.io/tags/C/"/>
<category term="link" scheme="https://murphypei.github.io/tags/link/"/>
<category term="no-as-needed" scheme="https://murphypei.github.io/tags/no-as-needed/"/>
<category term="undef" scheme="https://murphypei.github.io/tags/undef/"/>
</entry>
</feed>