Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a536ff4
[Common] Add Newton-Schulz inverse square root C API via cuSolverMp
vcherepanov-nv Feb 8, 2026
e8fea44
[PyTorch] Add Newton-Schulz PyTorch bindings and distributed tests
vcherepanov-nv Feb 8, 2026
2e5d826
[Common] Fix cuSolverMp API signatures in Newton-Schulz implementation
vcherepanov-nv Feb 8, 2026
48f549b
[PyTorch] Propagate NVTE_WITH_CUSOLVERMP define to PyTorch extension …
vcherepanov-nv Feb 8, 2026
c154d98
[PyTorch] Fix NCCL comm extraction and pass global dims to Newton-Schulz
vcherepanov-nv Feb 9, 2026
2f62321
[Common] Cache cuSolverMp handle and grid in Newton-Schulz context
vcherepanov-nv Feb 18, 2026
e4a9999
[Common] Create dedicated CUDA stream in Newton-Schulz context
vcherepanov-nv Feb 18, 2026
0cf4327
[Common] Fix Newton-Schulz zero output with event-based stream sync
vcherepanov-nv Feb 18, 2026
f24dd8f
[Common] Fix Newton-Schulz NaNs by keeping host workspace alive
vcherepanov-nv Feb 18, 2026
3badc16
[Common] Cache CUDA event in Newton-Schulz context
vcherepanov-nv Feb 18, 2026
8a11b4e
[Common] Use separate in/out events for Newton-Schulz stream sync
vcherepanov-nv Feb 18, 2026
5c5d206
Correct coefficients
vcherepanov-nv Feb 18, 2026
b0b1367
No stream synchronize
vcherepanov-nv Feb 18, 2026
7cfc57c
[Test] Verify Newton-Schulz result with XAX=I identity check
vcherepanov-nv Feb 18, 2026
f64d8f6
Change test - it approximates orthogonal matrix, not inverse square root
vcherepanov-nv Feb 19, 2026
c634d61
Generalize number of iterations in tests
vcherepanov-nv Feb 19, 2026
de423aa
Remove extra info diag - everything should be in logs
vcherepanov-nv Feb 25, 2026
6d3a4dc
Add Newton-Schulz tests to the QA script
vcherepanov-nv Feb 25, 2026
e424057
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2026
e5ca4b3
Fix outdated comments
vcherepanov-nv Feb 25, 2026
f86f8bb
Remove unused variable
vcherepanov-nv Feb 25, 2026
9d503e0
Move magic numbers from tests to impl
vcherepanov-nv Feb 26, 2026
89c5594
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2026
879fd38
Fix outdated comments
vcherepanov-nv Feb 26, 2026
9a7386b
Check num_coefficients
vcherepanov-nv Feb 26, 2026
ff78aa3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2026
823a2f5
Auto-detect cuSolverMp support from common library binary
vcherepanov-nv Feb 27, 2026
257cc43
Conditionally exclude Newton-Schulz API from PyTorch extension
vcherepanov-nv Feb 27, 2026
274c06d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2026
a1026fb
Make symbol detection errors fatal in common_lib_has_symbol
vcherepanov-nv Feb 27, 2026
295504e
Search for libtransformer_engine.so via installed module location first
vcherepanov-nv Mar 4, 2026
b9c6bc8
Add site packages to search paths for TE common
vcherepanov-nv Mar 4, 2026
8de97d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2026
9811950
Revert "Auto-detect cuSolverMp support from common library binary"
vcherepanov-nv Mar 4, 2026
8cadb0d
Remove unused import
vcherepanov-nv Mar 4, 2026
4913a9d
Fix incorrect 'inverse square root' references in Newton-Schulz comments
vcherepanov-nv Mar 5, 2026
a9411e1
[PyTorch] Expose cuSolverMp context creation/destruction as public API
vcherepanov-nv Mar 5, 2026
c825455
[PyTorch] Strengthen input validation in newton_schulz
vcherepanov-nv Mar 5, 2026
842ed71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
dd3c2e4
Use ncclMemAlloc for cuSolverMp Newton-Schulz workspace
vcherepanov-nv Apr 1, 2026
5015e58
Add Newton-Schulz reference tests
vcherepanov-nv Apr 1, 2026
3acab29
Fix Newton-Schulz reference test logic
vcherepanov-nv Apr 1, 2026
a389e14
Fix column-major usage of cuSOLVERMp; add rectangular test cases
vcherepanov-nv Apr 2, 2026
61cff6f
Avoid explicit transpose
vcherepanov-nv Apr 3, 2026
960dd0f
Cleanup
vcherepanov-nv Apr 3, 2026
e2576a7
More cleanup
vcherepanov-nv Apr 3, 2026
33bb8fd
Cleanup
vcherepanov-nv Apr 3, 2026
7e53e11
Update transformer_engine/common/newton_schulz/newton_schulz.cpp
vcherepanov-nv Apr 3, 2026
da9dea3
Fix syntax
vcherepanov-nv Apr 3, 2026
70d2ea8
Apply suggestions from code review
vcherepanov-nv Apr 6, 2026
fc47fc0
Add timeout
vcherepanov-nv Apr 6, 2026
eca8616
Use RAII for cusolvermp CUDA resources
vcherepanov-nv Apr 6, 2026
d4d3c93
Make NS API declared unconditional, with stub / runtime errors withou…
vcherepanov-nv Apr 6, 2026
2b8d56a
Merge branch 'main' into newton-schulz
vcherepanov-nv Apr 6, 2026
bfe7484
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2026
ce0c44b
Fix index in diag
vcherepanov-nv Apr 6, 2026
739fd08
CMake fixes
vcherepanov-nv Apr 6, 2026
c99e42c
Update transformer_engine/pytorch/newton_schulz.py
vcherepanov-nv Apr 13, 2026
1ee7dd8
Fix a typo
vcherepanov-nv Apr 13, 2026
ae4f539
Cleanup context management
vcherepanov-nv Apr 13, 2026
72335af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2026
8e14fa7
Borrow more coefficient sets from Emerging Optimizers
vcherepanov-nv Apr 13, 2026
f48bbfc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2026
5d6cc7b
Couple num_iterations with coeff types in tests
vcherepanov-nv Apr 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"


# debug tests
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def setup_common_extension() -> CMakeExtension:
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")

if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))):
cmake_flags.append("-DNVTE_WITH_CUSOLVERMP=ON")
cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr")
cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}")

# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
Expand Down
127 changes: 127 additions & 0 deletions tests/pytorch/distributed/run_newton_schulz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Distributed Newton-Schulz test worker.

Launched via torchrun from test_newton_schulz.py.
"""

import argparse
import sys

import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record

from transformer_engine.pytorch.newton_schulz import (
CusolverMpCtx,
get_coefficients,
newton_schulz,
)


def newton_schulz_reference(in_x: torch.Tensor, coefficients: list[float]) -> torch.Tensor:
"""Local Newton-Schulz reference mirroring the provided Octave update."""
x = in_x.clone()
for i in range(len(coefficients) // 3):
a, b, c = coefficients[3 * i : 3 * (i + 1)]
xxt = x @ x.mT
x = a * x + b * xxt @ x + c * xxt @ xxt @ x
return x


@record
def main():
parser = argparse.ArgumentParser(description="Newton-Schulz distributed test")
parser.add_argument(
"--check", type=str, default="orthogonality", choices=["orthogonality", "reference"]
)
parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"])
parser.add_argument("--matrix-rows", type=int, default=256)
parser.add_argument("--matrix-cols", type=int, default=None)
parser.add_argument("--num-iterations", type=int, default=5)
parser.add_argument("--coeff-type", type=str, default="quintic")
parser.add_argument("--atol", type=float, default=1e-2)
parser.add_argument("--rtol", type=float, default=1e-2)
args = parser.parse_args()

dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)

dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
m = args.matrix_rows
n = args.matrix_cols if args.matrix_cols is not None else args.matrix_rows
coefficients = get_coefficients(args.num_iterations, args.coeff_type)

# Ensure the distributed column dimension is divisible by world_size.
assert n % world_size == 0, f"Matrix columns {n} must be divisible by world_size {world_size}"

# Create a random matrix on rank 0 with singular values in (0, 1),
# which keeps the Newton-Schulz iterations in the convergence regime.
if rank == 0:
torch.manual_seed(42)
k = min(m, n)
U, _ = torch.linalg.qr(
torch.randn(m, k, device="cuda", dtype=torch.float32), mode="reduced"
)
V, _ = torch.linalg.qr(
torch.randn(n, k, device="cuda", dtype=torch.float32), mode="reduced"
)
singular_values = torch.rand(k, device="cuda", dtype=torch.float32) * 0.8 + 0.1
A = U @ torch.diag(singular_values) @ V.T
A = A.to(dtype)
else:
A = torch.empty(m, n, device="cuda", dtype=dtype)

# Broadcast the full matrix to all ranks
dist.broadcast(A, src=0)

# Scatter columns to each rank
local_cols = n // world_size
x_local = A[:, rank * local_cols : (rank + 1) * local_cols].contiguous()

ctx = CusolverMpCtx(dist.group.WORLD)
try:
newton_schulz(x_local, ctx, args.num_iterations, coefficients=coefficients)
finally:
ctx.destroy()

# Gather results
gathered = [torch.empty_like(x_local) for _ in range(world_size)]
dist.all_gather(gathered, x_local)
X = torch.cat(gathered, dim=1)

# Check: the resulting matrix should be orthogonal, or match a local reference.
if rank == 0:
if args.check == "orthogonality":
if m <= n:
gram = X @ X.t()
expected = torch.eye(m, device=gram.device, dtype=gram.dtype)
max_diff = (gram - expected).abs().max().item()
print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True)
else:
gram = X.t() @ X
expected = torch.eye(n, device=gram.device, dtype=gram.dtype)
max_diff = (gram - expected).abs().max().item()
print(f"Max |X.t() @ X - I|: {max_diff:.6e}", flush=True)
passed = torch.allclose(gram, expected, atol=args.atol, rtol=args.rtol)
else:
reference = newton_schulz_reference(A.float(), coefficients).to(dtype)
max_diff = (X - reference).abs().max().item()
print(f"Max |distributed - reference|: {max_diff:.6e}", flush=True)
passed = torch.allclose(X, reference, atol=args.atol, rtol=args.rtol)

if passed:
print("NUMERICAL CHECK PASSED", flush=True)
else:
print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr)
sys.exit(1)

dist.destroy_process_group()


if __name__ == "__main__":
main()
69 changes: 69 additions & 0 deletions tests/pytorch/distributed/test_newton_schulz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Tests for distributed Newton-Schulz matrix orthogonalization."""

import os
import subprocess
from pathlib import Path

import pytest
import torch

if torch.cuda.device_count() < 2:
pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True)

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS = torch.cuda.device_count()
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
ORTHOGONALITY_SHAPES = [
(NUM_PROCS * 64, NUM_PROCS * 64),
(NUM_PROCS * 64, NUM_PROCS * 96),
(NUM_PROCS * 96, NUM_PROCS * 64),
]
REFERENCE_SHAPES = [(NUM_PROCS * 64, NUM_PROCS * 64)]


def _run_test(dtype, matrix_shape, num_iterations, coeff_type, check):
rows, cols = matrix_shape
test_path = TEST_ROOT / "run_newton_schulz.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
f"--check={check}",
f"--dtype={dtype}",
f"--matrix-rows={rows}",
f"--matrix-cols={cols}",
f"--num-iterations={num_iterations}",
f"--coeff-type={coeff_type}",
]
if dtype == "bfloat16":
test_cmd += ["--atol=5e-2", "--rtol=5e-2"]

result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False, timeout=300)
if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(
"Newton-Schulz test failed.\n"
f"stdout: {result.stdout.decode()}\n"
f"stderr: {result.stderr.decode()}"
)


@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
@pytest.mark.parametrize("matrix_shape", ORTHOGONALITY_SHAPES)
@pytest.mark.parametrize("num_iterations,coeff_type", [(5, "quintic"), (8, "polar_express")])
def test_orthogonality(dtype, matrix_shape, num_iterations, coeff_type):
"""Test distributed Newton-Schulz orthogonality."""
_run_test(dtype, matrix_shape, num_iterations, coeff_type, "orthogonality")


@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
@pytest.mark.parametrize("matrix_shape", REFERENCE_SHAPES)
@pytest.mark.parametrize("num_iterations,coeff_type", [(5, "quintic"), (8, "polar_express")])
def test_against_reference(dtype, matrix_shape, num_iterations, coeff_type):
"""Test distributed Newton-Schulz against a local reference implementation."""
_run_test(dtype, matrix_shape, num_iterations, coeff_type, "reference")
21 changes: 20 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ list(APPEND transformer_engine_cpp_sources
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp)
comm_gemm_overlap/comm_gemm_overlap.cpp
newton_schulz/newton_schulz.cpp
)

list(APPEND transformer_engine_cuda_sources
common.cu
Expand Down Expand Up @@ -303,6 +305,23 @@ if (NVTE_WITH_CUBLASMP)
message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
endif()

option(NVTE_WITH_CUSOLVERMP "Use cuSolverMp for distributed Newton-Schulz" OFF)
if (NVTE_WITH_CUSOLVERMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUSOLVERMP)
target_include_directories(transformer_engine PRIVATE ${CUSOLVERMP_DIR}/include)
find_library(CUSOLVERMP_LIB
NAMES cusolverMp libcusolverMp
PATHS ${CUSOLVERMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PRIVATE ${NCCL_LIB} ${CUSOLVERMP_LIB})
message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}")
endif()

# Number of philox4x32 rounds for stochastic rounding (build-time constant).
set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS})
if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file newton_schulz.h
* \brief Functions for distributed Newton-Schulz matrix orthogonalization.
*
* This API is a TE-native binding to the cuSolverMp library.
* It computes an iterative Newton-Schulz matrix orthogonalization on a distributed matrix.
*/

#ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_
#define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_

#include <nccl.h>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional #include <nccl.h> in a public header.

newton_schulz.h is installed as a public header (under include/transformer_engine/). The unconditional #include <nccl.h> means that any downstream project that includes this header — even one with no interest in Newton-Schulz — now requires NCCL in its include path.

ncclComm_t is only used in the function signatures of nvte_cusolvermp_ctx_create and nvte_newton_schulz, which are themselves only meaningful when NVTE_WITH_CUSOLVERMP is defined. Guarding the include and the declarations together would prevent the leakage:

Suggested change
#include <nccl.h>
#ifdef NVTE_WITH_CUSOLVERMP
#include <nccl.h>
// ... struct and function declarations ...
#endif // NVTE_WITH_CUSOLVERMP

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also a little uneasy at exposing NCCL as a required dependency, but I see we already import the NCCL header elsewhere in the codebase:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use NCCL types (ncclComm_t) in nvte_ctx_create API, so if we're taking the route of always defining the functions - we need to pull NCCL headers unconditionally.

#include <stdint.h>

#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#endif

typedef struct NVTECusolverMpCtx NVTECusolverMpCtx;

/*! \brief Create a cuSolverMp context for Newton-Schulz operations.
*
* Creates a dedicated CUDA stream internally (cuSolverMp requires a
* non-default stream).
*
* \param[in] comm NCCL communicator.
* \param[in] nranks Number of ranks.
* \param[in] rank Local rank.
*/
NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank);

/*! \brief Destroy a cuSolverMp context.
*
* \param[in] ctx Context to destroy.
*/
void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx);

/*! \brief Compute Newton-Schulz matrix orthogonalization in-place.
*
* \param[in] ctx cuSolverMp context.
* \param[in] m Global number of rows.
* \param[in] n Global number of columns.
* \param[in,out] x Local part of the matrix (modified in-place).
* \param[in] num_iterations Number of Newton-Schulz iterations.
* \param[in] coefficients Array of polynomial coefficients (length depends on polynomial
* degree used internally by cuSolverMp).
* \param[in] num_coefficients Number of elements in the coefficients array.
* \param[in] caller_stream CUDA stream on which the caller produced the input tensor.
* Used for event-based synchronisation with the internal stream.
*/
void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x,
int64_t num_iterations, const float* coefficients, int64_t num_coefficients,
cudaStream_t caller_stream);

#ifdef __cplusplus
} // extern "C"
#endif

#endif // TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_
Loading
Loading