Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/rocm_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import os
import pytest
import torch

from flashinfer.hip_utils import get_available_gpu_count

Expand Down Expand Up @@ -56,3 +59,23 @@ def pytest_xdist_auto_num_workers(config):
"Check HIP_VISIBLE_DEVICES or ROCm installation."
)
return n


def _maybe_clear_gpu_memory(device: torch.device) -> None:
total_memory = torch.cuda.get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved(device)

# FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.75)
threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.75"))

if reserved_memory > threshold * total_memory:
gc.collect()
torch.cuda.empty_cache()


@pytest.fixture(autouse=True, scope="function")
def clear_gpu_memory():
yield
if torch.cuda.is_available():
# Assume single GPU per worker due to device pinning in pytest_configure
_maybe_clear_gpu_memory(torch.device("cuda"))
Loading