-
Notifications
You must be signed in to change notification settings - Fork 719
Newton-Schulz via cuSOLVERMp #2706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
a536ff4
[Common] Add Newton-Schulz inverse square root C API via cuSolverMp
vcherepanov-nv e8fea44
[PyTorch] Add Newton-Schulz PyTorch bindings and distributed tests
vcherepanov-nv 2e5d826
[Common] Fix cuSolverMp API signatures in Newton-Schulz implementation
vcherepanov-nv 48f549b
[PyTorch] Propagate NVTE_WITH_CUSOLVERMP define to PyTorch extension …
vcherepanov-nv c154d98
[PyTorch] Fix NCCL comm extraction and pass global dims to Newton-Schulz
vcherepanov-nv 2f62321
[Common] Cache cuSolverMp handle and grid in Newton-Schulz context
vcherepanov-nv e4a9999
[Common] Create dedicated CUDA stream in Newton-Schulz context
vcherepanov-nv 0cf4327
[Common] Fix Newton-Schulz zero output with event-based stream sync
vcherepanov-nv f24dd8f
[Common] Fix Newton-Schulz NaNs by keeping host workspace alive
vcherepanov-nv 3badc16
[Common] Cache CUDA event in Newton-Schulz context
vcherepanov-nv 8a11b4e
[Common] Use separate in/out events for Newton-Schulz stream sync
vcherepanov-nv 5c5d206
Correct coefficients
vcherepanov-nv b0b1367
No stream synchronize
vcherepanov-nv 7cfc57c
[Test] Verify Newton-Schulz result with XAX=I identity check
vcherepanov-nv f64d8f6
Change test - it approximates orthogonal matrix, not inverse square root
vcherepanov-nv c634d61
Generalize number of iterations in tests
vcherepanov-nv de423aa
Remove extra info diag - everything should be in logs
vcherepanov-nv 6d3a4dc
Add Newton-Schulz tests to the QA script
vcherepanov-nv e424057
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e5ca4b3
Fix outdated comments
vcherepanov-nv f86f8bb
Remove unused variable
vcherepanov-nv 9d503e0
Move magic numbers from tests to impl
vcherepanov-nv 89c5594
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 879fd38
Fix outdated comments
vcherepanov-nv 9a7386b
Check num_coefficients
vcherepanov-nv ff78aa3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 823a2f5
Auto-detect cuSolverMp support from common library binary
vcherepanov-nv 257cc43
Conditionally exclude Newton-Schulz API from PyTorch extension
vcherepanov-nv 274c06d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a1026fb
Make symbol detection errors fatal in common_lib_has_symbol
vcherepanov-nv 295504e
Search for libtransformer_engine.so via installed module location first
vcherepanov-nv b9c6bc8
Add site packages to search paths for TE common
vcherepanov-nv 8de97d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9811950
Revert "Auto-detect cuSolverMp support from common library binary"
vcherepanov-nv 8cadb0d
Remove unused import
vcherepanov-nv 4913a9d
Fix incorrect 'inverse square root' references in Newton-Schulz comments
vcherepanov-nv a9411e1
[PyTorch] Expose cuSolverMp context creation/destruction as public API
vcherepanov-nv c825455
[PyTorch] Strengthen input validation in newton_schulz
vcherepanov-nv 842ed71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] dd3c2e4
Use ncclMemAlloc for cuSolverMp Newton-Schulz workspace
vcherepanov-nv 5015e58
Add Newton-Schulz reference tests
vcherepanov-nv 3acab29
Fix Newton-Schulz reference test logic
vcherepanov-nv a389e14
Fix column-major usage of cuSOLVERMp; add rectangular test cases
vcherepanov-nv 61cff6f
Avoid explicit transpose
vcherepanov-nv 960dd0f
Cleanup
vcherepanov-nv e2576a7
More cleanup
vcherepanov-nv 33bb8fd
Cleanup
vcherepanov-nv 7e53e11
Update transformer_engine/common/newton_schulz/newton_schulz.cpp
vcherepanov-nv da9dea3
Fix syntax
vcherepanov-nv 70d2ea8
Apply suggestions from code review
vcherepanov-nv fc47fc0
Add timeout
vcherepanov-nv eca8616
Use RAII for cusolvermp CUDA resources
vcherepanov-nv d4d3c93
Make NS API declared unconditional, with stub / runtime errors withou…
vcherepanov-nv 2b8d56a
Merge branch 'main' into newton-schulz
vcherepanov-nv bfe7484
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ce0c44b
Fix index in diag
vcherepanov-nv 739fd08
CMake fixes
vcherepanov-nv c99e42c
Update transformer_engine/pytorch/newton_schulz.py
vcherepanov-nv 1ee7dd8
Fix a typo
vcherepanov-nv ae4f539
Cleanup context management
vcherepanov-nv 72335af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8e14fa7
Borrow more coefficient sets from Emerging Optimizers
vcherepanov-nv f48bbfc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5d6cc7b
Couple num_iterations with coeff types in tests
vcherepanov-nv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Distributed Newton-Schulz test worker. | ||
|
|
||
| Launched via torchrun from test_newton_schulz.py. | ||
| """ | ||
|
|
||
| import argparse | ||
| import sys | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.distributed.elastic.multiprocessing.errors import record | ||
|
|
||
| from transformer_engine.pytorch.newton_schulz import ( | ||
| CusolverMpCtx, | ||
| get_coefficients, | ||
| newton_schulz, | ||
| ) | ||
|
|
||
|
|
||
| def newton_schulz_reference(in_x: torch.Tensor, coefficients: list[float]) -> torch.Tensor: | ||
| """Local Newton-Schulz reference mirroring the provided Octave update.""" | ||
| x = in_x.clone() | ||
| for i in range(len(coefficients) // 3): | ||
| a, b, c = coefficients[3 * i : 3 * (i + 1)] | ||
| xxt = x @ x.mT | ||
| x = a * x + b * xxt @ x + c * xxt @ xxt @ x | ||
| return x | ||
|
|
||
|
|
||
| @record | ||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Newton-Schulz distributed test") | ||
| parser.add_argument( | ||
| "--check", type=str, default="orthogonality", choices=["orthogonality", "reference"] | ||
| ) | ||
| parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) | ||
| parser.add_argument("--matrix-rows", type=int, default=256) | ||
| parser.add_argument("--matrix-cols", type=int, default=None) | ||
| parser.add_argument("--num-iterations", type=int, default=5) | ||
| parser.add_argument("--coeff-type", type=str, default="quintic") | ||
| parser.add_argument("--atol", type=float, default=1e-2) | ||
| parser.add_argument("--rtol", type=float, default=1e-2) | ||
| args = parser.parse_args() | ||
|
|
||
| dist.init_process_group(backend="nccl") | ||
| rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
| torch.cuda.set_device(rank) | ||
|
|
||
| dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 | ||
| m = args.matrix_rows | ||
| n = args.matrix_cols if args.matrix_cols is not None else args.matrix_rows | ||
| coefficients = get_coefficients(args.num_iterations, args.coeff_type) | ||
|
|
||
| # Ensure the distributed column dimension is divisible by world_size. | ||
| assert n % world_size == 0, f"Matrix columns {n} must be divisible by world_size {world_size}" | ||
|
|
||
| # Create a random matrix on rank 0 with singular values in (0, 1), | ||
| # which keeps the Newton-Schulz iterations in the convergence regime. | ||
| if rank == 0: | ||
| torch.manual_seed(42) | ||
| k = min(m, n) | ||
| U, _ = torch.linalg.qr( | ||
| torch.randn(m, k, device="cuda", dtype=torch.float32), mode="reduced" | ||
| ) | ||
| V, _ = torch.linalg.qr( | ||
| torch.randn(n, k, device="cuda", dtype=torch.float32), mode="reduced" | ||
| ) | ||
| singular_values = torch.rand(k, device="cuda", dtype=torch.float32) * 0.8 + 0.1 | ||
| A = U @ torch.diag(singular_values) @ V.T | ||
| A = A.to(dtype) | ||
| else: | ||
| A = torch.empty(m, n, device="cuda", dtype=dtype) | ||
|
|
||
| # Broadcast the full matrix to all ranks | ||
| dist.broadcast(A, src=0) | ||
|
|
||
| # Scatter columns to each rank | ||
| local_cols = n // world_size | ||
| x_local = A[:, rank * local_cols : (rank + 1) * local_cols].contiguous() | ||
|
|
||
| ctx = CusolverMpCtx(dist.group.WORLD) | ||
| try: | ||
| newton_schulz(x_local, ctx, args.num_iterations, coefficients=coefficients) | ||
| finally: | ||
| ctx.destroy() | ||
|
|
||
| # Gather results | ||
| gathered = [torch.empty_like(x_local) for _ in range(world_size)] | ||
| dist.all_gather(gathered, x_local) | ||
| X = torch.cat(gathered, dim=1) | ||
|
|
||
| # Check: the resulting matrix should be orthogonal, or match a local reference. | ||
| if rank == 0: | ||
| if args.check == "orthogonality": | ||
| if m <= n: | ||
| gram = X @ X.t() | ||
| expected = torch.eye(m, device=gram.device, dtype=gram.dtype) | ||
| max_diff = (gram - expected).abs().max().item() | ||
| print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True) | ||
| else: | ||
| gram = X.t() @ X | ||
| expected = torch.eye(n, device=gram.device, dtype=gram.dtype) | ||
| max_diff = (gram - expected).abs().max().item() | ||
| print(f"Max |X.t() @ X - I|: {max_diff:.6e}", flush=True) | ||
| passed = torch.allclose(gram, expected, atol=args.atol, rtol=args.rtol) | ||
| else: | ||
| reference = newton_schulz_reference(A.float(), coefficients).to(dtype) | ||
| max_diff = (X - reference).abs().max().item() | ||
| print(f"Max |distributed - reference|: {max_diff:.6e}", flush=True) | ||
| passed = torch.allclose(X, reference, atol=args.atol, rtol=args.rtol) | ||
|
|
||
| if passed: | ||
| print("NUMERICAL CHECK PASSED", flush=True) | ||
| else: | ||
| print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Tests for distributed Newton-Schulz matrix orthogonalization.""" | ||
|
|
||
| import os | ||
| import subprocess | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| if torch.cuda.device_count() < 2: | ||
| pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True) | ||
|
|
||
| TEST_ROOT = Path(__file__).parent.resolve() | ||
| NUM_PROCS = torch.cuda.device_count() | ||
| LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] | ||
| ORTHOGONALITY_SHAPES = [ | ||
| (NUM_PROCS * 64, NUM_PROCS * 64), | ||
| (NUM_PROCS * 64, NUM_PROCS * 96), | ||
| (NUM_PROCS * 96, NUM_PROCS * 64), | ||
| ] | ||
| REFERENCE_SHAPES = [(NUM_PROCS * 64, NUM_PROCS * 64)] | ||
|
|
||
|
|
||
| def _run_test(dtype, matrix_shape, num_iterations, coeff_type, check): | ||
| rows, cols = matrix_shape | ||
| test_path = TEST_ROOT / "run_newton_schulz.py" | ||
| test_cmd = LAUNCH_CMD + [ | ||
| str(test_path), | ||
| f"--check={check}", | ||
| f"--dtype={dtype}", | ||
| f"--matrix-rows={rows}", | ||
| f"--matrix-cols={cols}", | ||
| f"--num-iterations={num_iterations}", | ||
| f"--coeff-type={coeff_type}", | ||
| ] | ||
| if dtype == "bfloat16": | ||
| test_cmd += ["--atol=5e-2", "--rtol=5e-2"] | ||
|
|
||
| result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False, timeout=300) | ||
| if ( | ||
| result.returncode != 0 | ||
| or "NUMERICAL CHECK FAILED" in result.stderr.decode() | ||
| or "NUMERICAL CHECK PASSED" not in result.stdout.decode() | ||
| ): | ||
| raise AssertionError( | ||
| "Newton-Schulz test failed.\n" | ||
| f"stdout: {result.stdout.decode()}\n" | ||
| f"stderr: {result.stderr.decode()}" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) | ||
| @pytest.mark.parametrize("matrix_shape", ORTHOGONALITY_SHAPES) | ||
| @pytest.mark.parametrize("num_iterations,coeff_type", [(5, "quintic"), (8, "polar_express")]) | ||
| def test_orthogonality(dtype, matrix_shape, num_iterations, coeff_type): | ||
| """Test distributed Newton-Schulz orthogonality.""" | ||
| _run_test(dtype, matrix_shape, num_iterations, coeff_type, "orthogonality") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) | ||
| @pytest.mark.parametrize("matrix_shape", REFERENCE_SHAPES) | ||
| @pytest.mark.parametrize("num_iterations,coeff_type", [(5, "quintic"), (8, "polar_express")]) | ||
| def test_against_reference(dtype, matrix_shape, num_iterations, coeff_type): | ||
| """Test distributed Newton-Schulz against a local reference implementation.""" | ||
| _run_test(dtype, matrix_shape, num_iterations, coeff_type, "reference") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
transformer_engine/common/include/transformer_engine/newton_schulz.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| /*! \file newton_schulz.h | ||
| * \brief Functions for distributed Newton-Schulz matrix orthogonalization. | ||
| * | ||
| * This API is a TE-native binding to the cuSolverMp library. | ||
| * It computes an iterative Newton-Schulz matrix orthogonalization on a distributed matrix. | ||
| */ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ | ||
| #define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ | ||
|
|
||
| #include <nccl.h> | ||
| #include <stdint.h> | ||
|
|
||
| #include "transformer_engine.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| typedef struct NVTECusolverMpCtx NVTECusolverMpCtx; | ||
|
|
||
| /*! \brief Create a cuSolverMp context for Newton-Schulz operations. | ||
| * | ||
| * Creates a dedicated CUDA stream internally (cuSolverMp requires a | ||
| * non-default stream). | ||
| * | ||
| * \param[in] comm NCCL communicator. | ||
| * \param[in] nranks Number of ranks. | ||
| * \param[in] rank Local rank. | ||
| */ | ||
| NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank); | ||
|
|
||
| /*! \brief Destroy a cuSolverMp context. | ||
| * | ||
| * \param[in] ctx Context to destroy. | ||
| */ | ||
| void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); | ||
|
|
||
| /*! \brief Compute Newton-Schulz matrix orthogonalization in-place. | ||
| * | ||
| * \param[in] ctx cuSolverMp context. | ||
| * \param[in] m Global number of rows. | ||
| * \param[in] n Global number of columns. | ||
| * \param[in,out] x Local part of the matrix (modified in-place). | ||
| * \param[in] num_iterations Number of Newton-Schulz iterations. | ||
| * \param[in] coefficients Array of polynomial coefficients (length depends on polynomial | ||
| * degree used internally by cuSolverMp). | ||
| * \param[in] num_coefficients Number of elements in the coefficients array. | ||
| * \param[in] caller_stream CUDA stream on which the caller produced the input tensor. | ||
| * Used for event-based synchronisation with the internal stream. | ||
| */ | ||
| void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, | ||
| int64_t num_iterations, const float* coefficients, int64_t num_coefficients, | ||
| cudaStream_t caller_stream); | ||
|
|
||
| #ifdef __cplusplus | ||
| } // extern "C" | ||
| #endif | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unconditional
#include <nccl.h>in a public header.newton_schulz.his installed as a public header (underinclude/transformer_engine/). The unconditional#include <nccl.h>means that any downstream project that includes this header — even one with no interest in Newton-Schulz — now requires NCCL in its include path.ncclComm_tis only used in the function signatures ofnvte_cusolvermp_ctx_createandnvte_newton_schulz, which are themselves only meaningful whenNVTE_WITH_CUSOLVERMPis defined. Guarding the include and the declarations together would prevent the leakage:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also a little uneasy at exposing NCCL as a required dependency, but I see we already import the NCCL header elsewhere in the codebase:
TransformerEngine/transformer_engine/common/util/logging.h
Line 15 in 580e7aa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use NCCL types (ncclComm_t) in nvte_ctx_create API, so if we're taking the route of always defining the functions - we need to pull NCCL headers unconditionally.