Skip to content

feat: add enable_fp32_lm_head configuration for model precision#849

Merged
LLLLKKKK merged 1 commit intoalibaba:mainfrom
Vinkle-hzt:fix_gemm
Apr 10, 2026
Merged

feat: add enable_fp32_lm_head configuration for model precision#849
LLLLKKKK merged 1 commit intoalibaba:mainfrom
Vinkle-hzt:fix_gemm

Conversation

@Vinkle-hzt
Copy link
Copy Markdown
Collaborator

No description provided.

@Vinkle-hzt Vinkle-hzt marked this pull request as ready for review April 1, 2026 04:02
@Vinkle-hzt Vinkle-hzt requested a review from LLLLKKKK as a code owner April 1, 2026 04:02
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

LLLLKKKK commented Apr 2, 2026

🤖 Code Review (v1) — LGTM

Verdict: LGTM ✅

forwardPostLayers 中 logits 计算的 FP32 转换从 gemm 之前移到之后:

// Before:
auto logits = torch::mm(last_hidden.to(torch::kFloat32), lm_head->kernel.to(torch::kFloat32).t());

// After:
auto logits = torch::mm(last_hidden, lm_head->kernel.t()).to(torch::kFloat32);

正确的性能优化。原实现将两个 tensor 都先转 FP32 再做 GEMM,浪费内存且无法利用 tensor core。改后在原始精度下做 GEMM,仅在最后转 FP32,是标准做法。


Automated review by CI Bot

@Vinkle-hzt Vinkle-hzt force-pushed the fix_gemm branch 4 times, most recently from 8d7174a to c0e7f5b Compare April 9, 2026 06:34
@Vinkle-hzt Vinkle-hzt force-pushed the fix_gemm branch 2 times, most recently from b6ee732 to 114a195 Compare April 10, 2026 02:29
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review (incremental) — PR #849
Head SHA: 114a195a8cbe | Previous: b6ee732a6078 | Verdict: LGTM

Changes since last review

Single new commit covering two themes:

  1. Core optimization in PyWrappedModel.cc: lm_head GEMM now runs in the kernel's native dtype (e.g. FP16/BF16) and casts the result to FP32 afterward, instead of casting both inputs to FP32 before GEMM. Controlled by enable_fp32_lm_head (default true = old behavior preserved).

  2. Refactoring NormalBatchStreamProcessor: The monolithic ~650-line .cc decomposed into three focused classes:

    • NormalModelInputGatherer — buffer allocation + model input gathering
    • NormalSamplerInputGatherer — sampler input gathering
    • NormalOutputDispatcher — output dispatch to streams
  3. Supporting changes: StreamGroups splits block update copy counts by decode/context; MemoryLayoutStrategy avoids temp tensor creation in getBlockPtr(); config plumbing for enable_fp32_lm_head through ModelArgs/ModelConfig/ModelWeightInfo/CLI.

Findings

No blocking issues. Clean decomposition, correct config plumbing, tests updated properly.

  • P2 (cosmetic): Minor indentation inconsistency in NormalOutputDispatcher.cc method implementations.
  • Default enable_fp32_lm_head=True preserves backward compat. Good.
  • getBlockPtr bounds check is good defensive coding for the perf-critical path change.

LGTM overall.

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review (incremental) — PR #849
Head SHA: db5100077ae0 | Previous: 114a195a8cbe | Verdict: P1

Changes since last review

New commit db5100077ae0 adds enable_fp32_lm_head config (default True) controlling whether lm_head weight is loaded as FP32. The core change moves the float32 cast from before the GEMM to after — matmul now runs in the kernel's native dtype, then casts to FP32. Applies to both C++ (PyWrappedModel.cc) and Python (auto_model.py) paths.

Findings

P1: Bug — .dtype() called as method on torch.Tensor in Python
rtp_llm/models_py/standalone/auto_model.py ~line 285:

hidden_states.to(self.lm_head_weight.dtype())

self.lm_head_weight is a torch.Tensor. In Python, torch.Tensor.dtype is a property, not a method — calling .dtype() will raise TypeError: cannot create 'torch.dtype' instances. The C++ counterpart (lm_head->kernel.dtype()) is correct (libtorch member function), but the Python code incorrectly mirrors that convention.
Fix: self.lm_head_weight.dtype (remove parentheses).

P2: Default True forces FP32 lm_head even for quantized models
_fix_fp32_lm_head unconditionally sets weight.data_type = torch.float32. This preserves old behavior (previous code cast to FP32 before matmul), but worth confirming this is intentional for all quantization schemes.

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review (incremental) — PR #849
Head SHA: ac67fd9cdf0e | Previous: db5100077ae0 | Verdict: P2

Previous P1 status

  • .dtype() parentheses bug: FIXEDauto_model.py now correctly uses .dtype (property) instead of .dtype(). The C++ .dtype() in PyWrappedModel.cc is fine (libtorch method call).

Changes since last review

One new commit ac67fd9cdf0e adds enable_fp32_lm_head config (default True) and changes the GEMM strategy: hidden is cast to the lm_head weight's dtype before matmul, then the result is cast to fp32 — instead of casting both inputs to fp32 upfront.

Findings

P2: Redundant .to(kFloat32) in default path
When enable_fp32_lm_head=True (default), the lm_head weight is already fp32, so last_hidden.to(lm_head->kernel.dtype()) casts to fp32 and the trailing .to(torch::kFloat32) is a no-op. Not a bug, but could be guarded. Same in auto_model.py.

P2: No test coverage for new config
No unit test for enable_fp32_lm_head verifying dtype propagation or correctness in both modes.

Overall the previous P1 is resolved and the new code is functionally correct. The default True preserves backward compatibility. Minor items only.

@Vinkle-hzt Vinkle-hzt changed the title chore: trans logits after gemm feat: add enable_fp32_lm_head configuration for model precision Apr 10, 2026
@LLLLKKKK LLLLKKKK merged commit bbb6749 into alibaba:main Apr 10, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants