From 9e371703e45d18b9d0815ce515dba6fbf07063a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Mon, 23 Feb 2026 20:40:09 +0000 Subject: [PATCH 1/4] add fixture to gpu cache cleanup --- tests/conftest.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7bb16e204b..4261e71f1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import gc import json import os import types @@ -203,7 +204,27 @@ def is_cuda_oom_error_str(e: str) -> bool: return "CUDA" in e and "out of memory" in e -@pytest.hookimpl(wrapper=True) +def clear_cuda_cache(device: torch.device) -> None: + total_memory = get_device_properties(device).total_memory + reserved_memory = torch.cuda.memory_reserved() + + # FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.75) + threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.75")) + + if reserved_memory > threshold * total_memory: + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(autouse=True, scope="function") +def clear_gpu_memory(): + yield + if torch.cuda.is_available(): + device = torch.device("cuda:0") + clear_cuda_cache(device) # Use the existing function + + +@pytest.hookimpl(tryfirst=True) def pytest_runtest_call(item): # skip OOM error and missing JIT cache errors try: From 9706b419c046800db0695b676369c4261ea9beef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Mon, 23 Feb 2026 21:58:46 +0000 Subject: [PATCH 2/4] add gpu memory cleanup hook to rocm tests --- tests/rocm_tests/conftest.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/rocm_tests/conftest.py b/tests/rocm_tests/conftest.py index 6ae60760b2..d75645f61d 100644 --- a/tests/rocm_tests/conftest.py +++ b/tests/rocm_tests/conftest.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import os +import pytest +import torch from flashinfer.hip_utils import get_available_gpu_count @@ -56,3 +59,23 @@ def pytest_xdist_auto_num_workers(config): "Check HIP_VISIBLE_DEVICES or ROCm installation." ) return n + + +def _maybe_clear_gpu_memory(device: torch.device) -> None: + total_memory = torch.cuda.get_device_properties(device).total_memory + reserved_memory = torch.cuda.memory_reserved() + + # FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.75) + threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.75")) + + if reserved_memory > threshold * total_memory: + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(autouse=True, scope="function") +def clear_gpu_memory(): + yield + if torch.cuda.is_available(): + # Assume single GPU per worker due to device pinning in pytest_configure + _maybe_clear_gpu_memory(torch.device("cuda")) From c687759771c05cb9b73fbf8180b8b51915503ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Mon, 23 Feb 2026 21:59:28 +0000 Subject: [PATCH 3/4] restore original conftest --- tests/conftest.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4261e71f1d..7bb16e204b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -import gc import json import os import types @@ -204,27 +203,7 @@ def is_cuda_oom_error_str(e: str) -> bool: return "CUDA" in e and "out of memory" in e -def clear_cuda_cache(device: torch.device) -> None: - total_memory = get_device_properties(device).total_memory - reserved_memory = torch.cuda.memory_reserved() - - # FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.75) - threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.75")) - - if reserved_memory > threshold * total_memory: - gc.collect() - torch.cuda.empty_cache() - - -@pytest.fixture(autouse=True, scope="function") -def clear_gpu_memory(): - yield - if torch.cuda.is_available(): - device = torch.device("cuda:0") - clear_cuda_cache(device) # Use the existing function - - -@pytest.hookimpl(tryfirst=True) +@pytest.hookimpl(wrapper=True) def pytest_runtest_call(item): # skip OOM error and missing JIT cache errors try: From cb94af2a0800699d6bfdc01fcd2d20aa99729dd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rony=20Lepp=C3=A4nen?= Date: Mon, 23 Feb 2026 22:10:47 +0000 Subject: [PATCH 4/4] missing device param --- tests/rocm_tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/rocm_tests/conftest.py b/tests/rocm_tests/conftest.py index d75645f61d..98e4d81460 100644 --- a/tests/rocm_tests/conftest.py +++ b/tests/rocm_tests/conftest.py @@ -63,7 +63,7 @@ def pytest_xdist_auto_num_workers(config): def _maybe_clear_gpu_memory(device: torch.device) -> None: total_memory = torch.cuda.get_device_properties(device).total_memory - reserved_memory = torch.cuda.memory_reserved() + reserved_memory = torch.cuda.memory_reserved(device) # FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.75) threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.75"))