Skip to content

fix: use tied embedding for linear CE fusion output weight#2363

Open
jthomson04 wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
jthomson04:jthomson04/fix-qwen3-tied-embedding-linear-ce-fusion
Open

fix: use tied embedding for linear CE fusion output weight#2363
jthomson04 wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
jthomson04:jthomson04/fix-qwen3-tied-embedding-linear-ce-fusion

Conversation

@jthomson04
Copy link
Copy Markdown

@jthomson04 jthomson04 commented Apr 29, 2026

Summary

  • _gpt_forward_with_linear_ce_fusion set output_weight_layer = self.output_layer.weight, which is None for models with tied embeddings (e.g. Qwen3 with share_embeddings_and_output_weights=True). The first positional argument to from_parallel_hidden_states_to_logprobs was therefore None, crashing logprob computation when policy.use_linear_ce_fusion_loss=True is set on a tied-embedding model.
  • Fetch the weight via shared_embedding_or_output_weight() when embeddings are tied, matching the logic already used for the second positional argument, and reuse the same tensor for both args.

Test plan

  • Run a logprob/training step on a Qwen3 model with policy.use_linear_ce_fusion_loss=True and confirm no crash and matching logprobs vs. the non-fusion path.
  • Confirm untied-embedding models (e.g. Llama) still produce identical logprobs (the non-tied branch is unchanged).

🤖 Generated with Claude Code

@jthomson04 jthomson04 requested a review from a team as a code owner April 29, 2026 21:08
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 29, 2026

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jthomson04
Copy link
Copy Markdown
Author

/ok to test f3466a3

@terrykong terrykong requested review from guyueh1 and terrykong May 5, 2026 19:47
terrykong
terrykong previously approved these changes May 5, 2026
Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — clean, minimal bug fix. The tied-embedding path was passing None (self.output_layer.weight) as the first positional arg to from_parallel_hidden_states_to_logprobs, while the correct weight was only computed for the second arg (which turns out to be unused in the callee). The fix correctly resolves the weight via shared_embedding_or_output_weight() for tied-embedding models.

Minor observation (non-blocking): from_parallel_hidden_states_to_logprobs accepts an output_weight parameter (its 3rd positional arg) that is never referenced in the function body — only output_weight_layer (2nd positional arg) gets passed into ChunkedDistributedHiddenStatesToLogprobs.apply(). This dead parameter pre-dates this PR, but could be cleaned up in a follow-up to avoid future confusion.

Generated by Claude Code

Comment thread nemo_rl/distributed/model_utils.py Outdated
self.shared_embedding_or_output_weight()
if self.share_embeddings_and_output_weights
else self.output_layer.weight,
output_weight_layer,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nemo_rl/distributed/model_utils.py:2165

Nit: both args now pass the same output_weight_layer, which is correct since the callee's output_weight parameter (line 1836) is never actually read — only output_weight_layer is forwarded to ChunkedDistributedHiddenStatesToLogprobs.apply(). Consider removing the dead output_weight parameter in a follow-up to prevent future confusion.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — dropped the dead output_weight parameter in 68e1b23.

jthomson04 added a commit to jthomson04/RL that referenced this pull request May 6, 2026
…en_states_to_logprobs

The third positional parameter was never read inside the function — only
output_weight_layer is forwarded to ChunkedDistributedHiddenStatesToLogprobs.apply().
Per @terrykong's review note on PR NVIDIA-NeMo#2363, removing the dead arg now to prevent
future confusion.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jthomson04 and others added 2 commits May 5, 2026 20:12
In `_gpt_forward_with_linear_ce_fusion`, `output_weight_layer` was set
unconditionally from `self.output_layer.weight`, which is `None` for
models with tied embeddings (e.g. Qwen3) where
`share_embeddings_and_output_weights=True`. This crashed in
`from_parallel_hidden_states_to_logprobs` because the weight tensor
passed positionally was `None`.

Fetch the weight via `shared_embedding_or_output_weight()` when
embeddings are tied, mirroring the logic that was already used for the
second positional argument, and reuse the same tensor for both args.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
…en_states_to_logprobs

The third positional parameter was never read inside the function — only
output_weight_layer is forwarded to ChunkedDistributedHiddenStatesToLogprobs.apply().
Per @terrykong's review note on PR NVIDIA-NeMo#2363, removing the dead arg now to prevent
future confusion.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
@jthomson04 jthomson04 force-pushed the jthomson04/fix-qwen3-tied-embedding-linear-ce-fusion branch from 68e1b23 to fc6c3a6 Compare May 6, 2026 03:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants