Skip to content

[Draft] Newton-Schulz via cuSOLVERMp#2706

Open
vcherepanov-nv wants to merge 26 commits intoNVIDIA:mainfrom
vcherepanov-nv:newton-schulz
Open

[Draft] Newton-Schulz via cuSOLVERMp#2706
vcherepanov-nv wants to merge 26 commits intoNVIDIA:mainfrom
vcherepanov-nv:newton-schulz

Conversation

@vcherepanov-nv
Copy link
Collaborator

Description

Adds an API to call Newton-Schulz method on a distributed tensor.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Integrate cuSOLVERMp as a new dependency
  • Add corresponding API to TE/common
  • Add PyTorch binding and tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vcherepanov-nv and others added 19 commits February 8, 2026 22:38
Add a new distributed Newton-Schulz inverse square root API to Transformer
Engine's common C library. This wraps the cusolverMpNewtonSchulz library
function, following the same pattern as the existing cuBLASMp integration
for comm_gemm.

New files:
- newton_schulz.h: Public C API header with context management and
  computation functions
- newton_schulz/newton_schulz.cpp: Implementation with RAII wrappers
  for cuSolverMp handles

Build integration:
- New NVTE_WITH_CUSOLVERMP CMake option and CUSOLVERMP_HOME env var
- NVTE_CHECK_CUSOLVERMP error checking macro in logging.h
- Conditional compilation guarded by NVTE_WITH_CUSOLVERMP

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Add PyTorch-level bindings for the cuSolverMp Newton-Schulz inverse
square root API introduced in the previous commit.

New files:
- pytorch/csrc/extensions/newton_schulz.cpp: C++ extension wrapping
  the C API with PyTorch tensor support
- pytorch/newton_schulz.py: Python wrapper that extracts NCCL
  communicator from torch.distributed ProcessGroup
- tests/pytorch/distributed/test_newton_schulz.py: pytest launcher
- tests/pytorch/distributed/run_newton_schulz.py: distributed test
  worker with reference implementation for numerical validation

Modified files:
- pytorch/csrc/extensions.h: Function declarations
- pytorch/csrc/extensions/pybind.cpp: pybind11 registrations
- pytorch/__init__.py: Public API export

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Fix API mismatches discovered during compilation:
- cusolverMpCreate takes (handle*, deviceId, stream), not (handle*, stream)
- cusolverMpCreateDeviceGrid takes handle as first arg with different
  parameter order
- Use cusolverMpGridMapping_t (not cusolverMpGridLayout_t) and
  CUSOLVERMP_GRID_MAPPING_COL_MAJOR
- cusolverMpCreateMatrixDesc has different parameter order: (desc*,
  grid, dtype, M, N, MB, NB, RSRC, CSRC, LLD)
- cusolverMpNewtonSchulzDescriptorCreate takes only (nsDesc*) with no
  iteration/coefficient args
- No cusolverMpStreamSet exists; create handle per-call with user stream
- cusolverMpNewtonSchulz requires computeType and info parameters
- Switch from generic template RAII to explicit deleter structs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
…build

Add NVTE_WITH_CUSOLVERMP compiler define and cusolverMp include/library
paths to the PyTorch C++ extension build, following the same pattern as
NVTE_UB_WITH_MPI and NVTE_ENABLE_NVSHMEM.

Without this, the #ifdef NVTE_WITH_CUSOLVERMP guards in the PyTorch
extension code would never be active since the define was only set as
PRIVATE in the CMake build for the common library.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Two fixes:
- Use ProcessGroupNCCL._comm_ptr() to extract the raw NCCL communicator
  pointer instead of the non-existent get_nccl_comm() method
- Pass global matrix dimensions (m, n) from Python to C++ instead of
  using local tensor dimensions, which would produce incorrect
  ScaLAPACK block sizes in the distributed computation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cuSolverMp handle and grid creation are expensive operations. Move them
from per-call creation in nvte_newton_schulz into the NVTECusolverMpCtx,
which is their natural home — the context exists to encapsulate the grid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cuSolverMp cannot work with the default CUDA stream. Create a dedicated
stream inside nvte_cusolvermp_ctx_create and remove the stream parameter
from both C API functions since the context now owns its stream.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
The internal dedicated stream was reading the input tensor before the
caller's stream had finished producing it, resulting in all-zero output.

Add event-based synchronisation: the internal stream waits for the
caller's input to be ready, and the caller's stream waits for the
output to be written. Replaces the blocking cudaStreamSynchronize.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cuSolverMp is asynchronous and uses the host workspace during multi-GPU
execution. The event-based output sync did not block the host, so the
local workspace_host vector was destroyed while the GPU was still
reading from it. Restore cudaStreamSynchronize to ensure the host
workspace remains valid for the full duration of the operation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Avoid creating and destroying a cudaEvent_t on every
nvte_newton_schulz call by making it a persistent member of
NVTECusolverMpCtx, matching the existing pattern for the stream.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Replace single event with in_ready and out_ready events. After the
cuSolverMp call, record out_ready on the internal stream and make the
caller's stream wait on it, ensuring the output tensor is ready before
the caller uses it.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Replace reference-comparison test with a direct arithmetic check:
if X is the inverse square root of A, then X @ A @ X must equal the
identity matrix. This is more robust and removes the need for a
separate reference implementation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Greptile Summary

This PR adds distributed Newton-Schulz matrix orthogonalization to TransformerEngine via cuSolverMp integration. The implementation introduces a new optional dependency with proper build system integration, C API layer with RAII wrappers, and PyTorch bindings.

Key Changes

  • Integrates cuSolverMp library as an optional dependency (enabled via NVTE_WITH_CUSOLVERMP)
  • Adds C API (nvte_newton_schulz) with context management for cuSolverMp operations
  • Provides Python API (te.pytorch.newton_schulz) for distributed tensor operations
  • Includes distributed tests using torchrun with multiple GPUs

Issues Found

  • Critical: Missing tensor contiguity validation - non-contiguous tensors will produce incorrect results since C++ code assumes contiguous memory layout
  • Uses private PyTorch APIs (_get_backend, _comm_ptr) that may break in future PyTorch versions
  • Test validates orthogonality (X @ X.t() ≈ I) but documentation claims "inverse square root" - semantic mismatch between implementation and documentation
  • Context created/destroyed per call wastes resources - heavyweight operations (stream/event creation, cuSolverMp setup) happen on every invocation
  • Row distribution assumes even splitting without validation - incorrect global dimension calculation if matrix rows not evenly divisible by number of ranks
  • Synchronous cudaMalloc/cudaFree on hot path causes device synchronization, negating async benefits

Confidence Score: 2/5

  • Not safe to merge - has a critical data correctness issue with missing contiguity validation
  • The missing contiguity check is a critical bug that will cause incorrect results for non-contiguous tensors. Combined with multiple existing issues around performance, API stability, and semantic correctness, this PR needs significant fixes before merge.
  • transformer_engine/pytorch/newton_schulz.py requires immediate attention for contiguity validation. Also review the test verification logic and documentation to clarify whether this computes inverse square root or orthogonalization.

Important Files Changed

Filename Overview
transformer_engine/common/newton_schulz/newton_schulz.cpp C++ implementation with proper RAII wrappers and error handling, but has performance concerns with context recreation and synchronous memory allocation
transformer_engine/pytorch/newton_schulz.py Python API with missing contiguity validation (critical), uses private PyTorch APIs, and doesn't validate even row distribution
transformer_engine/pytorch/init.py unconditionally imports optional feature - exposes newton_schulz even when built without NVTE_WITH_CUSOLVERMP
tests/pytorch/distributed/run_newton_schulz.py distributed test worker with incorrect verification (checks orthogonality instead of inverse square root property)

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant PyAPI as newton_schulz.py
    participant Ext as C++ Extension
    participant Ctx as CusolverMpCtx
    participant cuSolver as cuSolverMp Library
    participant NCCL as NCCL Comm

    User->>PyAPI: newton_schulz(x, group, iterations)
    PyAPI->>PyAPI: Extract NCCL comm ptr from group
    PyAPI->>PyAPI: Compute global matrix dims (m = x.size(0) * nranks)
    PyAPI->>Ext: cusolvermp_ctx_create(comm_ptr, nranks, rank)
    Ext->>Ctx: Create context (stream, events, handle, grid)
    Ctx-->>Ext: ctx_ptr
    Ext-->>PyAPI: ctx_ptr
    
    PyAPI->>Ext: newton_schulz(ctx_ptr, m, n, x, iterations, coefficients)
    Ext->>Ext: Build NVTETensor from PyTorch tensor
    Ext->>Ctx: nvte_newton_schulz(ctx, m, n, x_tensor, ...)
    
    Note over Ctx: Event-based stream sync
    Ctx->>Ctx: Record event on caller_stream
    Ctx->>Ctx: Wait for event on internal stream
    
    Ctx->>Ctx: Compute ScaLAPACK distribution (mb, local_rows)
    Ctx->>cuSolver: cusolverMpNewtonSchulz_bufferSize(...)
    cuSolver-->>Ctx: workspace_size
    Ctx->>Ctx: Allocate/grow workspace (cudaMalloc)
    
    Ctx->>cuSolver: cusolverMpNewtonSchulz(...)
    cuSolver->>NCCL: Distributed matrix operations
    NCCL->>NCCL: Inter-rank communication
    cuSolver-->>Ctx: Result (in-place in x)
    
    Note over Ctx: Event-based stream sync
    Ctx->>Ctx: Record event on internal stream
    Ctx->>Ctx: Wait for event on caller_stream
    
    Ctx-->>Ext: Success
    Ext-->>PyAPI: Success
    
    PyAPI->>Ext: cusolvermp_ctx_destroy(ctx_ptr)
    Ext->>Ctx: Destroy context (free workspace, destroy handles)
    Ctx-->>Ext: Done
    Ext-->>PyAPI: Done
    PyAPI-->>User: Result (x modified in-place)
Loading

Last reviewed commit: 8eb6028

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, 12 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 93 to 98
# Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix
if rank == 0:
XXT = X @ X.t()
I = torch.eye(N, device=XXT.device, dtype=XXT.dtype)
max_diff = (XXT - I).abs().max().item()
print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verification doesn't match the comment - if X = A^{-1/2}, the check should be X @ A @ X ≈ I, not X @ X.t() ≈ I. The current check verifies X is orthogonal, not that X is the inverse square root of A. Note that A_orig is created on line 76 but never used.

Suggested change
# Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix
if rank == 0:
XXT = X @ X.t()
I = torch.eye(N, device=XXT.device, dtype=XXT.dtype)
max_diff = (XXT - I).abs().max().item()
print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True)
# Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix
XAX = X @ A_orig @ X
I = torch.eye(N, device=XAX.device, dtype=XAX.dtype)
max_diff = (XAX - I).abs().max().item()
print(f"Max |X @ A @ X - I|: {max_diff:.6e}", flush=True)
if torch.allclose(XAX, I, atol=args.atol, rtol=args.rtol):

Comment on lines +31 to +32
nccl_backend = group._get_backend(torch.device("cuda"))
return nccl_backend._comm_ptr()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uses private PyTorch APIs (_get_backend, _comm_ptr) that may change in future versions

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines 39 to 58
quintic_coefficients = [
4.0848,
-6.8946,
2.9270,
3.9505,
-6.3029,
2.6377,
3.7418,
-5.5913,
2.3037,
2.8769,
-3.1427,
1.2046,
2.8366,
-3.0525,
1.2012,
]
coefficients = (
quintic_coefficients if args.num_iterations == 5 else [1.5, -0.5, 0.0] * args.num_iterations
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coefficients mismatch with API defaults - test uses 15 coefficients for 5 iterations, but newton_schulz.py defaults to 5 coefficients. This inconsistency means default API behavior isn't tested.

vcherepanov-nv and others added 2 commits February 26, 2026 00:45
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines 8 to 12
* \brief Functions for distributed Newton-Schulz inverse square root.
*
* This API is a TE-native binding to the cuSolverMp library.
* It computes an iterative Newton-Schulz inverse square root
* approximation on a distributed matrix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation claims this computes "inverse square root" but the test validates orthogonality (X @ X.t() ≈ I), and commit dd1dd0b states "it approximates orthogonal matrix, not inverse square root". If this computes the polar decomposition (orthogonal factor), the documentation should be updated to reflect that. Inverse square root would satisfy X @ A @ X ≈ I, which is different from orthogonality.

Comment on lines +79 to +80
m = x.size(0) * nranks # rows are distributed across ranks
n = x.size(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assumes rows are evenly distributed (m = x.size(0) * nranks) but doesn't validate this. If matrix size isn't divisible by nranks, the computed global size m will be incorrect, leading to wrong results from cuSOLVERMp. Consider adding validation:

Suggested change
m = x.size(0) * nranks # rows are distributed across ranks
n = x.size(1)
# Global matrix dimensions
# Rows must be evenly distributed across ranks
local_rows = x.size(0)
m = local_rows * nranks
n = x.size(1)

Then add a validation check that all ranks have the same local_rows via dist.all_reduce.

num_iterations: int = 5,
coefficients: Optional[List[float]] = None,
) -> None:
"""Compute Newton-Schulz inverse square root in-place on a distributed matrix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring says "inverse square root" but test checks orthogonality. Update to match actual behavior (see comment on header file).

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +136 to +138
void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x,
int64_t num_iterations, const float* coefficients, int64_t num_coefficients,
cudaStream_t caller_stream) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_coefficients parameter is unused

The num_coefficients parameter is accepted but never referenced in the function body. Neither cusolverMpNewtonSchulz_bufferSize (line 166) nor cusolverMpNewtonSchulz (line 183) receive this value. If cuSolverMp infers the count from num_iterations internally, then num_coefficients is dead code that should be removed from the API. If cuSolverMp actually needs it, then it should be passed to the cuSolverMp calls — otherwise the library may read out of bounds on the coefficients array.

from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from transformer_engine.pytorch.newton_schulz import newton_schulz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional import of optional feature

newton_schulz is unconditionally imported and exported as part of the public API, even when TE is built without NVTE_WITH_CUSOLVERMP. While the function itself raises a runtime error when called, this exposes the symbol to all users and makes it appear as a supported feature in auto-complete and docs. Consider guarding this import behind a check (similar to how other optional features are handled), or at minimum adding a note in the docstring that the function requires NVTE_WITH_CUSOLVERMP=1 at build time.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +170 to +177
// Allocate/grow device workspace
if (ctx->workspace_size < wrksp_size_device) {
if (ctx->workspace) {
NVTE_CHECK_CUDA(cudaFree(ctx->workspace));
}
NVTE_CHECK_CUDA(cudaMalloc(&ctx->workspace, wrksp_size_device));
ctx->workspace_size = wrksp_size_device;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synchronous cudaFree/cudaMalloc on hot path

cudaFree followed by cudaMalloc inside nvte_newton_schulz will synchronize with the device each time the workspace needs to grow. Since the context is recreated on every call from newton_schulz.py (line 82-86 creates + destroys ctx each invocation), the workspace will never be reused across calls — the grow-only caching here is ineffective. Consider either:

  1. Allowing callers to keep the context alive across calls, or
  2. Using cudaMallocAsync/cudaFreeAsync on ctx->stream to avoid synchronous stalls.

Comment on lines +82 to +86
ctx_ptr = tex.cusolvermp_ctx_create(nccl_comm_ptr, nranks, rank)
try:
tex.newton_schulz(ctx_ptr, m, n, x, num_iterations, coefficients)
finally:
tex.cusolvermp_ctx_destroy(ctx_ptr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context created/destroyed per call wastes resources

A new NVTECusolverMpCtx is created and destroyed on every invocation of newton_schulz. Context creation involves cudaStreamCreate, two cudaEventCreate calls, cusolverMpCreate, and cusolverMpCreateDeviceGrid — all of which are heavyweight operations. And since the context is destroyed afterward, the grow-only workspace caching in the C++ layer (lines 170-177 of newton_schulz.cpp) is never actually reused.

Consider caching the context (e.g., in a module-level dict keyed by (nccl_comm_ptr, nranks, rank)) and reusing it across calls, or exposing the context lifecycle to callers so they can amortize the cost when calling newton_schulz repeatedly in a training loop.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +65 to +67
assert (
len(coefficients) == num_iterations * 3
), f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ValueError instead of assert for validation - assert can be disabled with Python's -O flag

Suggested change
assert (
len(coefficients) == num_iterations * 3
), f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations"
if len(coefficients) != num_iterations * 3:
raise ValueError(
f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations"
)

Comment on lines +69 to +72
if x.dim() != 2:
raise ValueError(f"Expected 2D tensor, got {x.dim()}D")
if not x.is_cuda:
raise ValueError("Input tensor must be on CUDA device")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing contiguity check - C++ code uses data_ptr() which requires contiguous memory. Non-contiguous tensors will cause incorrect results.

Suggested change
if x.dim() != 2:
raise ValueError(f"Expected 2D tensor, got {x.dim()}D")
if not x.is_cuda:
raise ValueError("Input tensor must be on CUDA device")
if x.dim() != 2:
raise ValueError(f"Expected 2D tensor, got {x.dim()}D")
if not x.is_cuda:
raise ValueError("Input tensor must be on CUDA device")
if not x.is_contiguous():
raise ValueError("Input tensor must be contiguous")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant