diff --git a/tests/rocm_tests/conftest.py b/tests/rocm_tests/conftest.py index 6ae60760b2..98e4d81460 100644 --- a/tests/rocm_tests/conftest.py +++ b/tests/rocm_tests/conftest.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import os +import pytest +import torch from flashinfer.hip_utils import get_available_gpu_count @@ -56,3 +59,23 @@ def pytest_xdist_auto_num_workers(config): "Check HIP_VISIBLE_DEVICES or ROCm installation." ) return n + + +def _maybe_clear_gpu_memory(device: torch.device) -> None: + total_memory = torch.cuda.get_device_properties(device).total_memory + reserved_memory = torch.cuda.memory_reserved(device) + + # FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.75) + threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.75")) + + if reserved_memory > threshold * total_memory: + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture(autouse=True, scope="function") +def clear_gpu_memory(): + yield + if torch.cuda.is_available(): + # Assume single GPU per worker due to device pinning in pytest_configure + _maybe_clear_gpu_memory(torch.device("cuda"))