From 6c25d2224b45255d63e9ab954736236f409d2956 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 23 Apr 2026 23:23:09 +0000 Subject: [PATCH 01/20] docs: refresh README for amd-flashinfer library consumers Reframe the top-level README for developers embedding FlashInfer+ROCm. Add a minimal usage example, feature matrix with prefill backends (fa2, aiter, fa3_cdna3), consolidated GPU/ROCm/PyTorch support, AITER page-size constraints from prefill_rocm, notebook link, and a dedicated prefill backends section. Remove verbose docker details blocks in favor of inline context. Made-with: Cursor --- README.md | 337 ++++++++++++++---------------------------------------- 1 file changed, 87 insertions(+), 250 deletions(-) diff --git a/README.md b/README.md index 1446c2f4d9..167c50cbce 100644 --- a/README.md +++ b/README.md @@ -1,68 +1,81 @@ # FlashInfer+ROCm: An AMD ROCm port of FlashInfer -FlashInfer+ROCm is a port of the [FlashInfer](https://github.com/flashinfer-ai/flashinfer) library -that adds support for AMD Instinct GPUs. The project is in active development with current focus on -porting attention kernels to ROCm. +FlashInfer+ROCm is an AMD ROCm port of the [FlashInfer](https://github.com/flashinfer-ai/flashinfer) library of fast attention, RoPE, RMSNorm, sampling, and logits-processor kernels for LLM inference on AMD Instinct GPUs. This README is aimed at library consumers: developers embedding FlashInfer kernels into their own training or serving stack. -**Versioning:** The release tag format `+amd` ties each FlashInfer+ROCm release -to its corresponding upstream tag (e.g., `0.2.5+amd.2` is second release of amd-flashinfer based on upstream version `v0.2.5`). +**Status:** Active development, attention (single/batch prefill and decode) is the primary focus. See [CHANGELOG.md](CHANGELOG.md) for the full history. + +**Versioning:** Release tags use the form `+amd.` (for example, `0.5.3+amd.1` is the first AMD release based on upstream `v0.5.3`). + +## Minimal usage + +```python +import torch +import flashinfer + +# Device is still "cuda" on PyTorch+ROCm. +q = torch.randn(1024, 32, 128, dtype=torch.float16, device="cuda") +k = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") # GQA 4:1 +v = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") + +# Default backend = "fa2". Use backend="aiter" or "fa3_cdna3" to switch. +o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) +``` ## Table of Contents * [Feature Support Matrix](#feature-support-matrix) -* [GPU and ROCm Support](#gpu-and-rocm-support) +* [GPU, ROCm, and PyTorch Support](#gpu-rocm-and-pytorch-support) * [Getting Started](#getting-started) - * [Option 1: Get a Pre-built Docker Image](#option-1-get-a-pre-built-docker-image) + * [Option 1: Pre-built Docker Image](#option-1-pre-built-docker-image) * [Option 2: Install from a Wheel Package](#option-2-install-from-a-wheel-package) - * [Trying the Examples](#trying-the-examples) + * [Running the Examples](#running-the-examples) * [Build from Source](#build-from-source) - * [Setting up a Development Environment](#setting-up-a-development-environment) - * [Building and Installing a Wheel Package](#building-and-installing-a-wheel-package) + * [Development Environment](#development-environment) + * [Building and Installing a Wheel](#building-and-installing-a-wheel) * [Running Tests](#running-tests) -* [AITER Support](#aiter-support) - * [Single Prefill AITER example](#single-prefill-example) +* [Prefill Backends](#prefill-backends) +* [Contributing and License](#contributing-and-license) ## Feature Support Matrix -| Kernel Type | FP16 / BF16 | FP8 (E4M3, E5M2) | Has AITER backend | Notes | -| :--- | :---: | :---: | :---: | :--- | -| **Decode Attention** | ✅ | ✅ | No | Supports MHA, GQA, and MQA | -| **Prefill Attention** | ✅ | WIP | ✅ | Supports MHA, GQA, and MQA | -| **Cascade Attention** | TBD | TBD | No | Not Yet Ported | -| **MLA** | TBD | TBD | No | Not Yet Ported | -| **POD** | TBD | TBD | No | Not Yet Ported | -| **Positional Encoding** | TBD | TBD | No | Not Yet Ported | -| **Sampling** | ✅ | TBD | No | Supports Top-K/Top-P Sampling/OnlineSoftmax/SamplingFromLogits | -| **Logits Processor** | ✅ | TBD | No | | -| **Normalization** | ✅ | TBD | No | Supports RMS-Norm/Layer-Norm | +| Kernel | FP16 / BF16 | FP8 (E4M3, E5M2) | Backends | Notes | +| :--- | :---: | :---: | :--- | :--- | +| **Decode attention** | Yes | Yes | `fa2` | MHA, GQA, MQA | +| **Prefill attention** | Yes | WIP | `fa2`, `aiter`, `fa3_cdna3` | MHA, GQA, MQA | +| **RoPE** (incl. Llama 3.1, fused RoPE+FP8+paged-KV append) | Yes | - | `fa2` | | +| **RMSNorm / LayerNorm / Gemma variants** | Yes | - | `fa2` | | +| **Sampling** | Yes | - | `fa2` | Top-K, Top-P, OnlineSoftmax, SamplingFromLogits | +| **Logits processor** | Yes | - | `fa2` | | +| **Quantization** (`packbits`, `segment_packbits`) | Yes | - | `fa2` | | +| Cascade, MLA, POD, PosEncoding-mode variants | - | - | - | Not yet ported | -## GPU and ROCm Support +## GPU, ROCm, and PyTorch Support -**Supported GPU:** gfx942 (CDNA3 architecture), gfx950 (CDNA4 architecture) +**GPU architectures:** gfx942 (CDNA3 — MI300X, MI325X), gfx950 (CDNA4 — MI355X). -**Supported ROCm versions:** 7.0.2, 7.1.1, 7.2 +**ROCm:** 7.0.2, 7.1.1, 7.2. -## Torch Version Support +**PyTorch+ROCm:** 2.8.0, 2.9.1. Install the matching wheel from `repo.radeon.com`: -**Torch+ROCm:** 2.8.0, 2.9.1 +```bash +pip install torch==2.9.1 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2 +``` -**Note**: Other versions may work but have not been tested. Refer to (replacing `{rocm-version}` with the desired ROCm version, e.g., `7.0.2`) for available versions. +Other versions may work but are not tested. Replace `7.2` with the ROCm version you need; see for the full list. ## Getting Started -### Option 1: Get a Pre-built Docker Image +### Option 1: Pre-built Docker Image -AMD validates and publishes [FlashInfer images](https://hub.docker.com/r/rocm/flashinfer/tags) -with ROCm backends on Docker Hub. The following Docker image tag and associated -inventories represent the latest available FlashInfer version from the official Docker Hub. +AMD validates and publishes [FlashInfer images](https://hub.docker.com/r/rocm/flashinfer/tags) on Docker Hub: | Docker image | ROCm | FlashInfer | PyTorch | Ubuntu | Python | GPU | -| ------------ | ---- | ---------- | ------- | ------ | ------ | --- | -| rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 |7.2.0 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355x, MI325X, MI300X | -| rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.0.2_ubuntu24.04_py3.12_pytorch2.9.1 | 7.0.2 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355x, MI325X, MI300X | -| rocm/flashinfer:flashinfer-0.2.5.amd2_rocm7.1.1_ubuntu24.04_py3.12_pytorch2.8 | 7.1.1 | v0.2.5 | 2.8.0 | 24.04 | 3.12 | MI325X, MI300X | +| --- | --- | --- | --- | --- | --- | --- | +| `rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1` | 7.2.0 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355X, MI325X, MI300X | +| `rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.0.2_ubuntu24.04_py3.12_pytorch2.9.1` | 7.0.2 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355X, MI325X, MI300X | +| `rocm/flashinfer:flashinfer-0.2.5.amd2_rocm7.1.1_ubuntu24.04_py3.12_pytorch2.8` | 7.1.1 | v0.2.5 | 2.8.0 | 24.04 | 3.12 | MI325X, MI300X | -**Start a container:** +Start a container: ```bash docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \ @@ -70,63 +83,37 @@ docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \ --ipc=host --shm-size 128G --name= ``` -**Activate the environment and verify:** +Verify: ```bash -# Activate micromamba environment (Note: env name may vary based on the image) -micromamba activate base - -# Verify installation +micromamba activate base # env name may vary per image python -c "import flashinfer; print(flashinfer.__version__)" +# expected: 0.5.3+amd.1 ``` -Expected output: `0.5.3+amd.1` (with a possible JIT backend message) - ### Option 2: Install from a Wheel Package -Install from AMD's package repository: - ```bash pip install amd-flashinfer --index-url https://pypi.amd.com/simple/ -``` - -Install the needed ROCm-enabled torch package from : - -```bash pip install torch==2.9.1 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2 ``` -**NOTE**: The torch version should be exactly as available on repo.radeon.com otherwise a non-ROCm -torch version will get installed from pypi. +> `torch` is deliberately not a declared dependency because the ROCm wheel must come from `repo.radeon.com`, not PyPI. Installing without `-f` will pull a non-ROCm build. -### Trying the Examples - -Download and run example scripts from the repository: +### Running the Examples ```bash -# Download a single example wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/single_prefill_example.py +wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/batch_prefill_example.py +wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/batch_decode_example.py python single_prefill_example.py - -# Download all examples -for example in single_prefill_example.py batch_prefill_example.py batch_decode_example.py; do - wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/$example -done ``` -**Available examples:** - -* `single_prefill_example.py` - Single-sequence prefill attention -* `batch_prefill_example.py` - Batched prefill attention -* `batch_decode_example.py` - Batched decode attention -* `examples/amd_flashinfer_rocm_tutorial.ipynb` - Jupyter tutorial: environment verification (`hip_utils`), AITER-backed prefill examples, and `logits_processor` on ROCm -* `examples/run_jupyter_server.sh` - Start JupyterLab from the repo root (run inside your ROCm/FlashInfer environment or Docker container) +An end-to-end recommendation-system notebook that exercises the full public API is also available at [`examples/recommendation_system_flashinfer_rocm.ipynb`](examples/recommendation_system_flashinfer_rocm.ipynb). ## Build from Source -### Setting up a Development Environment - -Build the development Docker image with the repository's Dockerfile: +### Development Environment ```bash docker build \ @@ -138,210 +125,60 @@ docker build \ --build-arg USER_GID=$(id -g) \ -t flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 \ -f .devcontainer/rocm/Dockerfile . -``` - - -
-Build argument descriptions - -* `ROCM_VERSION`: ROCm version (default: 7.2) -* `PY_VERSION`: Python version (default: 3.12) -* `TORCH_VERSION`: PyTorch version (default: 2.9.1) -* `USERNAME`: Username inside container (default: devuser) -* `USER_UID`: User ID for matching host permissions -* `USER_GID`: Group ID for matching host permissions - -
- - -**Run the development container:** -```bash docker run -it \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --ipc=host --privileged --shm-size=128G --network=host \ - --device=/dev/kfd --device=/dev/dri \ - --group-add video --group-add render \ - -v $PWD:/workspace \ - --name flashinfer-dev-container \ + --device=/dev/kfd --device=/dev/dri --group-add video --group-add render \ + -v $PWD:/workspace --name flashinfer-dev-container \ flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 ``` - -
-Docker run argument descriptions - -* `--cap-add=SYS_PTRACE`: Enables debugging -* `--security-opt seccomp=unconfined`: Relaxes security for development -* `--ipc=host`: Shares host IPC for better performance -* `--privileged`: Required for GPU access -* `--shm-size=128G`: Shared memory size (adjust as needed) -* `--network=host`: Uses host networking -* `--device=/dev/kfd --device=/dev/dri`: Exposes AMD GPU devices -* `--group-add video --group-add render`: GPU access groups -* `-v :`: Mounts source code - -
- - -**Note:** Environment name varies based on Python, PyTorch, and ROCm versions. - -### Building and Installing a Wheel Package - -**Build with JIT (Just-in-Time) compilation only:** +### Building and Installing a Wheel ```bash -python -m pip wheel . --wheel-dir=./dist/ --no-deps --no-build-isolation -v -cd dist && pip install amd_flashinfer-*.whl -``` - -**Editable install for development:** +# Editable install (JIT kernels compile on first use) +python -m pip install --no-build-isolation -ve . -```bash -python -m pip install --no-build-isolation -ve. +# Wheel build +python -m pip wheel . --wheel-dir=./dist/ --no-deps --no-build-isolation -v +pip install dist/amd_flashinfer-*.whl ``` -**Note:** The `--no-deps` flag assumes dependencies are pre-installed. Omit it -to download dependencies during build. AOT builds take longer and use more disk -space but avoid JIT compilation at runtime. - ### Running Tests -The Python tests suite can be run with pytest: - -```bash -# Run default tests (configured in pyproject.toml) -pytest - -# Run specific test file -pytest tests/test_decode_kernels_hip.py - -# Run with pattern matching -pytest -k "test_decode_kernels_hip" - -# Verbose output -pytest -v - -# To run tests parallely on multiple GPUs -pytest -n auto # Uses all available GPUs -pytest -n 2 # Use only two GPUs -``` - -The default test configuration is specified in [pyproject.toml](pyproject.toml) under the `testpaths` setting. - -#### Recommended invocation on AMD CPX systems - -`pytest-rerunfailures` (declared in the `dev` extra — `pip install -e ".[dev]"`) -absorbs the residual transient HIP runtime crashes. Then for the full suite: - ```bash -# Fast path — skips heavy 1M-trial sampling-frequency tests and 4 GB -# speculative-sampling cases (~7 min on a CPX 8-card host): -pytest -n auto --reruns 2 -m "not slow" - -# Full coverage — including the slow tests (~20 min): -pytest -n auto --reruns 2 - -# Slow path only (~13 min): -pytest -n auto --reruns 2 -m "slow" +pytest # curated set from pyproject.toml +pytest tests/rocm_tests # AMD-authored HIP tests +pytest -n auto # across all visible GPUs +pytest -k "test_batch_decode_kernels_hip" ``` -**Notes** - -* `pytest -n auto` for the `tests/rocm_tests/` suite spawns **half as many xdist workers as physical AMD cards** (e.g. 4 workers on a CPX-mode 8-card MI308X / MI325X host). One worker per physical card was tried first but produced sporadic failures across rope, single_prefill, and logits_cap under residual concurrent load; halving the count produces reliable green runs. Each worker is pinned to its card via `HIP_VISIBLE_DEVICES`. On non-CPX systems the helper applies the same halving; users who want every device used can pass an explicit `-n N`. -* `--reruns 2` (from `pytest-rerunfailures`) absorbs the residual ~0.01 % of transient HIP runtime crashes (HSA exceptions, HIPBLAS handle-pool exhaustion, intermittent generator non-determinism) that worker pinning cannot fully eliminate. Successful tests are not duplicated; only failed tests are retried. -* The `slow` marker is registered in [pyproject.toml](pyproject.toml). It tags the 1M-trial sampling-frequency tests, the 4 GB-tensor speculative-sampling cases, and the entire `TestLogitsPipeCompilationHIP` class (every test there runs the sampling kernel twice per case for compile=True/False). -* The reference attention helper in `tests/attention_reference.py` wraps `torch.matmul` in a `_hipblas_safe_matmul` retry helper that catches `HIPBLAS_STATUS_ALLOC_FAILED` and retries with a short back-off — needed under heavy concurrent xdist load. - -## AITER Support +The default test set is pinned in `[tool.pytest.ini_options]` in [pyproject.toml](pyproject.toml). -FlashInfer+ROCm supports the use of [AITER](https://github.com/ROCm/aiter) as a -backend. The `aiter` backend is enabled for the `single_prefill` and `batch_prefill` kernels. +## Prefill Backends -**On gfx942/gfx950 GPUs, `backend="auto"` (the default) automatically selects the AITER backend** -when the call parameters are compatible (fp16/bf16, NHD layout, no custom mask, equal Q/K/V -dtypes and head dims, `pos_encoding_mode="NONE"`). It falls back to `fa2` with a one-time -`logger.warning` when any condition is not met. You can also pass `backend="aiter"` explicitly. +All prefill entry points accept a `backend=` keyword (default `"auto"`, which resolves to `"fa2"`). -Unless you are using the prebuilt docker image, AITER must also be installed on your system. You may follow one of the following ways to do so. +- **`fa2` (default)** — In-tree HIP port of FlashAttention-2. Broadest coverage: paged + ragged, single + batch, fp16/bf16. +- **`aiter`** — Wraps [AITER](https://github.com/ROCm/aiter) CK FMHA (`flash_attn_varlen_func`, `mha_batch_prefill_func`). NHD layout only; paged batch-prefill `page_size ∈ {16, 1024}` (or `{128, 256, 1024}` on `amd-aiter==0.1.10`). +- **`fa3_cdna3`** — MI300X-optimized single-prefill kernel for chunked prefill, `head_dim=256`, `q_len != kv_len`. Experimental; see [`benchmarks/rocm_benchmarks/bench_fa3_cdna3.py`](benchmarks/rocm_benchmarks/bench_fa3_cdna3.py) and [`examples/single_prefill_example.py`](examples/single_prefill_example.py). -### Install AITER from source - -```bash -git clone --recursive https://github.com/ROCm/aiter.git -cd aiter -python3 setup.py develop -``` - -### Install AITER wheel package - -Wheel packages are available from AMD's PyPI index: [pypi.amd.com/simple](https://pypi.amd.com/simple/). +Install AITER if you plan to use `backend="aiter"` outside the prebuilt Docker image: ```bash pip install amd-aiter --index-url https://pypi.amd.com/simple/ +# or: git clone --recursive https://github.com/ROCm/aiter.git && cd aiter && python3 setup.py develop ``` -### Known Limitations - -The AITER backend has the following constraints. With `backend="aiter"` the -call will error on the first group of conditions, or for the second group, -run but silently ignore the unsupported argument. - -**Conditions that fall back to `fa2` under `backend="auto"`:** - -* GPU is not gfx942 or gfx950 -* `kv_layout` is not `NHD` -* a custom attention mask tensor is supplied -* `q_dtype` is not `float16` / `bfloat16` (no fp32, fp8, or int8) -* `q_dtype != kv_dtype` (mixed-precision Q/KV is unsupported) -* `head_dim_qk != head_dim_vo` (e.g. DeepSeek-style MLA with 192/128 head dims) -* the `aiter` Python package is not importable - -**Features silently ignored on the AITER path** (the kwargs are accepted by -the FlashInfer wrapper but not forwarded to AITER, which can produce wrong -results — pass `backend="fa2"` explicitly if you need any of these): +Example: -* ALiBi slopes (`maybe_alibi_slopes`) -* in-kernel positional encoding modes (`pos_encoding_mode`, `rope_scale`, - `rope_theta`) -* attention sinks (`sinks`) -* multi-modal / prefix-cache helpers (`maybe_prefix_len_ptr`, - `maybe_token_pos_in_items_ptr`, `maybe_max_item_len_ptr`) -* FP8 dequant scales (`scale_q` / `scale_k` / `scale_v`) -* `use_fp16_qk_reduction`, `enable_pdl` - -**Other notes:** - -* Batch prefill: AITER's CK FMHA kernels natively support page sizes - `{16, 1024}` (or `{128, 256, 1024}` on `amd-aiter==0.1.10`). Other page - sizes still work but go through an extra GPU gather to flatten paged KV - before the AITER call. -* Ragged (non-paged) KV is not yet implemented on the AITER batch-prefill - path. `BatchPrefillWithRaggedKVCacheWrapper` therefore forces the backend - to `fa2` regardless of whether you pass `backend="auto"` or - `backend="aiter"` (a warning is logged in the latter case). +```python +o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, backend="aiter") +``` -### Single Prefill Example +## Contributing and License -This section provides an example on how to use Single Prefill with AITER. +See [CONTRIBUTING.md](CONTRIBUTING.md). Run `pre-commit run -a` and `pytest` before opening a PR. -```python -import torch -import flashinfer - -# Configuration -seq_len = 1024 # Prompt length -num_qo_heads = 32 # Number of query/output heads -num_kv_heads = 8 # Number of KV heads (GQA with 4:1 ratio) -head_dim = 128 - -# Create Q, K, V tensors (NHD layout: sequence, heads, dimension) -q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda") -k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") -v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") - -# Run single prefill attention with causal masking -# On gfx942/gfx950, backend="auto" (default) routes to AITER automatically. -# Pass backend="aiter" to require AITER explicitly, or backend="fa2" to skip it. -output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, backend="auto") -``` +Upstream project: [flashinfer-ai/flashinfer](https://github.com/flashinfer-ai/flashinfer). Released under the Apache-2.0 License — [LICENSE](LICENSE), [NOTICE](NOTICE). From 3b7de1374a754210165bcc30a3e2ff79cba12256 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 04:33:26 +0000 Subject: [PATCH 02/20] docs: align README with amd-integration and refresh feature matrix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restore the practical sections that the prior rewrite dropped (Docker tag table, source-build instructions, CPX-mode pytest guidance, AITER install recipes) and refresh the Feature Support Matrix to reflect what has actually landed on amd-integration: Cascade, MLA (AITER), RoPE, paged KV-cache append, RMSNorm/AITER, sliding-window decode, torch.compile. Drop the stale fa3_cdna3 backend mention — it has no Python dispatch entry. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 377 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 302 insertions(+), 75 deletions(-) diff --git a/README.md b/README.md index 167c50cbce..2adf5277c1 100644 --- a/README.md +++ b/README.md @@ -1,81 +1,115 @@ # FlashInfer+ROCm: An AMD ROCm port of FlashInfer -FlashInfer+ROCm is an AMD ROCm port of the [FlashInfer](https://github.com/flashinfer-ai/flashinfer) library of fast attention, RoPE, RMSNorm, sampling, and logits-processor kernels for LLM inference on AMD Instinct GPUs. This README is aimed at library consumers: developers embedding FlashInfer kernels into their own training or serving stack. +FlashInfer+ROCm is an AMD ROCm port of the +[FlashInfer](https://github.com/flashinfer-ai/flashinfer) attention, +RoPE, normalization, sampling, and logits-processor kernels for LLM +inference on AMD Instinct GPUs. The port targets CDNA3 (gfx942 — +MI300X / MI325X) and CDNA4 (gfx950 — MI355X), and is aimed at developers +embedding FlashInfer kernels into their own training or serving stack. + +The project is in active development with the primary focus on attention +(single and batch prefill / decode) and the surrounding KV-cache, RoPE, +and normalization kernels. See [CHANGELOG.md](CHANGELOG.md) for the +full release history. + +**Versioning:** The release tag format `+amd.` ties +each FlashInfer+ROCm release to its corresponding upstream tag (e.g. +`0.5.3+amd.1` is the first AMD release based on upstream `v0.5.3`). -**Status:** Active development, attention (single/batch prefill and decode) is the primary focus. See [CHANGELOG.md](CHANGELOG.md) for the full history. +## Table of Contents -**Versioning:** Release tags use the form `+amd.` (for example, `0.5.3+amd.1` is the first AMD release based on upstream `v0.5.3`). +* [Basic Usage](#basic-usage) +* [Feature Support Matrix](#feature-support-matrix) +* [GPU, ROCm, and PyTorch Support](#gpu-rocm-and-pytorch-support) +* [Getting Started](#getting-started) + * [Option 1: Get a Pre-built Docker Image](#option-1-get-a-pre-built-docker-image) + * [Option 2: Install from a Wheel Package](#option-2-install-from-a-wheel-package) + * [Trying the Examples](#trying-the-examples) +* [Build from Source](#build-from-source) + * [Setting up a Development Environment](#setting-up-a-development-environment) + * [Building and Installing a Wheel Package](#building-and-installing-a-wheel-package) + * [Running Tests](#running-tests) +* [AITER Support](#aiter-support) + * [Install AITER from source](#install-aiter-from-source) + * [Install AITER wheel package](#install-aiter-wheel-package) + * [Known Limitations](#known-limitations) + * [Single Prefill Example](#single-prefill-example) +* [License and Acknowledgements](#license-and-acknowledgements) -## Minimal usage +## Basic Usage ```python import torch import flashinfer -# Device is still "cuda" on PyTorch+ROCm. +# PyTorch+ROCm still uses device="cuda" for AMD GPUs. q = torch.randn(1024, 32, 128, dtype=torch.float16, device="cuda") k = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") # GQA 4:1 v = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") -# Default backend = "fa2". Use backend="aiter" or "fa3_cdna3" to switch. -o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) +# backend="auto" (default) routes to AITER when supported on gfx942/gfx950 +# and falls back to the in-tree fa2 HIP kernel otherwise. +output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) ``` -## Table of Contents - -* [Feature Support Matrix](#feature-support-matrix) -* [GPU, ROCm, and PyTorch Support](#gpu-rocm-and-pytorch-support) -* [Getting Started](#getting-started) - * [Option 1: Pre-built Docker Image](#option-1-pre-built-docker-image) - * [Option 2: Install from a Wheel Package](#option-2-install-from-a-wheel-package) - * [Running the Examples](#running-the-examples) -* [Build from Source](#build-from-source) - * [Development Environment](#development-environment) - * [Building and Installing a Wheel](#building-and-installing-a-wheel) - * [Running Tests](#running-tests) -* [Prefill Backends](#prefill-backends) -* [Contributing and License](#contributing-and-license) +See [`examples/`](examples/) for batch prefill, batch decode, and a +Jupyter tutorial that walks through the full public API on ROCm. ## Feature Support Matrix -| Kernel | FP16 / BF16 | FP8 (E4M3, E5M2) | Backends | Notes | -| :--- | :---: | :---: | :--- | :--- | -| **Decode attention** | Yes | Yes | `fa2` | MHA, GQA, MQA | -| **Prefill attention** | Yes | WIP | `fa2`, `aiter`, `fa3_cdna3` | MHA, GQA, MQA | -| **RoPE** (incl. Llama 3.1, fused RoPE+FP8+paged-KV append) | Yes | - | `fa2` | | -| **RMSNorm / LayerNorm / Gemma variants** | Yes | - | `fa2` | | -| **Sampling** | Yes | - | `fa2` | Top-K, Top-P, OnlineSoftmax, SamplingFromLogits | -| **Logits processor** | Yes | - | `fa2` | | -| **Quantization** (`packbits`, `segment_packbits`) | Yes | - | `fa2` | | -| Cascade, MLA, POD, PosEncoding-mode variants | - | - | - | Not yet ported | +| Kernel Type | FP16 / BF16 | FP8 (E4M3, E5M2) | Has AITER backend | Notes | +| :--- | :---: | :---: | :---: | :--- | +| **Single / Batch Decode Attention** | ✅ | ✅ (E4M3FNUZ KV-cache) | ✅ (batch paged, fp16/bf16) | MHA, GQA, MQA; sliding-window on the AITER path; CUDA-graph support | +| **Single / Batch Prefill Attention** | ✅ | WIP | ✅ (single, batch-paged, batch-ragged) | MHA, GQA, MQA | +| **Cascade Attention** | ✅ | — | No | Two-level shared-prefix attention | +| **MLA (Multi-Latent Attention)** | ✅ (bf16, `page_size=1`) | — | ✅ (AITER-only path) | DeepSeek-style 192/128 head-dim split; **requires AITER** on ROCm | +| **POD Attention** | TBD | TBD | No | Code present; **not yet validated on ROCm** | +| **RoPE (Positional Encoding)** | ✅ | — | No | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + paged-KV append | +| **Paged KV-Cache Append** | ✅ | ✅ | ✅ (opt-in) | `append_paged_kv_cache` | +| **Sampling** | ✅ | — | No | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | +| **Logits Processor** | ✅ | — | No | Composable processor pipeline (cap, mask, temperature, …) | +| **Normalization** | ✅ | — | ✅ (RMSNorm only) | RMSNorm, LayerNorm, Gemma RMSNorm | +| **Activation** | ✅ | — | No | SiLU / GELU with fused gating | +| **Quantization** | ✅ | — | No | `packbits`, `segment_packbits` | +| **`torch.compile`** | ✅ (opt-in) | — | n/a | Enabled via the `FLASHINFER_ENABLE_TORCH_COMPILE` env flag | + +Every ✅ row above is exercised by a matching `tests/rocm_tests/test_*_hip.py`. ## GPU, ROCm, and PyTorch Support -**GPU architectures:** gfx942 (CDNA3 — MI300X, MI325X), gfx950 (CDNA4 — MI355X). +**Supported GPUs:** gfx942 (CDNA3 — MI300X, MI325X), gfx950 (CDNA4 — MI355X). -**ROCm:** 7.0.2, 7.1.1, 7.2. +**Supported ROCm versions:** 7.0.2, 7.1.1, 7.2. -**PyTorch+ROCm:** 2.8.0, 2.9.1. Install the matching wheel from `repo.radeon.com`: +**Supported PyTorch+ROCm versions:** 2.8.0, 2.9.1. + +Install the matching ROCm-enabled PyTorch wheel from +: ```bash -pip install torch==2.9.1 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2 +pip install torch==2.9.1 --index-url https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ ``` -Other versions may work but are not tested. Replace `7.2` with the ROCm version you need; see for the full list. +Other versions may work but have not been tested. Replace `7.2` with the +ROCm version you need; refer to + for +available wheels. ## Getting Started -### Option 1: Pre-built Docker Image +### Option 1: Get a Pre-built Docker Image -AMD validates and publishes [FlashInfer images](https://hub.docker.com/r/rocm/flashinfer/tags) on Docker Hub: +AMD validates and publishes [FlashInfer images](https://hub.docker.com/r/rocm/flashinfer/tags) +with ROCm backends on Docker Hub. The following Docker image tags +represent the latest available FlashInfer+ROCm releases: | Docker image | ROCm | FlashInfer | PyTorch | Ubuntu | Python | GPU | -| --- | --- | --- | --- | --- | --- | --- | +| ------------ | ---- | ---------- | ------- | ------ | ------ | --- | | `rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1` | 7.2.0 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355X, MI325X, MI300X | | `rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.0.2_ubuntu24.04_py3.12_pytorch2.9.1` | 7.0.2 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355X, MI325X, MI300X | | `rocm/flashinfer:flashinfer-0.2.5.amd2_rocm7.1.1_ubuntu24.04_py3.12_pytorch2.8` | 7.1.1 | v0.2.5 | 2.8.0 | 24.04 | 3.12 | MI325X, MI300X | -Start a container: +**Start a container:** ```bash docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \ @@ -83,37 +117,66 @@ docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \ --ipc=host --shm-size 128G --name= ``` -Verify: +**Activate the environment and verify:** ```bash -micromamba activate base # env name may vary per image +# Activate the micromamba environment (env name may vary based on the image) +micromamba activate base + +# Verify installation python -c "import flashinfer; print(flashinfer.__version__)" -# expected: 0.5.3+amd.1 ``` +Expected output: `0.5.3+amd.1` (with a possible JIT backend message). + ### Option 2: Install from a Wheel Package +Install from AMD's package repository: + ```bash pip install amd-flashinfer --index-url https://pypi.amd.com/simple/ -pip install torch==2.9.1 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2 ``` -> `torch` is deliberately not a declared dependency because the ROCm wheel must come from `repo.radeon.com`, not PyPI. Installing without `-f` will pull a non-ROCm build. +Install the matching ROCm-enabled torch package from : + +```bash +pip install torch==2.9.1 --index-url https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ +``` + +**NOTE:** Use `--index-url` (not `-f`) so pip cannot silently fall back +to a CPU-only PyPI wheel. -### Running the Examples +### Trying the Examples + +Download and run example scripts from the repository: ```bash +# Download a single example wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/single_prefill_example.py -wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/batch_prefill_example.py -wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/batch_decode_example.py python single_prefill_example.py + +# Download all examples +for example in single_prefill_example.py batch_prefill_example.py batch_decode_example.py; do + wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/$example +done ``` -An end-to-end recommendation-system notebook that exercises the full public API is also available at [`examples/recommendation_system_flashinfer_rocm.ipynb`](examples/recommendation_system_flashinfer_rocm.ipynb). +**Available examples:** + +* `single_prefill_example.py` — single-sequence prefill attention +* `batch_prefill_example.py` — batched prefill attention +* `batch_decode_example.py` — batched decode attention +* `examples/amd_flashinfer_rocm_tutorial.ipynb` — Jupyter tutorial: + environment verification (`hip_utils`), AITER-backed prefill examples, + and `logits_processor` on ROCm +* `examples/run_jupyter_server.sh` — start JupyterLab from the repo root + (run inside your ROCm/FlashInfer environment or Docker container) ## Build from Source -### Development Environment +### Setting up a Development Environment + +Build the development Docker image with the repository's Dockerfile: ```bash docker build \ @@ -125,60 +188,224 @@ docker build \ --build-arg USER_GID=$(id -g) \ -t flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 \ -f .devcontainer/rocm/Dockerfile . +``` + + +
+Build argument descriptions + +* `ROCM_VERSION`: ROCm version (default: 7.2) +* `PY_VERSION`: Python version (default: 3.12) +* `TORCH_VERSION`: PyTorch version (default: 2.9.1) +* `USERNAME`: Username inside container (default: devuser) +* `USER_UID`: User ID for matching host permissions +* `USER_GID`: Group ID for matching host permissions +
+ + +**Run the development container:** + +```bash docker run -it \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --ipc=host --privileged --shm-size=128G --network=host \ - --device=/dev/kfd --device=/dev/dri --group-add video --group-add render \ - -v $PWD:/workspace --name flashinfer-dev-container \ + --device=/dev/kfd --device=/dev/dri \ + --group-add video --group-add render \ + -v $PWD:/workspace \ + --name flashinfer-dev-container \ flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 ``` -### Building and Installing a Wheel + +
+Docker run argument descriptions -```bash -# Editable install (JIT kernels compile on first use) -python -m pip install --no-build-isolation -ve . +* `--cap-add=SYS_PTRACE`: Enables debugging +* `--security-opt seccomp=unconfined`: Relaxes security for development +* `--ipc=host`: Shares host IPC for better performance +* `--privileged`: Required for GPU access +* `--shm-size=128G`: Shared memory size (adjust as needed) +* `--network=host`: Uses host networking +* `--device=/dev/kfd --device=/dev/dri`: Exposes AMD GPU devices +* `--group-add video --group-add render`: GPU access groups +* `-v :`: Mounts source code -# Wheel build +
+ + +**Note:** Environment name varies based on Python, PyTorch, and ROCm +versions. + +### Building and Installing a Wheel Package + +**Build with JIT (Just-in-Time) compilation only:** + +```bash python -m pip wheel . --wheel-dir=./dist/ --no-deps --no-build-isolation -v -pip install dist/amd_flashinfer-*.whl +cd dist && pip install amd_flashinfer-*.whl ``` +**Editable install for development:** + +```bash +python -m pip install --no-build-isolation -ve . +``` + +**Note:** The `--no-deps` flag assumes dependencies are pre-installed. +Omit it to download dependencies during build. AOT builds take longer +and use more disk space but avoid JIT compilation at runtime. + ### Running Tests +The Python tests suite can be run with pytest: + ```bash -pytest # curated set from pyproject.toml -pytest tests/rocm_tests # AMD-authored HIP tests -pytest -n auto # across all visible GPUs +# Run default tests (configured in pyproject.toml) +pytest + +# Run specific test file +pytest tests/rocm_tests/test_batch_decode_kernels_hip.py + +# Run with pattern matching pytest -k "test_batch_decode_kernels_hip" + +# Verbose output +pytest -v + +# Run tests in parallel across multiple GPUs +pytest -n auto # Uses all available GPUs +pytest -n 2 # Use only two GPUs +``` + +The default test configuration is specified in [pyproject.toml](pyproject.toml) +under the `testpaths` setting. + +#### Recommended invocation on AMD CPX systems + +`pytest-rerunfailures` (declared in the `dev` extra — `pip install -e ".[dev]"`) +absorbs the residual transient HIP runtime crashes. Then for the full suite: + +```bash +# Fast path — skips heavy 1M-trial sampling-frequency tests and 4 GB +# speculative-sampling cases (~7 min on a CPX 8-card host): +pytest -n auto --reruns 2 -m "not slow" + +# Full coverage — including the slow tests (~20 min): +pytest -n auto --reruns 2 + +# Slow path only (~13 min): +pytest -n auto --reruns 2 -m "slow" ``` -The default test set is pinned in `[tool.pytest.ini_options]` in [pyproject.toml](pyproject.toml). +**Notes** + +* `pytest -n auto` for the `tests/rocm_tests/` suite spawns **half as many xdist workers as physical AMD cards** (e.g. 4 workers on a CPX-mode 8-card MI308X / MI325X host). One worker per physical card was tried first but produced sporadic failures across rope, single_prefill, and logits_cap under residual concurrent load; halving the count produces reliable green runs. Each worker is pinned to its card via `HIP_VISIBLE_DEVICES`. On non-CPX systems the helper applies the same halving; users who want every device used can pass an explicit `-n N`. +* `--reruns 2` (from `pytest-rerunfailures`) absorbs the residual ~0.01 % of transient HIP runtime crashes (HSA exceptions, HIPBLAS handle-pool exhaustion, intermittent generator non-determinism) that worker pinning cannot fully eliminate. Successful tests are not duplicated; only failed tests are retried. +* The `slow` marker is registered in [pyproject.toml](pyproject.toml). It tags the 1M-trial sampling-frequency tests, the 4 GB-tensor speculative-sampling cases, and the entire `TestLogitsPipeCompilationHIP` class (every test there runs the sampling kernel twice per case for compile=True/False). +* The reference attention helper in `tests/attention_reference.py` wraps `torch.matmul` in a `_hipblas_safe_matmul` retry helper that catches `HIPBLAS_STATUS_ALLOC_FAILED` and retries with a short back-off — needed under heavy concurrent xdist load. -## Prefill Backends +## AITER Support -All prefill entry points accept a `backend=` keyword (default `"auto"`, which resolves to `"fa2"`). +FlashInfer+ROCm supports the use of [AITER](https://github.com/ROCm/aiter) as a +backend. The `aiter` backend is enabled for the `single_prefill`, +`batch_prefill` (paged and ragged), `batch_decode`, `append_paged_kv_cache`, +`rmsnorm`, and `MLA` paths. MLA on ROCm is **only** available via AITER — +there is no in-tree HIP MLA kernel yet. -- **`fa2` (default)** — In-tree HIP port of FlashAttention-2. Broadest coverage: paged + ragged, single + batch, fp16/bf16. -- **`aiter`** — Wraps [AITER](https://github.com/ROCm/aiter) CK FMHA (`flash_attn_varlen_func`, `mha_batch_prefill_func`). NHD layout only; paged batch-prefill `page_size ∈ {16, 1024}` (or `{128, 256, 1024}` on `amd-aiter==0.1.10`). -- **`fa3_cdna3`** — MI300X-optimized single-prefill kernel for chunked prefill, `head_dim=256`, `q_len != kv_len`. Experimental; see [`benchmarks/rocm_benchmarks/bench_fa3_cdna3.py`](benchmarks/rocm_benchmarks/bench_fa3_cdna3.py) and [`examples/single_prefill_example.py`](examples/single_prefill_example.py). +**On gfx942/gfx950 GPUs, `backend="auto"` (the default) automatically selects the AITER backend** +when the call parameters are compatible (fp16/bf16, NHD layout, no custom mask, equal Q/K/V +dtypes and head dims, `pos_encoding_mode="NONE"`). It falls back to `fa2` with a one-time +`logger.warning` when any condition is not met. You can also pass `backend="aiter"` explicitly. -Install AITER if you plan to use `backend="aiter"` outside the prebuilt Docker image: +Unless you are using the prebuilt docker image, AITER must also be installed on your system. You may follow one of the following ways to do so. + +### Install AITER from source + +```bash +git clone --recursive https://github.com/ROCm/aiter.git +cd aiter +python3 setup.py develop +``` + +### Install AITER wheel package + +Wheel packages are available from AMD's PyPI index: [pypi.amd.com/simple](https://pypi.amd.com/simple/). ```bash pip install amd-aiter --index-url https://pypi.amd.com/simple/ -# or: git clone --recursive https://github.com/ROCm/aiter.git && cd aiter && python3 setup.py develop ``` -Example: +### Known Limitations + +The AITER backend has the following constraints. With `backend="aiter"` the +call will error on the first group of conditions, or for the second group, +run but silently ignore the unsupported argument. + +**Conditions that fall back to `fa2` under `backend="auto"`:** + +* GPU is not gfx942 or gfx950 +* `kv_layout` is not `NHD` +* a custom attention mask tensor is supplied +* `q_dtype` is not `float16` / `bfloat16` (no fp32, fp8, or int8) +* `q_dtype != kv_dtype` (mixed-precision Q/KV is unsupported) +* `head_dim_qk != head_dim_vo` (e.g. DeepSeek-style MLA with 192/128 head dims) +* the `aiter` Python package is not importable + +**Features silently ignored on the AITER path** (the kwargs are accepted by +the FlashInfer wrapper but not forwarded to AITER, which can produce wrong +results — pass `backend="fa2"` explicitly if you need any of these): + +* ALiBi slopes (`maybe_alibi_slopes`) +* in-kernel positional encoding modes (`pos_encoding_mode`, `rope_scale`, + `rope_theta`) +* attention sinks (`sinks`) +* multi-modal / prefix-cache helpers (`maybe_prefix_len_ptr`, + `maybe_token_pos_in_items_ptr`, `maybe_max_item_len_ptr`) +* FP8 dequant scales (`scale_q` / `scale_k` / `scale_v`) +* `use_fp16_qk_reduction`, `enable_pdl` + +**Other notes:** + +* Batch prefill: AITER's CK FMHA kernels natively support page sizes + `{16, 1024}` (or `{128, 256, 1024}` on `amd-aiter==0.1.10`). Other page + sizes still work but go through an extra GPU gather to flatten paged KV + before the AITER call. +* Ragged (non-paged) batch prefill via AITER is supported through + `BatchPrefillWithRaggedKVCacheWrapper`. The wrapper auto-routes to + AITER under `backend="auto"` when the standard AITER compatibility + conditions are met and falls back to `fa2` otherwise. +* MLA on ROCm currently supports only `bfloat16` and `page_size=1` + through the AITER backend. + +### Single Prefill Example + +This section provides an example on how to use Single Prefill with AITER. ```python -o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, backend="aiter") -``` +import torch +import flashinfer -## Contributing and License +# Configuration +seq_len = 1024 # Prompt length +num_qo_heads = 32 # Number of query/output heads +num_kv_heads = 8 # Number of KV heads (GQA with 4:1 ratio) +head_dim = 128 + +# Create Q, K, V tensors (NHD layout: sequence, heads, dimension) +q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda") +k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") +v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") + +# Run single prefill attention with causal masking. +# On gfx942/gfx950, backend="auto" (default) routes to AITER automatically. +# Pass backend="aiter" to require AITER explicitly, or backend="fa2" to skip it. +output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, backend="auto") +``` -See [CONTRIBUTING.md](CONTRIBUTING.md). Run `pre-commit run -a` and `pytest` before opening a PR. +## License and Acknowledgements -Upstream project: [flashinfer-ai/flashinfer](https://github.com/flashinfer-ai/flashinfer). Released under the Apache-2.0 License — [LICENSE](LICENSE), [NOTICE](NOTICE). +FlashInfer+ROCm is released under the Apache-2.0 License — see +[LICENSE](LICENSE) and [NOTICE](NOTICE). Upstream project: +[flashinfer-ai/flashinfer](https://github.com/flashinfer-ai/flashinfer). +Run `pre-commit run -a` and `pytest` before opening a PR. From 1256e66613a71d66dca2ca26aa850d9854e33220 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 12:52:50 +0000 Subject: [PATCH 03/20] docs: clarify HIP vs AITER backends and auto-routing in feature matrix Split the single "AITER backend" column into HIP and AITER columns plus a new `backend="auto"` column that spells out the exact conditions that auto-routes to AITER vs. HIP per kernel. MLA is flagged as AITER-only (no HIP fallback); RMSNorm auto stays on HIP even though AITER is available (opt-in only). Co-Authored-By: Claude Sonnet 4.6 --- README.md | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 2adf5277c1..5f040ff5cf 100644 --- a/README.md +++ b/README.md @@ -57,23 +57,38 @@ Jupyter tutorial that walks through the full public API on ROCm. ## Feature Support Matrix -| Kernel Type | FP16 / BF16 | FP8 (E4M3, E5M2) | Has AITER backend | Notes | -| :--- | :---: | :---: | :---: | :--- | -| **Single / Batch Decode Attention** | ✅ | ✅ (E4M3FNUZ KV-cache) | ✅ (batch paged, fp16/bf16) | MHA, GQA, MQA; sliding-window on the AITER path; CUDA-graph support | -| **Single / Batch Prefill Attention** | ✅ | WIP | ✅ (single, batch-paged, batch-ragged) | MHA, GQA, MQA | -| **Cascade Attention** | ✅ | — | No | Two-level shared-prefix attention | -| **MLA (Multi-Latent Attention)** | ✅ (bf16, `page_size=1`) | — | ✅ (AITER-only path) | DeepSeek-style 192/128 head-dim split; **requires AITER** on ROCm | -| **POD Attention** | TBD | TBD | No | Code present; **not yet validated on ROCm** | -| **RoPE (Positional Encoding)** | ✅ | — | No | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + paged-KV append | -| **Paged KV-Cache Append** | ✅ | ✅ | ✅ (opt-in) | `append_paged_kv_cache` | -| **Sampling** | ✅ | — | No | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | -| **Logits Processor** | ✅ | — | No | Composable processor pipeline (cap, mask, temperature, …) | -| **Normalization** | ✅ | — | ✅ (RMSNorm only) | RMSNorm, LayerNorm, Gemma RMSNorm | -| **Activation** | ✅ | — | No | SiLU / GELU with fused gating | -| **Quantization** | ✅ | — | No | `packbits`, `segment_packbits` | -| **`torch.compile`** | ✅ (opt-in) | — | n/a | Enabled via the `FLASHINFER_ENABLE_TORCH_COMPILE` env flag | +Most kernels ship with an in-tree HIP implementation. A subset also has +an [AITER](https://github.com/ROCm/aiter) backend; for those, the +`backend="auto"` default picks AITER when its compatibility conditions +hold and transparently falls back to HIP otherwise. AITER-only kernels +(currently MLA) require an explicit `backend="aiter"`. + +Legend: **HIP** = in-tree FlashInfer+ROCm kernel (the historical `fa2` +HIP port, or the `native` JIT kernel for non-attention ops). **AITER** = +ROCm AITER backend. + +| Kernel | HIP | AITER | `backend="auto"` resolves to | Notes | +| :--- | :---: | :---: | :--- | :--- | +| **Single decode attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA; fp16, bf16, fp8 (E4M3FNUZ KV-cache) | +| **Batch decode attention (paged)** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no CUDA-graph + `use_tensor_cores=False`; else **HIP** | Sliding-window supported on the AITER path; CUDA-graph auto-routes back to HIP | +| **Single prefill attention** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no custom mask + equal Q/KV dtypes & head dims + `pos_encoding_mode="NONE"`; else **HIP** | MHA / GQA / MQA; fp8 prefill WIP | +| **Batch prefill attention (paged + ragged)** | ✅ `fa2` | ✅ | Same auto criteria as single prefill | AITER native page sizes: `{16, 1024}` (`{128, 256, 1024}` on `amd-aiter==0.1.10`); other sizes go through a gather on the AITER path | +| **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention | +| **MLA (Multi-Latent Attention)** | — | ✅ | **AITER only** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; must pass `backend="aiter"` explicitly | +| **POD attention** | TBD | — | n/a | Code present; **not yet validated on ROCm** | +| **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + paged-KV append | +| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + AITER importable; else **HIP `native`** | `append_paged_kv_cache` | +| **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` | +| **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | | +| **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | +| **Logits processor** | ✅ | — | HIP | Composable processor pipeline (cap, mask, temperature, …) | +| **Activation** | ✅ | — | HIP | SiLU / GELU with fused gating | +| **Quantization** | ✅ | — | HIP | `packbits`, `segment_packbits` | +| **`torch.compile`** | ✅ (opt-in flag) | n/a | n/a | Enabled via the `FLASHINFER_ENABLE_TORCH_COMPILE` env flag | Every ✅ row above is exercised by a matching `tests/rocm_tests/test_*_hip.py`. +The full set of conditions that cause AITER auto-routing to fall back to +HIP is documented in [Known Limitations](#known-limitations) below. ## GPU, ROCm, and PyTorch Support From efb7ee12768483fef52d748c39081820afbf42d6 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 12:54:10 +0000 Subject: [PATCH 04/20] docs: fold fp8 status into per-row notes in feature matrix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous matrix listed dtypes inconsistently — single-decode named fp16/bf16/fp8 explicitly while sibling rows didn't. Drop the implicit fp16/bf16 enumeration (already covered by the ✅ HIP marker) and call out fp8 only where it's actually supported: batch decode KV-cache (E4M3FNUZ), RoPE fused quant+append (E4M3FNUZ + E5M2FNUZ), paged KV-cache append HIP path. Prefill rows mark fp8 as WIP. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5f040ff5cf..b39c04e3db 100644 --- a/README.md +++ b/README.md @@ -69,15 +69,15 @@ ROCm AITER backend. | Kernel | HIP | AITER | `backend="auto"` resolves to | Notes | | :--- | :---: | :---: | :--- | :--- | -| **Single decode attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA; fp16, bf16, fp8 (E4M3FNUZ KV-cache) | -| **Batch decode attention (paged)** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no CUDA-graph + `use_tensor_cores=False`; else **HIP** | Sliding-window supported on the AITER path; CUDA-graph auto-routes back to HIP | -| **Single prefill attention** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no custom mask + equal Q/KV dtypes & head dims + `pos_encoding_mode="NONE"`; else **HIP** | MHA / GQA / MQA; fp8 prefill WIP | -| **Batch prefill attention (paged + ragged)** | ✅ `fa2` | ✅ | Same auto criteria as single prefill | AITER native page sizes: `{16, 1024}` (`{128, 256, 1024}` on `amd-aiter==0.1.10`); other sizes go through a gather on the AITER path | +| **Single decode attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA | +| **Batch decode attention (paged)** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no CUDA-graph + `use_tensor_cores=False`; else **HIP** | MHA / GQA / MQA; **fp8 KV-cache (E4M3FNUZ)** on the HIP path; sliding-window on the AITER path; CUDA-graph auto-routes back to HIP | +| **Single prefill attention** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no custom mask + equal Q/KV dtypes & head dims + `pos_encoding_mode="NONE"`; else **HIP** | MHA / GQA / MQA; fp8 WIP | +| **Batch prefill attention (paged + ragged)** | ✅ `fa2` | ✅ | Same auto criteria as single prefill | MHA / GQA / MQA; fp8 WIP. AITER native page sizes: `{16, 1024}` (`{128, 256, 1024}` on `amd-aiter==0.1.10`); other sizes go through a gather on the AITER path | | **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention | | **MLA (Multi-Latent Attention)** | — | ✅ | **AITER only** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; must pass `backend="aiter"` explicitly | | **POD attention** | TBD | — | n/a | Code present; **not yet validated on ROCm** | -| **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + paged-KV append | -| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + AITER importable; else **HIP `native`** | `append_paged_kv_cache` | +| **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ) | +| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path | | **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` | | **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | | | **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | From 4650cf5f07a6a012a62b5f35cca430e518381791 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 13:47:31 +0000 Subject: [PATCH 05/20] docs: fix torch.compile env var name and document runtime env vars / helpers The matrix referenced a nonexistent FLASHINFER_ENABLE_TORCH_COMPILE; the actual gate is FLASHINFER_USE_TORCH_CUSTOM_OPS=1 (must be set before importing flashinfer, requires PyTorch >= 2.4). While here, add an Environment Variables section covering the runtime knobs that aren't already in CLAUDE.md (FLASHINFER_HIP_FUSED_CASCADE, FLASHINFER_LOGGING_LEVEL, FLASHINFER_DISABLE_JIT, ROCM_PATH/ROCM_HOME) and a Runtime Helpers section pointing at is_aiter_supported and check_torch_rocm_compatibility. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 42 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b39c04e3db..f0bcf16bc5 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,8 @@ each FlashInfer+ROCm release to its corresponding upstream tag (e.g. * [Install AITER wheel package](#install-aiter-wheel-package) * [Known Limitations](#known-limitations) * [Single Prefill Example](#single-prefill-example) +* [Environment Variables](#environment-variables) +* [Runtime Helpers](#runtime-helpers) * [License and Acknowledgements](#license-and-acknowledgements) ## Basic Usage @@ -73,7 +75,7 @@ ROCm AITER backend. | **Batch decode attention (paged)** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no CUDA-graph + `use_tensor_cores=False`; else **HIP** | MHA / GQA / MQA; **fp8 KV-cache (E4M3FNUZ)** on the HIP path; sliding-window on the AITER path; CUDA-graph auto-routes back to HIP | | **Single prefill attention** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no custom mask + equal Q/KV dtypes & head dims + `pos_encoding_mode="NONE"`; else **HIP** | MHA / GQA / MQA; fp8 WIP | | **Batch prefill attention (paged + ragged)** | ✅ `fa2` | ✅ | Same auto criteria as single prefill | MHA / GQA / MQA; fp8 WIP. AITER native page sizes: `{16, 1024}` (`{128, 256, 1024}` on `amd-aiter==0.1.10`); other sizes go through a gather on the AITER path | -| **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention | +| **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention; a fused single-kernel HIP variant is gated behind `FLASHINFER_HIP_FUSED_CASCADE=1` | | **MLA (Multi-Latent Attention)** | — | ✅ | **AITER only** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; must pass `backend="aiter"` explicitly | | **POD attention** | TBD | — | n/a | Code present; **not yet validated on ROCm** | | **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ) | @@ -84,7 +86,7 @@ ROCm AITER backend. | **Logits processor** | ✅ | — | HIP | Composable processor pipeline (cap, mask, temperature, …) | | **Activation** | ✅ | — | HIP | SiLU / GELU with fused gating | | **Quantization** | ✅ | — | HIP | `packbits`, `segment_packbits` | -| **`torch.compile`** | ✅ (opt-in flag) | n/a | n/a | Enabled via the `FLASHINFER_ENABLE_TORCH_COMPILE` env flag | +| **`torch.compile`** | ✅ (opt-in) | n/a | n/a | Set `FLASHINFER_USE_TORCH_CUSTOM_OPS=1` **before** importing `flashinfer`; requires PyTorch ≥ 2.4. Without it, `torch.compile` raises a clear error if it traces into a flashinfer op | Every ✅ row above is exercised by a matching `tests/rocm_tests/test_*_hip.py`. The full set of conditions that cause AITER auto-routing to fall back to @@ -418,6 +420,42 @@ v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cu output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, backend="auto") ``` +## Environment Variables + +FlashInfer+ROCm reads the following environment variables at runtime +or import time. Build-time variables (`FLASHINFER_ROCM_ARCH_LIST`, +`FLASHINFER_JIT_VERBOSE`, `FLASHINFER_JIT_DEBUG`, `MAX_JOBS`, …) are +documented in [CLAUDE.md](CLAUDE.md). + +| Variable | Default | Purpose | +| :--- | :--- | :--- | +| `FLASHINFER_USE_TORCH_CUSTOM_OPS` | `0` | Set to `1` **before** importing `flashinfer` to wrap kernels in `torch.library.custom_op` so `torch.compile` / Dynamo can trace them. Requires PyTorch ≥ 2.4. Adds a small per-call dispatch overhead. | +| `FLASHINFER_HIP_FUSED_CASCADE` | `0` | Set to `1` to use a fused single-kernel HIP cascade attention path instead of the default two-level merge-based path. Experimental on ROCm. | +| `FLASHINFER_LOGGING_LEVEL` | `INFO` | Logger verbosity (e.g. `DEBUG`, `INFO`, `WARNING`). Affects AITER auto-fallback warnings and JIT build messages. | +| `FLASHINFER_DISABLE_JIT` | unset | Set to any non-empty value to skip JIT compilation. Useful when running an AOT-built wheel and you want to fail loudly on missing kernels rather than trigger a build. | +| `ROCM_PATH` / `ROCM_HOME` | `/opt/rocm` | Used by `flashinfer.hip_utils` to locate the ROCm install. Override only for non-standard ROCm layouts. | + +## Runtime Helpers + +`flashinfer` ships a few ROCm-specific helpers that are useful when +guarding code paths or diagnosing setup issues: + +```python +from flashinfer.aiter_utils import is_aiter_supported +from flashinfer.hip_utils import ( + check_torch_rocm_compatibility, + validate_flashinfer_rocm_arch, +) + +# Returns True only on gfx942/gfx950 with the aiter package importable. +if is_aiter_supported(torch.device("cuda")): + ... + +# Raises a clear error if PyTorch + ROCm versions are incompatible +# (e.g. a CPU-only torch wheel was picked up from PyPI). +check_torch_rocm_compatibility() +``` + ## License and Acknowledgements FlashInfer+ROCm is released under the Apache-2.0 License — see From 9f650bf6a8c9f330d1cfaed85c9df66340f99478 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 13:50:19 +0000 Subject: [PATCH 06/20] docs: collapse docker image table to the latest tag Keep only the current validated rocm/flashinfer image and point readers at hub.docker.com/r/rocm/flashinfer/tags for older ROCm/PyTorch combos. The full table goes stale on every release. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f0bcf16bc5..7da330fc80 100644 --- a/README.md +++ b/README.md @@ -116,15 +116,16 @@ available wheels. ### Option 1: Get a Pre-built Docker Image -AMD validates and publishes [FlashInfer images](https://hub.docker.com/r/rocm/flashinfer/tags) -with ROCm backends on Docker Hub. The following Docker image tags -represent the latest available FlashInfer+ROCm releases: +AMD validates and publishes FlashInfer images with ROCm backends on +Docker Hub. The latest validated tag is: | Docker image | ROCm | FlashInfer | PyTorch | Ubuntu | Python | GPU | | ------------ | ---- | ---------- | ------- | ------ | ------ | --- | | `rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1` | 7.2.0 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355X, MI325X, MI300X | -| `rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.0.2_ubuntu24.04_py3.12_pytorch2.9.1` | 7.0.2 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355X, MI325X, MI300X | -| `rocm/flashinfer:flashinfer-0.2.5.amd2_rocm7.1.1_ubuntu24.04_py3.12_pytorch2.8` | 7.1.1 | v0.2.5 | 2.8.0 | 24.04 | 3.12 | MI325X, MI300X | + +For older releases (earlier ROCm / PyTorch / FlashInfer combinations), +see the full tag list at +. **Start a container:** From 7c122a677c073ad2bd9d7fa38a064b59b3a00381 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 13:51:15 +0000 Subject: [PATCH 07/20] docs: drop manual micromamba activate from docker verify step The base environment is activated on shell start inside the rocm/flashinfer images, so the explicit `micromamba activate base` call was misleading. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 7da330fc80..a7daf0eef3 100644 --- a/README.md +++ b/README.md @@ -135,17 +135,15 @@ docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \ --ipc=host --shm-size 128G --name= ``` -**Activate the environment and verify:** +**Verify the installation:** ```bash -# Activate the micromamba environment (env name may vary based on the image) -micromamba activate base - -# Verify installation python -c "import flashinfer; print(flashinfer.__version__)" ``` Expected output: `0.5.3+amd.1` (with a possible JIT backend message). +The container's micromamba environment is activated automatically on +shell start — no manual `micromamba activate` is required. ### Option 2: Install from a Wheel Package From 012c053dc63f4e1c0132c4d15f05acd3057a3fe8 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 13:51:51 +0000 Subject: [PATCH 08/20] docs: use concrete image tag and container name in docker run Replace / placeholders with the flashinfer-rocm container name and the actual latest image tag so the snippet is copy-pasteable. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a7daf0eef3..45c7f0ec55 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,8 @@ see the full tag list at ```bash docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ - --ipc=host --shm-size 128G --name= + --ipc=host --shm-size 128G --name=flashinfer-rocm \ + rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 ``` **Verify the installation:** From 4cdf823506d466d10e173c4c00572a66b22859b4 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 13:59:33 +0000 Subject: [PATCH 09/20] docs(readme): simplify "Trying the Examples" to point at examples/ Replace the wget-based download steps with a brief pointer to the examples/ directory and a single run command. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 45c7f0ec55..53aace3e1e 100644 --- a/README.md +++ b/README.md @@ -165,30 +165,15 @@ to a CPU-only PyPI wheel. ### Trying the Examples -Download and run example scripts from the repository: +Runnable scripts live in the [`examples/`](examples/) directory of this +repository (single/batch prefill, batch decode, plus an +`amd_flashinfer_rocm_tutorial.ipynb` Jupyter notebook). After cloning, +run any of them directly, for example: ```bash -# Download a single example -wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/single_prefill_example.py -python single_prefill_example.py - -# Download all examples -for example in single_prefill_example.py batch_prefill_example.py batch_decode_example.py; do - wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/$example -done +python examples/single_prefill_example.py ``` -**Available examples:** - -* `single_prefill_example.py` — single-sequence prefill attention -* `batch_prefill_example.py` — batched prefill attention -* `batch_decode_example.py` — batched decode attention -* `examples/amd_flashinfer_rocm_tutorial.ipynb` — Jupyter tutorial: - environment verification (`hip_utils`), AITER-backed prefill examples, - and `logits_processor` on ROCm -* `examples/run_jupyter_server.sh` — start JupyterLab from the repo root - (run inside your ROCm/FlashInfer environment or Docker container) - ## Build from Source ### Setting up a Development Environment From 7da63b9630b2f53cc0f2b93d0761866d44370e35 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 14:00:52 +0000 Subject: [PATCH 10/20] docs(readme): drop redundant Single Prefill Example from AITER section The Basic Usage snippet at the top of the README already shows the same call pattern; the AITER-section duplicate added no extra information. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/README.md b/README.md index 53aace3e1e..51096a7b44 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,6 @@ each FlashInfer+ROCm release to its corresponding upstream tag (e.g. * [Install AITER from source](#install-aiter-from-source) * [Install AITER wheel package](#install-aiter-wheel-package) * [Known Limitations](#known-limitations) - * [Single Prefill Example](#single-prefill-example) * [Environment Variables](#environment-variables) * [Runtime Helpers](#runtime-helpers) * [License and Acknowledgements](#license-and-acknowledgements) @@ -380,31 +379,6 @@ results — pass `backend="fa2"` explicitly if you need any of these): * MLA on ROCm currently supports only `bfloat16` and `page_size=1` through the AITER backend. -### Single Prefill Example - -This section provides an example on how to use Single Prefill with AITER. - -```python -import torch -import flashinfer - -# Configuration -seq_len = 1024 # Prompt length -num_qo_heads = 32 # Number of query/output heads -num_kv_heads = 8 # Number of KV heads (GQA with 4:1 ratio) -head_dim = 128 - -# Create Q, K, V tensors (NHD layout: sequence, heads, dimension) -q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda") -k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") -v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") - -# Run single prefill attention with causal masking. -# On gfx942/gfx950, backend="auto" (default) routes to AITER automatically. -# Pass backend="aiter" to require AITER explicitly, or backend="fa2" to skip it. -output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, backend="auto") -``` - ## Environment Variables FlashInfer+ROCm reads the following environment variables at runtime From 7b3bcdd881df47aafd952c925e74896658266c2b Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 14:02:32 +0000 Subject: [PATCH 11/20] docs(readme): rename "Build from Source" to "Install from Source" Co-Authored-By: Claude Sonnet 4.6 --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 51096a7b44..0521e48ef8 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ each FlashInfer+ROCm release to its corresponding upstream tag (e.g. * [Option 1: Get a Pre-built Docker Image](#option-1-get-a-pre-built-docker-image) * [Option 2: Install from a Wheel Package](#option-2-install-from-a-wheel-package) * [Trying the Examples](#trying-the-examples) -* [Build from Source](#build-from-source) +* [Install from Source](#install-from-source) * [Setting up a Development Environment](#setting-up-a-development-environment) * [Building and Installing a Wheel Package](#building-and-installing-a-wheel-package) * [Running Tests](#running-tests) @@ -173,7 +173,7 @@ run any of them directly, for example: python examples/single_prefill_example.py ``` -## Build from Source +## Install from Source ### Setting up a Development Environment From 9f13ae182479e85e09a15bb5e26b119c9154f0b2 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 15:25:49 +0000 Subject: [PATCH 12/20] docs(readme): link CDNA3 / CDNA4 to their architecture references Hyperlink the first mention of CDNA3 to the ROCm MI300 microarchitecture docs and CDNA4 to AMD's MI350 product page. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 0521e48ef8..f47c909c4a 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,13 @@ # FlashInfer+ROCm: An AMD ROCm port of FlashInfer FlashInfer+ROCm is an AMD ROCm port of the -[FlashInfer](https://github.com/flashinfer-ai/flashinfer) attention, -RoPE, normalization, sampling, and logits-processor kernels for LLM -inference on AMD Instinct GPUs. The port targets CDNA3 (gfx942 — -MI300X / MI325X) and CDNA4 (gfx950 — MI355X), and is aimed at developers -embedding FlashInfer kernels into their own training or serving stack. +[FlashInfer](https://github.com/flashinfer-ai/flashinfer) library for LLM +inference on AMD Instinct GPUs. The port targets +[CDNA3](https://rocm.docs.amd.com/en/latest/conceptual/gpu-arch/mi300.html) +(gfx942 — MI300X / MI325X) and +[CDNA4](https://www.amd.com/en/products/accelerators/instinct/mi350.html) +(gfx950 — MI355X), and is aimed at developers embedding FlashInfer +kernels into their own training or serving stack. The project is in active development with the primary focus on attention (single and batch prefill / decode) and the surrounding KV-cache, RoPE, From 86448ff30d488e9aceeb7897140a53a2d7a48a45 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 15:27:02 +0000 Subject: [PATCH 13/20] docs(readme): point CDNA3 / CDNA4 links to official whitepapers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the GPU product / ROCm doc links with the AMD CDNA3 and CDNA4 architecture whitepapers — the right reference for the architectures themselves rather than the cards. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f47c909c4a..c12f462c0f 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@ FlashInfer+ROCm is an AMD ROCm port of the [FlashInfer](https://github.com/flashinfer-ai/flashinfer) library for LLM inference on AMD Instinct GPUs. The port targets -[CDNA3](https://rocm.docs.amd.com/en/latest/conceptual/gpu-arch/mi300.html) +[CDNA3](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-3-white-paper.pdf) (gfx942 — MI300X / MI325X) and -[CDNA4](https://www.amd.com/en/products/accelerators/instinct/mi350.html) +[CDNA4](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-4-architecture-whitepaper.pdf) (gfx950 — MI355X), and is aimed at developers embedding FlashInfer kernels into their own training or serving stack. From 57c8cc8d2e2749ab7ee26917c796966d7a7d0730 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 15:37:57 +0000 Subject: [PATCH 14/20] docs(readme): tighten intro and call out the HIP + AITER split - Drop the "AMD ROCm port" redundancy with the title and lead with what ships in-tree (the HIP kernel set) versus what dispatches to AITER. - Cross-link the Feature Support Matrix and AITER from the first paragraph so readers landing on the README see the structure immediately. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index c12f462c0f..30e9e2ebcd 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,20 @@ # FlashInfer+ROCm: An AMD ROCm port of FlashInfer -FlashInfer+ROCm is an AMD ROCm port of the -[FlashInfer](https://github.com/flashinfer-ai/flashinfer) library for LLM -inference on AMD Instinct GPUs. The port targets +FlashInfer+ROCm brings the +[FlashInfer](https://github.com/flashinfer-ai/flashinfer) inference +kernel library to AMD Instinct GPUs — currently [CDNA3](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-3-white-paper.pdf) (gfx942 — MI300X / MI325X) and [CDNA4](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-4-architecture-whitepaper.pdf) -(gfx950 — MI355X), and is aimed at developers embedding FlashInfer -kernels into their own training or serving stack. - -The project is in active development with the primary focus on attention -(single and batch prefill / decode) and the surrounding KV-cache, RoPE, -and normalization kernels. See [CHANGELOG.md](CHANGELOG.md) for the -full release history. +(gfx950 — MI355X). It ships in-tree HIP ports of the attention, +KV-cache, RoPE, normalization, sampling, and logits-processor kernels, +and transparently dispatches a subset of attention paths to AMD's +[AITER](https://github.com/ROCm/aiter) backend when its compatibility +conditions hold (see [Feature Support Matrix](#feature-support-matrix)). + +The port is in active development and is aimed at developers embedding +FlashInfer kernels into their own training or serving stack. See +[CHANGELOG.md](CHANGELOG.md) for the full release history. **Versioning:** The release tag format `+amd.` ties each FlashInfer+ROCm release to its corresponding upstream tag (e.g. From df1cb841600f4792253f57ebdf48d037f2843a8d Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 15:40:50 +0000 Subject: [PATCH 15/20] docs(readme): move Basic Usage to the end of the README Install / Feature Matrix / Build / AITER are what a new reader needs first; the code snippet reads better as a closing example. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 30e9e2ebcd..f1b19eda92 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,6 @@ each FlashInfer+ROCm release to its corresponding upstream tag (e.g. ## Table of Contents -* [Basic Usage](#basic-usage) * [Feature Support Matrix](#feature-support-matrix) * [GPU, ROCm, and PyTorch Support](#gpu-rocm-and-pytorch-support) * [Getting Started](#getting-started) @@ -39,27 +38,9 @@ each FlashInfer+ROCm release to its corresponding upstream tag (e.g. * [Known Limitations](#known-limitations) * [Environment Variables](#environment-variables) * [Runtime Helpers](#runtime-helpers) +* [Basic Usage](#basic-usage) * [License and Acknowledgements](#license-and-acknowledgements) -## Basic Usage - -```python -import torch -import flashinfer - -# PyTorch+ROCm still uses device="cuda" for AMD GPUs. -q = torch.randn(1024, 32, 128, dtype=torch.float16, device="cuda") -k = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") # GQA 4:1 -v = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") - -# backend="auto" (default) routes to AITER when supported on gfx942/gfx950 -# and falls back to the in-tree fa2 HIP kernel otherwise. -output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) -``` - -See [`examples/`](examples/) for batch prefill, batch decode, and a -Jupyter tutorial that walks through the full public API on ROCm. - ## Feature Support Matrix Most kernels ship with an in-tree HIP implementation. A subset also has @@ -419,6 +400,25 @@ if is_aiter_supported(torch.device("cuda")): check_torch_rocm_compatibility() ``` +## Basic Usage + +```python +import torch +import flashinfer + +# PyTorch+ROCm still uses device="cuda" for AMD GPUs. +q = torch.randn(1024, 32, 128, dtype=torch.float16, device="cuda") +k = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") # GQA 4:1 +v = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") + +# backend="auto" (default) routes to AITER when supported on gfx942/gfx950 +# and falls back to the in-tree fa2 HIP kernel otherwise. +output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) +``` + +See [`examples/`](examples/) for batch prefill, batch decode, and a +Jupyter tutorial that walks through the full public API on ROCm. + ## License and Acknowledgements FlashInfer+ROCm is released under the Apache-2.0 License — see From dfa7a0f784aec8a327d6839affeb2062aad7787b Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 15:44:08 +0000 Subject: [PATCH 16/20] docs(readme): proofread, dedupe, and clarify after fact-check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Tighten Feature Matrix preamble and Legend; drop the duplicate AITER link (already in the intro). - Collapse the AITER Support intro that overlapped with the matrix; cross-link Known Limitations instead of re-listing the conditions. - Rewrite Known Limitations preamble to call out the two-group split (hard errors vs. silently-ignored kwargs) more directly. - Split the dense CPX-mode pytest notes into labelled bullets. - Drop the unused validate_flashinfer_rocm_arch import from the runtime helpers snippet and note (separately) that it's a build-time validator, not a runtime helper. - Move the pre-commit / pytest contributing reminder out of the License paragraph into its own line. - Fix "Python tests suite" → "Python test suite". Verified against the codebase: env var names + defaults (FLASHINFER_USE_TORCH_CUSTOM_OPS, FLASHINFER_HIP_FUSED_CASCADE, FLASHINFER_LOGGING_LEVEL, FLASHINFER_DISABLE_JIT, ROCM_PATH/ROCM_HOME), hip_utils / aiter_utils helper signatures, attention_reference.py path, and the MI308X CPX-mode reference (decode.cuh:707). Co-Authored-By: Claude Sonnet 4.6 --- README.md | 90 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 56 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index f1b19eda92..388fa0cb9a 100644 --- a/README.md +++ b/README.md @@ -44,14 +44,12 @@ each FlashInfer+ROCm release to its corresponding upstream tag (e.g. ## Feature Support Matrix Most kernels ship with an in-tree HIP implementation. A subset also has -an [AITER](https://github.com/ROCm/aiter) backend; for those, the -`backend="auto"` default picks AITER when its compatibility conditions -hold and transparently falls back to HIP otherwise. AITER-only kernels -(currently MLA) require an explicit `backend="aiter"`. +an AITER backend; for those, `backend="auto"` picks AITER when its +compatibility conditions hold and falls back to HIP otherwise. The one +AITER-only kernel today (MLA) requires an explicit `backend="aiter"`. -Legend: **HIP** = in-tree FlashInfer+ROCm kernel (the historical `fa2` -HIP port, or the `native` JIT kernel for non-attention ops). **AITER** = -ROCm AITER backend. +Legend: **HIP** = in-tree kernel (`fa2` for attention, `native` JIT +kernel for non-attention ops). **AITER** = ROCm AITER backend. | Kernel | HIP | AITER | `backend="auto"` resolves to | Notes | | :--- | :---: | :---: | :--- | :--- | @@ -244,7 +242,7 @@ and use more disk space but avoid JIT compilation at runtime. ### Running Tests -The Python tests suite can be run with pytest: +Run the Python test suite with pytest: ```bash # Run default tests (configured in pyproject.toml) @@ -286,25 +284,43 @@ pytest -n auto --reruns 2 -m "slow" **Notes** -* `pytest -n auto` for the `tests/rocm_tests/` suite spawns **half as many xdist workers as physical AMD cards** (e.g. 4 workers on a CPX-mode 8-card MI308X / MI325X host). One worker per physical card was tried first but produced sporadic failures across rope, single_prefill, and logits_cap under residual concurrent load; halving the count produces reliable green runs. Each worker is pinned to its card via `HIP_VISIBLE_DEVICES`. On non-CPX systems the helper applies the same halving; users who want every device used can pass an explicit `-n N`. -* `--reruns 2` (from `pytest-rerunfailures`) absorbs the residual ~0.01 % of transient HIP runtime crashes (HSA exceptions, HIPBLAS handle-pool exhaustion, intermittent generator non-determinism) that worker pinning cannot fully eliminate. Successful tests are not duplicated; only failed tests are retried. -* The `slow` marker is registered in [pyproject.toml](pyproject.toml). It tags the 1M-trial sampling-frequency tests, the 4 GB-tensor speculative-sampling cases, and the entire `TestLogitsPipeCompilationHIP` class (every test there runs the sampling kernel twice per case for compile=True/False). -* The reference attention helper in `tests/attention_reference.py` wraps `torch.matmul` in a `_hipblas_safe_matmul` retry helper that catches `HIPBLAS_STATUS_ALLOC_FAILED` and retries with a short back-off — needed under heavy concurrent xdist load. +* **Worker count.** `pytest -n auto` for the `tests/rocm_tests/` suite + spawns **half as many xdist workers as physical AMD cards** (e.g. 4 + workers on a CPX-mode 8-card MI308X / MI325X host) and pins each + worker to its card via `HIP_VISIBLE_DEVICES`. One worker per physical + card was tried first but produced sporadic failures across rope, + single_prefill, and logits_cap under residual concurrent load. + Pass an explicit `-n N` to override the halving. +* **Reruns.** `--reruns 2` (from `pytest-rerunfailures`) absorbs the + residual ~0.01 % of transient HIP runtime crashes (HSA exceptions, + HIPBLAS handle-pool exhaustion, intermittent generator + non-determinism) that worker pinning cannot fully eliminate. Only + failed tests are retried. +* **`slow` marker.** Registered in [pyproject.toml](pyproject.toml). It + tags the 1M-trial sampling-frequency tests, the 4 GB-tensor + speculative-sampling cases, and the entire `TestLogitsPipeCompilationHIP` + class (each test runs the sampling kernel twice for compile=True/False). +* **HIPBLAS retry.** The reference attention helper in + `tests/attention_reference.py` wraps `torch.matmul` in a + `_hipblas_safe_matmul` retry that catches `HIPBLAS_STATUS_ALLOC_FAILED` + and retries with a short back-off — needed under heavy concurrent + xdist load. ## AITER Support -FlashInfer+ROCm supports the use of [AITER](https://github.com/ROCm/aiter) as a -backend. The `aiter` backend is enabled for the `single_prefill`, -`batch_prefill` (paged and ragged), `batch_decode`, `append_paged_kv_cache`, -`rmsnorm`, and `MLA` paths. MLA on ROCm is **only** available via AITER — -there is no in-tree HIP MLA kernel yet. +FlashInfer+ROCm can dispatch the `single_prefill`, `batch_prefill` +(paged and ragged), `batch_decode`, `append_paged_kv_cache`, `rmsnorm`, +and `MLA` paths to [AITER](https://github.com/ROCm/aiter). MLA on ROCm +is **AITER-only** — there is no in-tree HIP MLA kernel yet. -**On gfx942/gfx950 GPUs, `backend="auto"` (the default) automatically selects the AITER backend** -when the call parameters are compatible (fp16/bf16, NHD layout, no custom mask, equal Q/K/V -dtypes and head dims, `pos_encoding_mode="NONE"`). It falls back to `fa2` with a one-time -`logger.warning` when any condition is not met. You can also pass `backend="aiter"` explicitly. +On gfx942/gfx950, `backend="auto"` (the default) selects AITER when the +call is compatible (see [Known Limitations](#known-limitations) for the +full list) and otherwise falls back to the in-tree `fa2` HIP kernel, +emitting a one-time `logger.warning`. Pass `backend="aiter"` to require +AITER explicitly, or `backend="fa2"` to skip it. -Unless you are using the prebuilt docker image, AITER must also be installed on your system. You may follow one of the following ways to do so. +Unless you are using the prebuilt Docker image, install AITER separately +via one of the options below. ### Install AITER from source @@ -324,9 +340,11 @@ pip install amd-aiter --index-url https://pypi.amd.com/simple/ ### Known Limitations -The AITER backend has the following constraints. With `backend="aiter"` the -call will error on the first group of conditions, or for the second group, -run but silently ignore the unsupported argument. +AITER constraints fall into two groups: hard incompatibilities (the call +errors with `backend="aiter"` and triggers fallback under +`backend="auto"`), and silently-ignored kwargs (the call runs but the +flag has no effect on AITER — pass `backend="fa2"` explicitly if you +need any of them). **Conditions that fall back to `fa2` under `backend="auto"`:** @@ -338,9 +356,9 @@ run but silently ignore the unsupported argument. * `head_dim_qk != head_dim_vo` (e.g. DeepSeek-style MLA with 192/128 head dims) * the `aiter` Python package is not importable -**Features silently ignored on the AITER path** (the kwargs are accepted by -the FlashInfer wrapper but not forwarded to AITER, which can produce wrong -results — pass `backend="fa2"` explicitly if you need any of these): +**Features silently ignored on the AITER path** (kwargs are accepted by +the FlashInfer wrapper but not forwarded to AITER, which can produce +wrong results): * ALiBi slopes (`maybe_alibi_slopes`) * in-kernel positional encoding modes (`pos_encoding_mode`, `rope_scale`, @@ -386,12 +404,9 @@ guarding code paths or diagnosing setup issues: ```python from flashinfer.aiter_utils import is_aiter_supported -from flashinfer.hip_utils import ( - check_torch_rocm_compatibility, - validate_flashinfer_rocm_arch, -) +from flashinfer.hip_utils import check_torch_rocm_compatibility -# Returns True only on gfx942/gfx950 with the aiter package importable. +# True only on gfx942/gfx950 with the aiter package importable. if is_aiter_supported(torch.device("cuda")): ... @@ -400,6 +415,11 @@ if is_aiter_supported(torch.device("cuda")): check_torch_rocm_compatibility() ``` +`flashinfer.hip_utils.validate_flashinfer_rocm_arch` is a related +build-time validator used by `setup.py` to cross-check +`FLASHINFER_ROCM_ARCH_LIST` against ROCm and PyTorch — not typically +called from application code. + ## Basic Usage ```python @@ -424,4 +444,6 @@ Jupyter tutorial that walks through the full public API on ROCm. FlashInfer+ROCm is released under the Apache-2.0 License — see [LICENSE](LICENSE) and [NOTICE](NOTICE). Upstream project: [flashinfer-ai/flashinfer](https://github.com/flashinfer-ai/flashinfer). -Run `pre-commit run -a` and `pytest` before opening a PR. + +Contributions are welcome. Please run `pre-commit run -a` and the +relevant `pytest` selection before opening a PR. From 2d1d0f9b205b17d4b934e725fa6c710b65704763 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 16:06:21 +0000 Subject: [PATCH 17/20] feat(mla): accept backend="auto" on ROCm as an alias for "aiter" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MLA on ROCm previously forced the user to pass backend="aiter" explicitly: the wrapper's __init__ raised ValueError on anything other than "aiter", including the auto value used by every other ROCm kernel. That left MLA as the odd one out in the public API even though it has exactly one implementation to choose from on ROCm. Accept both "auto" and "aiter" (default is now "auto" to match the rest of the ROCm wrappers); any other value still raises with an updated message. The behaviour is unchanged for callers who already pass "aiter". ### Test plan - New parametrized test covering backend="auto" / "aiter" construction. - New test that backend="fa2" still raises ValueError (runs anywhere, no GPU required since the check fires before the AITER probe). - Full tests/rocm_tests/test_mla_aiter_hip.py — 11 passed. - pre-commit run -a — passed. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 9 ++++--- flashinfer/mla_rocm.py | 11 +++++---- tests/rocm_tests/test_mla_aiter_hip.py | 33 ++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 388fa0cb9a..a3bb84c621 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,8 @@ each FlashInfer+ROCm release to its corresponding upstream tag (e.g. Most kernels ship with an in-tree HIP implementation. A subset also has an AITER backend; for those, `backend="auto"` picks AITER when its compatibility conditions hold and falls back to HIP otherwise. The one -AITER-only kernel today (MLA) requires an explicit `backend="aiter"`. +AITER-only kernel today (MLA) has no HIP path — `backend="auto"` +resolves directly to `"aiter"`. Legend: **HIP** = in-tree kernel (`fa2` for attention, `native` JIT kernel for non-attention ops). **AITER** = ROCm AITER backend. @@ -58,7 +59,7 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend. | **Single prefill attention** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no custom mask + equal Q/KV dtypes & head dims + `pos_encoding_mode="NONE"`; else **HIP** | MHA / GQA / MQA; fp8 WIP | | **Batch prefill attention (paged + ragged)** | ✅ `fa2` | ✅ | Same auto criteria as single prefill | MHA / GQA / MQA; fp8 WIP. AITER native page sizes: `{16, 1024}` (`{128, 256, 1024}` on `amd-aiter==0.1.10`); other sizes go through a gather on the AITER path | | **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention; a fused single-kernel HIP variant is gated behind `FLASHINFER_HIP_FUSED_CASCADE=1` | -| **MLA (Multi-Latent Attention)** | — | ✅ | **AITER only** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; must pass `backend="aiter"` explicitly | +| **MLA (Multi-Latent Attention)** | — | ✅ | **AITER** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; `backend="auto"` (default) resolves to `"aiter"` | | **POD attention** | TBD | — | n/a | Code present; **not yet validated on ROCm** | | **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ) | | **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path | @@ -311,7 +312,9 @@ pytest -n auto --reruns 2 -m "slow" FlashInfer+ROCm can dispatch the `single_prefill`, `batch_prefill` (paged and ragged), `batch_decode`, `append_paged_kv_cache`, `rmsnorm`, and `MLA` paths to [AITER](https://github.com/ROCm/aiter). MLA on ROCm -is **AITER-only** — there is no in-tree HIP MLA kernel yet. +is **AITER-only** — there is no in-tree HIP MLA kernel yet, so +`backend="auto"` (the default for the MLA wrapper) resolves directly +to `"aiter"`. On gfx942/gfx950, `backend="auto"` (the default) selects AITER when the call is compatible (see [Known Limitations](#known-limitations) for the diff --git a/flashinfer/mla_rocm.py b/flashinfer/mla_rocm.py index ec00503ae2..5cb3e1e211 100644 --- a/flashinfer/mla_rocm.py +++ b/flashinfer/mla_rocm.py @@ -99,18 +99,21 @@ class BatchMLAPagedAttentionWrapper: float_workspace_buffer : torch.Tensor Reserved workspace. Size is ignored; only the device is used. backend : str - Must be ``"aiter"`` (only supported backend on ROCm). + Either ``"auto"`` (the default, resolves to ``"aiter"`` on ROCm) + or ``"aiter"``. Any other value raises ``ValueError``. """ def __init__( self, float_workspace_buffer: torch.Tensor, - backend: str = "aiter", + backend: str = "auto", ) -> None: - if backend != "aiter": + if backend not in ("auto", "aiter"): raise ValueError( - f"Only backend='aiter' is supported on ROCm; got {backend!r}." + f"Only backend='aiter' (or 'auto', which resolves to " + f"'aiter') is supported on ROCm; got {backend!r}." ) + backend = "aiter" self.device = float_workspace_buffer.device _require_aiter_mla(self.device) diff --git a/tests/rocm_tests/test_mla_aiter_hip.py b/tests/rocm_tests/test_mla_aiter_hip.py index 7f2f90785c..b0b84c6111 100644 --- a/tests/rocm_tests/test_mla_aiter_hip.py +++ b/tests/rocm_tests/test_mla_aiter_hip.py @@ -326,3 +326,36 @@ def test_mla_run_before_plan_raises(): torch.zeros(4, 16, 512, dtype=torch.float16, device=device), torch.zeros(4, 16, 64, dtype=torch.float16, device=device), ) + + +@pytest.mark.skipif( + not is_aiter_supported(torch.device("cuda:0")), + reason="AITER backend requires gfx942/gfx950", +) +@pytest.mark.parametrize("backend", ["auto", "aiter"]) +def test_mla_backend_accepts_auto_and_aiter(backend): + """The ROCm MLA wrapper accepts both 'auto' (default) and 'aiter'. + + 'auto' resolves to 'aiter' since there is no HIP MLA kernel. + """ + from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper + + device = torch.device("cuda:0") + ws = torch.empty(1, dtype=torch.float32, device=device) + BatchMLAPagedAttentionWrapper(ws, backend=backend) + + +def test_mla_backend_rejects_unsupported(): + """Any backend other than 'auto'/'aiter' raises ValueError. + + The check fires before the AITER-availability probe, so this test + runs on any host (no GPU / no AITER required). + """ + from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper + + device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + ) + ws = torch.empty(1, dtype=torch.float32, device=device) + with pytest.raises(ValueError, match="aiter.*auto"): + BatchMLAPagedAttentionWrapper(ws, backend="fa2") From 713aeca30b76c3a0e4d17643c2839d801f01a6a4 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 16:08:11 +0000 Subject: [PATCH 18/20] docs(readme): clarify the dev-container "Environment name" note MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The original "Environment name varies …" wording was ambiguous in context (the surrounding section is about the Docker image tag, not a shell or micromamba env). Rewrite to spell out that it's the Docker image tag that encodes the versions, and that the -t tag and the tag passed to docker run must match. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a3bb84c621..2fc1bca041 100644 --- a/README.md +++ b/README.md @@ -219,8 +219,10 @@ docker run -it \ -**Note:** Environment name varies based on Python, PyTorch, and ROCm -versions. +**Note:** The Docker image tag encodes the ROCm, Python, and PyTorch +versions (e.g. `flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1`). +If you change any of the `--build-arg` values in `docker build`, update +the `-t` tag accordingly and pass the matching tag to `docker run`. ### Building and Installing a Wheel Package From 16afd1808cd0d237cf7d658b80931d7b909384b3 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 16:11:20 +0000 Subject: [PATCH 19/20] docs(readme): drop redundant Docker-tag note from dev-container section The build/run blocks already show the matching -t tag and the docker run image tag side-by-side; the extra explanatory note added noise without new information. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/README.md b/README.md index 2fc1bca041..1fca25d02a 100644 --- a/README.md +++ b/README.md @@ -219,11 +219,6 @@ docker run -it \ -**Note:** The Docker image tag encodes the ROCm, Python, and PyTorch -versions (e.g. `flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1`). -If you change any of the `--build-arg` values in `docker build`, update -the `-t` tag accordingly and pass the matching tag to `docker run`. - ### Building and Installing a Wheel Package **Build with JIT (Just-in-Time) compilation only:** From 6140da446a8a871bdf700e5a7e55aaf29184aea8 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Thu, 21 May 2026 16:20:29 +0000 Subject: [PATCH 20/20] docs(readme): address Copilot review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Feature matrix: add `pos_encoding_mode="NONE"` to batch decode AITER auto-routing criteria; add gfx942/gfx950 arch gate to the `append_paged_kv_cache` row. - AITER Support: clarify the in-tree backend strings per-op (`fa2` for attention wrappers vs `native` for `append_paged_kv_cache` / `rmsnorm`) and call out the two backend-specific quirks (`rmsnorm` auto stays on HIP, batch decode auto avoids CUDA-graph / tensor cores). - Known Limitations: promote `pos_encoding_mode != "NONE"` and batch decode's `use_cuda_graph` / `use_tensor_cores` from the silently-ignored group to the hard-error / fallback group; the AITER attention paths reject them outright. - Runtime Helpers: add the missing `import torch` to the snippet and correct the `is_aiter_supported` comment — the function only checks ROCm build + GPU arch, not whether the `aiter` Python package can actually be imported. - CLAUDE.md: update the README anchor link to follow the renamed "GPU, ROCm, and PyTorch Support" section so cross-references stay live. Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 7 +++++-- README.md | 42 +++++++++++++++++++++++++++++++----------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 0a622166ad..05383b2c70 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -27,8 +27,9 @@ pip install torch== \ --index-url https://repo.radeon.com/rocm/manylinux/rocm-rel-/ ``` -See the [GPU and ROCm Support](README.md#gpu-and-rocm-support) table in -`README.md` for current `` and `` values. +See the [GPU, ROCm, and PyTorch Support](README.md#gpu-rocm-and-pytorch-support) +table in `README.md` for current `` and `` +values. ## Non-Obvious Gotchas @@ -90,6 +91,8 @@ gh api repos/ROCm/flashinfer/pulls/ --method PATCH --field body="" gh api repos/ROCm/flashinfer/pulls/ --method PATCH --field body="$(cat /tmp/pr_body.md)" ``` +Ask to push to remote. + ## PR Description **Body** — include sections that apply, skip the rest: diff --git a/README.md b/README.md index 1fca25d02a..30f7fa4096 100644 --- a/README.md +++ b/README.md @@ -55,14 +55,14 @@ kernel for non-attention ops). **AITER** = ROCm AITER backend. | Kernel | HIP | AITER | `backend="auto"` resolves to | Notes | | :--- | :---: | :---: | :--- | :--- | | **Single decode attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA | -| **Batch decode attention (paged)** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no CUDA-graph + `use_tensor_cores=False`; else **HIP** | MHA / GQA / MQA; **fp8 KV-cache (E4M3FNUZ)** on the HIP path; sliding-window on the AITER path; CUDA-graph auto-routes back to HIP | +| **Batch decode attention (paged)** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + `pos_encoding_mode="NONE"` + no CUDA-graph + `use_tensor_cores=False`; else **HIP** | MHA / GQA / MQA; **fp8 KV-cache (E4M3FNUZ)** on the HIP path; sliding-window on the AITER path; CUDA-graph auto-routes back to HIP | | **Single prefill attention** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no custom mask + equal Q/KV dtypes & head dims + `pos_encoding_mode="NONE"`; else **HIP** | MHA / GQA / MQA; fp8 WIP | | **Batch prefill attention (paged + ragged)** | ✅ `fa2` | ✅ | Same auto criteria as single prefill | MHA / GQA / MQA; fp8 WIP. AITER native page sizes: `{16, 1024}` (`{128, 256, 1024}` on `amd-aiter==0.1.10`); other sizes go through a gather on the AITER path | | **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention; a fused single-kernel HIP variant is gated behind `FLASHINFER_HIP_FUSED_CASCADE=1` | | **MLA (Multi-Latent Attention)** | — | ✅ | **AITER** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; `backend="auto"` (default) resolves to `"aiter"` | | **POD attention** | TBD | — | n/a | Code present; **not yet validated on ROCm** | | **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ) | -| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path | +| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + gfx942/gfx950 + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path | | **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` | | **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | | | **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | @@ -315,9 +315,19 @@ to `"aiter"`. On gfx942/gfx950, `backend="auto"` (the default) selects AITER when the call is compatible (see [Known Limitations](#known-limitations) for the -full list) and otherwise falls back to the in-tree `fa2` HIP kernel, -emitting a one-time `logger.warning`. Pass `backend="aiter"` to require -AITER explicitly, or `backend="fa2"` to skip it. +full list) and otherwise falls back to the in-tree HIP kernel, emitting +a one-time `logger.warning`. Pass `backend="aiter"` to require AITER +explicitly, or pass the in-tree backend string to skip it: +`backend="fa2"` for the attention wrappers (single/batch +prefill/decode), `backend="native"` for non-attention ops +(`append_paged_kv_cache`, `rmsnorm`). Two backend-specific exceptions +to "auto picks AITER when supported": + +* `rmsnorm`: `backend="auto"` stays on the HIP `native` kernel; the + AITER path is opt-in via `backend="aiter"`. +* `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True` + force `auto` back to `fa2` (AITER decode does not support either), + and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`. Unless you are using the prebuilt Docker image, install AITER separately via one of the options below. @@ -343,10 +353,12 @@ pip install amd-aiter --index-url https://pypi.amd.com/simple/ AITER constraints fall into two groups: hard incompatibilities (the call errors with `backend="aiter"` and triggers fallback under `backend="auto"`), and silently-ignored kwargs (the call runs but the -flag has no effect on AITER — pass `backend="fa2"` explicitly if you -need any of them). +flag has no effect on AITER — pass the in-tree backend explicitly if +you need any of them: `backend="fa2"` for attention wrappers, or +`backend="native"` for `append_paged_kv_cache` / `rmsnorm`). -**Conditions that fall back to `fa2` under `backend="auto"`:** +**Conditions that fall back to the in-tree HIP kernel under +`backend="auto"`** (and raise under `backend="aiter"`): * GPU is not gfx942 or gfx950 * `kv_layout` is not `NHD` @@ -354,6 +366,8 @@ need any of them). * `q_dtype` is not `float16` / `bfloat16` (no fp32, fp8, or int8) * `q_dtype != kv_dtype` (mixed-precision Q/KV is unsupported) * `head_dim_qk != head_dim_vo` (e.g. DeepSeek-style MLA with 192/128 head dims) +* `pos_encoding_mode != "NONE"` (AITER attention paths only support `"NONE"`) +* batch decode: `use_cuda_graph=True` or `use_tensor_cores=True` * the `aiter` Python package is not importable **Features silently ignored on the AITER path** (kwargs are accepted by @@ -361,8 +375,10 @@ the FlashInfer wrapper but not forwarded to AITER, which can produce wrong results): * ALiBi slopes (`maybe_alibi_slopes`) -* in-kernel positional encoding modes (`pos_encoding_mode`, `rope_scale`, - `rope_theta`) +* RoPE scaling kwargs (`rope_scale`, `rope_theta`) — these are only + consumed alongside `pos_encoding_mode != "NONE"`, which AITER + attention rejects outright; the kwargs themselves pass through + silently when the mode is `"NONE"` * attention sinks (`sinks`) * multi-modal / prefix-cache helpers (`maybe_prefix_len_ptr`, `maybe_token_pos_in_items_ptr`, `maybe_max_item_len_ptr`) @@ -403,10 +419,14 @@ documented in [CLAUDE.md](CLAUDE.md). guarding code paths or diagnosing setup issues: ```python +import torch + from flashinfer.aiter_utils import is_aiter_supported from flashinfer.hip_utils import check_torch_rocm_compatibility -# True only on gfx942/gfx950 with the aiter package importable. +# True on gfx942/gfx950 (a ROCm build + supported GPU arch). Does *not* +# verify the `aiter` Python package is importable — wrap the actual +# AITER call in a try/except ImportError if you need that guarantee. if is_aiter_supported(torch.device("cuda")): ...