Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions .claude/skills/add-rocm-kernel/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
---
name: add-rocm-kernel
description: Step-by-step tutorial for adding new HIP kernels to FlashInfer+ROCm (amd-flashinfer)
---

# Adding a New Kernel to FlashInfer+ROCm

For a complete worked example to copy, read these together:
[`norm.cu`](../../../flashinfer/csrc_rocm/norm.cu) +
[`flashinfer_norm_binding.cu`](../../../flashinfer/csrc_rocm/flashinfer_norm_binding.cu) +
[`jit/norm.py`](../../../flashinfer/jit/norm.py) +
[`norm.py`](../../../flashinfer/norm.py). For plan-run / multi-backend / FP8 see
[`batch_prefill.cu`](../../../flashinfer/csrc_rocm/batch_prefill.cu) +
[`prefill_rocm.py`](../../../flashinfer/prefill_rocm.py).

## File touchpoints (every new op needs each row, in order)

| Step | File | Purpose |
| --- | --- | --- |
| 1 | `include/flashinfer/<op>.cuh` | Framework-agnostic kernel + launcher template. **No `<torch/...>` includes here.** |
| 2 | `flashinfer/csrc_rocm/<op>.cu` | PyTorch launcher: `at::Tensor` in, `at::hip::getCurrentHIPStream()`, `TORCH_CHECK`, `DISPATCH_PYTORCH_DTYPE_*`. |
| 3 | `flashinfer/csrc_rocm/flashinfer_<op>_binding.cu` | `TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("<op>", <op>); }`. |
| 4 (opt) | `flashinfer/csrc_rocm/<op>_customize_config.jinja` | Compile-time type specialization. Skip if runtime dispatch is enough. |
| 5 | `flashinfer/jit/<op>.py` | `gen_<op>_module() -> JitSpec` via `gen_jit_spec(...)`. |
| 6 | `flashinfer/<op>.py` | Python API: `@functools.cache` module loader, destination-passing (`out=`). |
| 7 | `tests/rocm_tests/test_<op>_hip.py` | Correctness tests; FP32 reference math, loose BF16 tolerances. |
| 8 | `flashinfer/jit/__init__.py` (`IS_HIP` branch) | `from .<op> import gen_<op>_module as gen_<op>_module`. |
| 9 | `flashinfer/__init__.py` (`IS_HIP` branch) | `from .<op> import <op> as <op>`. |
| 10 (opt) | `flashinfer/aot_hip.py` | Register `gen_<op>_module` for pre-compiled wheels. |

**Forgetting steps 8 and 9 is the most common bug** — the module compiles but is invisible from `import flashinfer`.

## CUDA → ROCm porting cheat sheet

When porting an upstream kernel, mechanically rewrite:

| Upstream CUDA | This fork |
| --- | --- |
| `csrc/<op>.cu` | `flashinfer/csrc_rocm/<op>.cu` |
| `#include "tvm_ffi_utils.h"` | `#include "pytorch_extension_utils.h"` |
| `tvm::ffi::TensorView` | `at::Tensor` |
| `TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, op)` | `TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("op", op); }` |
| `TVM_FFI_THROW(ValueError) << "..."` | `TORCH_CHECK(cond, "...")` |
| `DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16` | `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16` |
| `get_stream(tensor.device())` | `at::hip::getCurrentHIPStream()` |
| `c10::cuda::OptionalCUDAGuard` | `c10::hip::OptionalHIPGuardMasqueradingAsCUDA` |
| `nvcc` flags via `extra_cuda_cflags=[...]` | **Same kwarg name** (`extra_cuda_cflags`) — internally routed to `hipcc`. |
| `flashinfer/aot.py` registration | `flashinfer/aot_hip.py` |
| `tests/test_op.py` | `tests/rocm_tests/test_op_hip.py` |
| `supported_major_versions=[9, 10]` | No analogue. Guard at Python layer via `FLASHINFER_SUPPORTED_ROCM_ARCHS`. |
| `csrc/` (hardcoded) | `jit_env.FLASHINFER_CSRC_DIR` resolves to `flashinfer/csrc_rocm/` on HIP. **Never hardcode `csrc/`.** |
| `PYBIND11_MODULE(...)` | **Don't.** Use `TORCH_LIBRARY_FRAGMENT` (integrates with `torch.compile`). |

## Non-obvious gotchas

- **PyTorch's ROCm masquerade.** `input.device.type == "cuda"` even on AMD. Never check for `"hip"`. PyTorch's HIP namespaces are reachable via `at::hip::...` and `c10::hip::OptionalHIPGuardMasqueradingAsCUDA` (literally the type name).
- **`gpu_iface` over duplication.** If a primitive (MMA intrinsic, cross-lane shuffle, dtype container, warp reduction) differs between CUDA and HIP, add it under [`include/gpu_iface/backend/{cuda,hip}/`](../../../include/gpu_iface) and expose a common name from the top-level `gpu_iface/` header. Don't fork the kernel into `csrc_rocm/`. Existing HIP backends: `mma_hip.h`, `memory_ops_hip.h`, `math_hip.h`, `vec_dtypes_hip.h`.
- **`-ffast-math` adds `-ffinite-math-only` on clang/hipcc.** [`jit/core.py`](../../../flashinfer/jit/core.py) explicitly re-adds `-fno-finite-math-only` so kernels that use `-inf` as a sentinel (online-softmax Map+Reduce) keep working. CUDA's `-use_fast_math` does *not* enable finite-math-only — divergence to be aware of when porting.
- **`gen_jit_spec` auto-injects `--offload-arch=gfxNNN`** for every target arch plus `COMMON_HIPCC_FLAGS` (`-DFLASHINFER_ENABLE_HIP`, FP8 enables, etc.). Don't add `--offload-arch` by hand.
- **Validation macros** live in [`pytorch_extension_utils.h`](../../../flashinfer/csrc_rocm/pytorch_extension_utils.h): `CHECK_INPUT` (GPU + contiguous), `CHECK_LAST_DIM_CONTIGUOUS_INPUT`, `CHECK_EQ`, `CHECK_DIM`, `CHECK_GE`, `CHECK_SHAPE`. Dispatch macros: `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16` (FP16+BF16), `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8` (E4M3+E5M2, both `_fnuz` on CDNA3/4), and the unsuffixed `DISPATCH_PYTORCH_DTYPE_TO_CTYPE` (FP16+BF16+FP8 combined). There is **no** `_FP16_FP32` variant — if you need FP32, dispatch manually.
- **The `_jit_pybind.cu` naming pattern** (e.g. `batch_decode_jit_pybind.cu`) is used by newer AITER-integrated bindings; the older `flashinfer_<op>_binding.cu` pattern is used by everything else. Both work — match the neighbors.

## CDNA3 (`gfx942`) vs CDNA4 (`gfx950`)

- **Wavefront = 64 on both.** Anything ported from CUDA assuming warp = 32 is wrong. Use `warpSize` for portability.
- **FP8** is `__hip_fp8_e4m3_fnuz` / `__hip_fp8_e5m2_fnuz` on both. PyTorch dtype is `torch.float8_e4m3fnuz` (not `torch.float8_e4m3fn`, which is NVIDIA OCP FP8). Bit-exact parity with NVIDIA FP8 is not guaranteed — calibrate scale factors separately.
- **MFMA intrinsics:** CDNA4 has additional FP8 MFMA shapes not on CDNA3. Guard arch-specific intrinsics with `__gfx942__` / `__gfx950__` or compute-capability dispatch at the Python layer.
- **LDS / register / occupancy budgets differ.** Don't hard-code tile sizes — parameterize (Jinja) or query via `torch.cuda.get_device_properties(dev)` at plan time.

## Quick checklist before commit

- [ ] No `<torch/...>` under `include/`.
- [ ] Launcher uses `at::hip::getCurrentHIPStream()` + `OptionalHIPGuardMasqueradingAsCUDA`.
- [ ] Binding registered via `TORCH_LIBRARY_FRAGMENT`.
- [ ] JIT generator uses `jit_env.FLASHINFER_CSRC_DIR` (not hardcoded `csrc/`).
- [ ] Both `flashinfer/jit/__init__.py` and `flashinfer/__init__.py` IS_HIP branches updated.
- [ ] Test file under `tests/rocm_tests/` named `test_*_hip.py`.
- [ ] `pre-commit run -a` clean.
82 changes: 82 additions & 0 deletions .claude/skills/benchmark-kernel/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
---
name: benchmark-kernel
description: Guide for benchmarking FlashInfer+ROCm kernels on AMD Instinct (CDNA3/CDNA4)
---

# Benchmarking FlashInfer+ROCm Kernels

For a real driver script to copy, see
[`benchmarks/rocm_benchmarks/bench_fa2_prefill.py`](../../../benchmarks/rocm_benchmarks/bench_fa2_prefill.py) and [`benchmarks/rocm_benchmarks/bench_aiter_prefill.py`](../../../benchmarks/rocm_benchmarks/bench_aiter_prefill.py)
For the in-repo profiler wrapper, see [`rocm_profiler/rocm_profiler.py`](../../../rocm_profiler/rocm_profiler.py).

## Timing method matrix

| Method | When | How |
| --- | --- | --- |
| `flashinfer.testing.bench_gpu_time` | Quick in-loop check (kernels ≳ 50 µs) | Falls through to PyTorch `torch.cuda.Event` (HIP events under ROCm) automatically. |
| `rocm_profiler` (`RocmProfiler`) | Anything you intend to optimize | Two-phase: in-process median timing, then re-execs the same script under `rocprofv3` (sentinel: `_ROCM_PROFILER_INTERNAL`) for hardware counters. Produces roofline PNG. |
| `rocprofv3` directly | Full control over counter set | `rocprofv3 --stats --kernel-trace -- python script.py`; or `-i pmc.txt` for custom counters. |
| `omnitrace` | Host + device timeline when Python overhead is suspect | Installed separately. |

## Non-obvious gotchas

- **CUPTI is NVIDIA-only — `enable_cupti=True` on ROCm warns and falls back.** [`flashinfer/testing/utils.py:1010`](../../../flashinfer/testing/utils.py) routes through `bench_gpu_time_with_cupti`, which `try/except`s the `cupti` import, emits a `UserWarning`, and reverts to CUDA/HIP event timing. No functional benefit on ROCm; just leave `enable_cupti=False` (the default) so `bench_gpu_time` uses `torch.cuda.Event` (HIP events) directly without the warning.
- **AITER backend constraints, accurately:**
- Explicit `backend="aiter"` + `kv_layout != "NHD"` → `ValueError` at `plan()` time. Raised in the prefill wrapper, e.g. [`prefill_rocm.py:1978`](../../../flashinfer/prefill_rocm.py) (single/paged) and the batch-paged wrapper around line 2920. Not raised by auto-selection — that path silently falls back to `fa2`.
- Explicit `backend="aiter"` on non-gfx942/gfx950 → `RuntimeError`.
- `amd-aiter` not importable → `ImportError`.
- **"Native" page sizes** (no flat-gather): `{128, 256, 1024}` for `amd-aiter >= 0.1.10`, else `{16, 1024}` — see `_aiter_native_page_sizes()` in [`prefill_rocm.py:59`](../../../flashinfer/prefill_rocm.py). **Non-native page sizes are NOT rejected** — they go through a flat-gather code path. So the "{1, 16, 1024}" guidance from older docs is wrong.
- Auto-selection (no explicit `backend=`) silently falls back to `fa2` for any of: `kv_layout != "NHD"`, custom mask, dtype not in `{fp16, bf16}`, `dtype_q != dtype_kv`, `head_dim_qk != head_dim_vo`, `pos_encoding_mode != "NONE"`, or `amd-aiter` not importable. See `_auto_select_prefill_backend()` in [`prefill_rocm.py:311`](../../../flashinfer/prefill_rocm.py) for the authoritative list.
- **Always verify numerical parity before trusting perf numbers.** Compare default-HIP vs AITER outputs with `torch.testing.assert_close(rtol=1e-2, atol=1e-2)` for BF16/FP16 first.
- **`gcnArchName` is the unambiguous arch marker.** Device strings show `cuda:0` on AMD too. Record `torch.cuda.get_device_properties(0).gcnArchName` and `torch.version.hip` alongside every number — a `gfx942` / ROCm 7.2 result is not comparable to a `gfx950` / ROCm 7.0.2 result.

## What can actually be benchmarked on ROCm

Only the APIs in the `IS_HIP` branch of [`flashinfer/__init__.py`](../../../flashinfer/__init__.py) are callable. **Not** available: MLA, cascade, POD, FP4, MoE, cuDNN backends. Don't try to import them.

AITER backend available for: single prefill, batch prefill (paged + ragged) — opt in via `backend="aiter"`. Not available for decode, norm, rope, sampling, etc.

## `rocm_profiler` counter presets

Pass via `RocmProfiler(counters=...)` or `--counters` on the driver script.

| Preset | What it shows | Use for |
| --- | --- | --- |
| `roofline` (default) | `FetchSize`, `WriteSize`, MFMA ops, TCC DRAM requests | "Am I compute- or memory-bound?" |
| `compute` | MFMA ops + cycle counters | Matrix-core throughput |
| `memory` | L2 + DRAM breakdown | L2 hit-rate, HBM traffic |
| `occupancy` | `SQ_WAVES`, `SQ_BUSY_CYCLES`, `SQ_VALU_MFMA_BUSY_CYCLES`, `SQ_INSTS_LDS` | Wavefront density |
| `stall` | `SQ_WAIT_INST_VMEM`, `SQ_WAIT_INST_LDS` | Diagnose memory stalls |
| `basic` | `FetchSize` / `WriteSize` | Minimal baseline |

Or pass a path to a `rocprofv3`-native YAML for a custom counter set.

Driver script flags: `--timing-only` (skip rocprofv3), `--skip-roofline`, `--replot` (regen PNG from existing CSVs, no GPU), `--list-presets`.

Output (under `benchmarks/rocm_benchmarks/`, gitignored):

```text
<label>_timing.csv # median + std per config
<label>_counter_collection.csv # raw counters
<label>_roofline.png # only for counters=roofline
```

## Reproducibility checklist

1. **Warm up.** `dry_run_iters >= 5`; raise to 10–20 if std is high. First call includes JIT compile.
2. **Pin clocks** for sub-100-µs kernels:

```bash
rocm-smi --showclocks
sudo rocm-smi --setsclk 7
sudo rocm-smi --setmclk 3
```

3. **Record arch + ROCm version** in the log: `print(props.name, props.gcnArchName, torch.version.hip)`.
4. **Isolate the GPU:** `HIP_VISIBLE_DEVICES=N` (or `ROCR_VISIBLE_DEVICES=N`, one layer deeper).

## Troubleshooting `rocm_profiler`

- **Empty `_counter_collection.csv`:** `kernel_name_regex` doesn't match the mangled name. Run `rocprofv3 --stats --kernel-trace -- python my_bench.py` first and copy the prefix from `*_kernel_stats.csv`.
- **Hang or no output:** confirm `which rocprofv3` is on `PATH`; the wrapper uses script `print()` output as a heartbeat — make sure the `if __name__ == "__main__":` block prints something.
- **Use `--timing-only` first** to verify the kernel path works before involving `rocprofv3`.
82 changes: 82 additions & 0 deletions .claude/skills/debug-rocm-crash/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
---
name: debug-rocm-crash
description: Tutorial for debugging HIP kernel crashes in FlashInfer+ROCm using HIP/ROCm runtime tooling
---

# Debugging ROCm Crashes in FlashInfer+ROCm

> **Note:** earlier revisions of this skill (and CLAUDE.md) described a `@flashinfer_api`
> decorator with `FLASHINFER_LOGLEVEL` / `FLASHINFER_LOGDEST` env vars. **That machinery does
> not exist in this fork** — no matches in code (`git grep` under `flashinfer/` and `include/`
> returns nothing; the only hits are this disclaimer). Don't try to set those env vars —
> use the HIP/ROCm tooling below instead.

## The magic env-var combo

For an unknown HIP fault, set these **before** running so the traceback points at the actual faulting kernel:

```bash
export AMD_SERIALIZE_KERNEL=3 # pins fault to the actual faulting kernel
export HIP_LAUNCH_BLOCKING=1 # synchronous launches; tracebacks point at the right line
```

Both are near-zero-overhead and reasonable to leave on while iterating on a new kernel.

For an in-script view of what's being passed, wrap the suspect call with `print(input.shape, input.dtype, input.device, input.is_contiguous())` and `torch.cuda.synchronize()` immediately before the FlashInfer call — this gives you the same info `@flashinfer_api` would have, manually.

## Per-error recipe

| Symptom | First check |
| --- | --- |
| `Memory access fault by GPU node-N` / `hipErrorIllegalAddress` / "CUDA error: illegal memory access" (PyTorch's ROCm reports HIP errors as "CUDA" errors) | Run with the env combo above. Print tensor shapes/dtypes/strides just before the call. Verify: `is_contiguous()` where required, all tensors on the same `cuda:N`, `kv_indices` within `[0, num_pages)`, `head_dim_qk` matches between Q and KV. |
| `backend="aiter"` `ValueError` before launch | `kv_layout != "NHD"` (only NHD is allowed — raised in the prefill wrapper's `plan()`, e.g. [`prefill_rocm.py:1978`](../../../flashinfer/prefill_rocm.py)). |
| `backend="aiter"` `RuntimeError` | Non-gfx942/gfx950 GPU. |
| `backend="aiter"` `ImportError` | `amd-aiter` not installed (`pip install amd-aiter --index-url https://pypi.amd.com/simple/`). |
| `backend="aiter"` hard GPU fault mid-kernel | `amd-aiter` version mismatch vs. ROCm. Reinstall matching your ROCm version. Try the default HIP backend to confirm the bug is in AITER, not our side. |
| NaN / Inf in outputs | Insert `torch.isnan(t).any()` / `torch.isinf(t).any()` checks around the call. On CDNA3/4: `_fnuz` FP8 has different representable range than NVIDIA OCP FP8 — scale factors calibrated against NVIDIA refs overflow. Or `-inf` from a previous op fed into `exp`. Or `torch.empty` vs `torch.zeros`. |
| `HIP out of memory` | `rocm-smi --showmeminfo vram --showpids` — kill zombies. JIT-compile spike → `MAX_JOBS=1`. Other tenant → `HIP_VISIBLE_DEVICES=N`. |
| `expected scalar type X but found Y` (FP8 callsites) | PyTorch dtype for `_fnuz` FP8 is `torch.float8_e4m3fnuz` / `torch.float8_e5m2fnuz`, **not** `torch.float8_e4m3fn` (which is NVIDIA OCP FP8). A callsite expecting `e4m3fn` mis-dispatches on ROCm. |

## ROCm-specific tooling

| Tool | Use |
| --- | --- |
| `rocgdb --args python my_script.py` | CUDA-GDB equivalent. Inside: `catch throw`, `run`, `bt`, `info agents`, `info wavefronts`. |
| `ROCM_DEBUG_WAIT_FOR_DEBUGGER=1` | Process blocks at first GPU API call until `rocgdb -p <pid>` attaches. |
| `AMD_LOG_LEVEL=3` (or `4`) | HIP API + stream trace. Linear under `HIP_LAUNCH_BLOCKING=1`, so each Python call correlates 1:1 with HIP launches. |
| `HSA_ENABLE_DEBUG=1` | HSA layer trace (one below HIP — queues, agents). |
| `sudo dmesg -T \| grep -iE 'amdgpu\|kfd\|vm_fault'` | `VM_CONTEXT1_PROTECTION_FAULT_STATUS` gives page-fault class, access type, offending address — useful when Python only says `hipErrorIllegalAddress`. |
| `watch -n 1 'rocm-smi --showuse --showmeminfo vram --showpids'` | Hang diagnosis: 100% GPU + no SQ activity = looping kernel; VRAM still pinned after exit = another process holds it. |

`compute-sanitizer` / `cuda-gdb` have **no direct ROCm equivalent.** Closest workflow is the env-var combo above plus `rocgdb`.

## AMD-specific gotchas

- **PyTorch's ROCm masquerade.** Device strings show `cuda:0` on AMD; "CUDA error" messages may be HIP errors. The unambiguous arch field is `torch.cuda.get_device_properties(0).gcnArchName`.
- **Wavefront = 64**, not 32. Any representative-thread `printf` ported from CUDA needs `threadIdx.x % 64 == 0` (or use the `warpSize` builtin).
- **`FLASHINFER_JIT_DEBUG=1` is wired on the CUDA path only.** On HIP it does nothing for debug build flags (no `-O0 -g`). Add `-g` via `extra_cuda_cflags` in the JIT generator for the op being debugged, clear `~/.cache/flashinfer/`, retry. See CLAUDE.md "Non-Obvious Gotchas".
- **HIP installs are stripped.** `rocgdb` exits with `no symbol table loaded` unless you rebuild with `-g` (see previous bullet).
- **Device `printf` flushes on `torch.cuda.synchronize()`** — works the same as CUDA.
- **`HIP_VISIBLE_DEVICES`** is the canonical AMD scoping env var (`ROCR_VISIBLE_DEVICES` works one layer deeper). `CUDA_VISIBLE_DEVICES` may also be honored by PyTorch.

## Quick recipes

```bash
# Hard GPU fault
export AMD_SERIALIZE_KERNEL=3
export HIP_LAUNCH_BLOCKING=1
python my_script.py
# Python traceback now points at the right call. Also: sudo dmesg -T | tail -50

# Step into a kernel
export AMD_SERIALIZE_KERNEL=3
export HIP_LAUNCH_BLOCKING=1
rocgdb --args python my_script.py
# (rocgdb) catch throw
# (rocgdb) run
# (rocgdb) bt

# HIP API trace
AMD_LOG_LEVEL=3 HIP_LAUNCH_BLOCKING=1 python my_script.py 2> hip.trace
# grep hipLaunchKernel / hipMemcpy / error in hip.trace
```
Loading
Loading