Skip to content

[Common][PyTorch] Add z_loss_weight and log_sum_exp output to parallel_cross_entropy#2707

Draft
bassoy wants to merge 2 commits intoNVIDIA:mainfrom
bassoy:add_zloss_to_parallel_cross_entropy
Draft

[Common][PyTorch] Add z_loss_weight and log_sum_exp output to parallel_cross_entropy#2707
bassoy wants to merge 2 commits intoNVIDIA:mainfrom
bassoy:add_zloss_to_parallel_cross_entropy

Conversation

@bassoy
Copy link

@bassoy bassoy commented Feb 26, 2026

Description

Adds z-loss regularization (see https://arxiv.org/abs/2202.08906) and log_sum_exp output to parallel_cross_entropy. The Triton kernel already computes lse = m + log(d) as part of the online softmax. We store it to a new output buffer and optionally add z_loss_weight * lse^2 to the loss. At z_loss_weight=0.0 all z-loss logic is dead-code-eliminated by Triton (tl.constexpr), so there is no overhead.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • Add z_loss_weight (tl.constexpr) and log_sum_exp output buffer to Triton CE kernel
  • Expose via parallel_cross_entropy (z_loss_weight=0.0, return_log_sum_exp=False) -- backward compatible; default return unchanged (single Tensor)
  • Add 8 new tests; all 15 pass on H100 (CUDA 13.0) and A40 (CUDA 13.1)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Cem Bassoy and others added 2 commits February 26, 2026 00:44
…l_cross_entropy

- Triton kernel: adds z_loss_weight (tl.constexpr) and a log_sum_exp output buffer;
  lse = m + log(d) reuses values already in registers — zero extra compute;
  z_loss blocks are dead-code-eliminated by Triton when z_loss_weight=0.0
- API: new z_loss_weight=0.0 and return_log_sum_exp=False parameters;
  default return is a single Tensor (backward compatible);
  return_log_sum_exp=True returns (loss, log_sum_exp) tuple;
  return type: Union[Tensor, Tuple[Tensor, Tensor]] matching TE convention
- Tests: 15 tests covering z_loss correctness, BF16, non-uniform backward
  gradients (loss masking), log_sum_exp semantics, and backward compatibility

Signed-off-by: Cem Bassoy <cem.bassoy@deepl.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant