fix: use tied embedding for linear CE fusion output weight#2363
fix: use tied embedding for linear CE fusion output weight#2363jthomson04 wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
|
Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
/ok to test f3466a3 |
terrykong
left a comment
There was a problem hiding this comment.
LGTM — clean, minimal bug fix. The tied-embedding path was passing None (self.output_layer.weight) as the first positional arg to from_parallel_hidden_states_to_logprobs, while the correct weight was only computed for the second arg (which turns out to be unused in the callee). The fix correctly resolves the weight via shared_embedding_or_output_weight() for tied-embedding models.
Minor observation (non-blocking): from_parallel_hidden_states_to_logprobs accepts an output_weight parameter (its 3rd positional arg) that is never referenced in the function body — only output_weight_layer (2nd positional arg) gets passed into ChunkedDistributedHiddenStatesToLogprobs.apply(). This dead parameter pre-dates this PR, but could be cleaned up in a follow-up to avoid future confusion.
Generated by Claude Code
| self.shared_embedding_or_output_weight() | ||
| if self.share_embeddings_and_output_weights | ||
| else self.output_layer.weight, | ||
| output_weight_layer, |
There was a problem hiding this comment.
nemo_rl/distributed/model_utils.py:2165
Nit: both args now pass the same output_weight_layer, which is correct since the callee's output_weight parameter (line 1836) is never actually read — only output_weight_layer is forwarded to ChunkedDistributedHiddenStatesToLogprobs.apply(). Consider removing the dead output_weight parameter in a follow-up to prevent future confusion.
There was a problem hiding this comment.
Good catch — dropped the dead output_weight parameter in 68e1b23.
…en_states_to_logprobs The third positional parameter was never read inside the function — only output_weight_layer is forwarded to ChunkedDistributedHiddenStatesToLogprobs.apply(). Per @terrykong's review note on PR NVIDIA-NeMo#2363, removing the dead arg now to prevent future confusion. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
In `_gpt_forward_with_linear_ce_fusion`, `output_weight_layer` was set unconditionally from `self.output_layer.weight`, which is `None` for models with tied embeddings (e.g. Qwen3) where `share_embeddings_and_output_weights=True`. This crashed in `from_parallel_hidden_states_to_logprobs` because the weight tensor passed positionally was `None`. Fetch the weight via `shared_embedding_or_output_weight()` when embeddings are tied, mirroring the logic that was already used for the second positional argument, and reuse the same tensor for both args. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
…en_states_to_logprobs The third positional parameter was never read inside the function — only output_weight_layer is forwarded to ChunkedDistributedHiddenStatesToLogprobs.apply(). Per @terrykong's review note on PR NVIDIA-NeMo#2363, removing the dead arg now to prevent future confusion. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
68e1b23 to
fc6c3a6
Compare
Summary
_gpt_forward_with_linear_ce_fusionsetoutput_weight_layer = self.output_layer.weight, which isNonefor models with tied embeddings (e.g. Qwen3 withshare_embeddings_and_output_weights=True). The first positional argument tofrom_parallel_hidden_states_to_logprobswas thereforeNone, crashing logprob computation whenpolicy.use_linear_ce_fusion_loss=Trueis set on a tied-embedding model.shared_embedding_or_output_weight()when embeddings are tied, matching the logic already used for the second positional argument, and reuse the same tensor for both args.Test plan
policy.use_linear_ce_fusion_loss=Trueand confirm no crash and matching logprobs vs. the non-fusion path.🤖 Generated with Claude Code