Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions gptqmodel/eora/eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ..utils.logger import setup_logger
from ..utils.rocm import IS_ROCM
from ..utils.torch import TORCH_GTE_210

log = setup_logger()

Expand Down Expand Up @@ -91,7 +92,7 @@ def eora_compute_lora(
# save this later for SVD
raw_scaling_diag_matrix = eigen_scaling_diag_matrix.to(device=device, dtype=torch.float64)

if IS_ROCM:
if IS_ROCM and not TORCH_GTE_210:
# hip cannot resolve linalg ops
original_backend = torch.backends.cuda.preferred_linalg_library()
torch.backends.cuda.preferred_linalg_library(backend="magma")
Expand Down Expand Up @@ -131,7 +132,7 @@ def eora_compute_lora(
del truc_s, truc_u, truc_v, truc_sigma, sqrtS

# revert linalg backend
if IS_ROCM:
if IS_ROCM and not TORCH_GTE_210:
torch.backends.cuda.preferred_linalg_library(original_backend)

return A, B
30 changes: 25 additions & 5 deletions gptqmodel/utils/linalg_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,32 @@
import threading

import torch
from .torch import TORCH_GTE_210


_GLOBAL_WARMUP_LOCK = threading.Lock()


def _get_cuda_preferred_linalg_library():
preferred = getattr(torch.backends.cuda, "preferred_linalg_library", None)
if preferred is None:
return None
if callable(preferred):
return preferred()
return preferred


def _set_cuda_preferred_linalg_library(backend) -> bool:
preferred = getattr(torch.backends.cuda, "preferred_linalg_library", None)
if preferred is None:
return False
if callable(preferred):
preferred(backend=backend)
return True
setattr(torch.backends.cuda, "preferred_linalg_library", backend)
return True


def _make_spd(size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Generate a small symmetric positive definite matrix."""
base = torch.randn((size, size), device=device, dtype=dtype)
Expand Down Expand Up @@ -64,19 +85,18 @@ def run_torch_linalg_warmup(device: torch.device) -> None:
_run_qr(device, dtype)

if device.type == "cuda" and hasattr(torch.backends, "cuda"):
preferred = getattr(torch.backends.cuda, "preferred_linalg_library", None)
if callable(preferred):
current = preferred()
current = _get_cuda_preferred_linalg_library()
if current is not None and not TORCH_GTE_210:
# Core warmup already ran using the currently preferred backend above.
# Some installations fall back to MAGMA when the primary solver is unavailable,
# so we pre-initialize MAGMA as well when it differs from the preferred backend.
if current and current != "magma":
with contextlib.suppress(Exception):
torch.backends.cuda.preferred_linalg_library(backend="magma")
_set_cuda_preferred_linalg_library("magma")
_run_cholesky_and_eigh(device, torch.float32)
if current:
with contextlib.suppress(Exception):
torch.backends.cuda.preferred_linalg_library(backend=current)
_set_cuda_preferred_linalg_library(current)


__all__ = ["run_torch_linalg_warmup"]
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test package namespace used for local test helpers."""
13 changes: 9 additions & 4 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch # noqa: E402
from logbar import LogBar # noqa: E402
from parameterized import parameterized # noqa: E402
from gptqmodel.utils.torch import TORCH_GTE_210


log = LogBar.shared()
Expand Down Expand Up @@ -44,11 +45,15 @@ def test_linalg_eigh(self, dtype: torch.dtype, size: int):
]
)
def test_linalg_eigh_magma(self, dtype: torch.dtype, size: int):
# force `magma` backend for linalg
original_backend = torch.backends.cuda.preferred_linalg_library()
torch.backends.cuda.preferred_linalg_library(backend="magma")
# force `magma` backend for linalg when available and allowed
restore_backend = None
preferred_linalg_library = getattr(torch.backends.cuda, "preferred_linalg_library", None)
if not TORCH_GTE_210 and callable(preferred_linalg_library):
restore_backend = preferred_linalg_library()
preferred_linalg_library(backend="magma")

matrix = torch.randn([size, size], device=ROCM, dtype=dtype)
torch.linalg.eigh(matrix)

torch.backends.cuda.preferred_linalg_library(backend=original_backend)
if restore_backend is not None:
preferred_linalg_library(backend=restore_backend)
Loading