From 171dd6292ad5e5a72765257cb91719156f3ee702 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 5 May 2026 13:08:44 +0000 Subject: [PATCH 01/19] fix(hip): guard cuda.h includes in mla_params.cuh and profiler.cuh Both headers unconditionally included , which is not available in the HIP compilation environment. Add !defined(__HIPCC__) && !defined(__HIP__) guards around the include in each file. profiler.cuh: stub get_timestamp() to return 0 on HIP; the PTX globaltimer register is NVIDIA-specific and has no HIP equivalent. Profiling via FLASHINFER_ENABLE_PROFILER will produce zero timestamps on ROCm. Co-Authored-By: Claude Sonnet 4.6 --- include/flashinfer/attention/mla_params.cuh | 2 ++ include/flashinfer/profiler.cuh | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/include/flashinfer/attention/mla_params.cuh b/include/flashinfer/attention/mla_params.cuh index 6da1ed7f53..9187cfc687 100644 --- a/include/flashinfer/attention/mla_params.cuh +++ b/include/flashinfer/attention/mla_params.cuh @@ -15,7 +15,9 @@ */ #ifndef FLASHINFER_MLA_PARAMS_CUH_ #define FLASHINFER_MLA_PARAMS_CUH_ +#if !defined(__HIPCC__) && !defined(__HIP__) #include +#endif #include "../fastdiv.cuh" #include "../profiler.cuh" diff --git a/include/flashinfer/profiler.cuh b/include/flashinfer/profiler.cuh index 7a4be6ee41..008c2e4989 100644 --- a/include/flashinfer/profiler.cuh +++ b/include/flashinfer/profiler.cuh @@ -16,7 +16,9 @@ #ifndef FLASHINFER_PROFILER_CUH_ #define FLASHINFER_PROFILER_CUH_ +#if !defined(__HIPCC__) && !defined(__HIP__) #include +#endif namespace flashinfer { @@ -55,7 +57,11 @@ __device__ __forceinline__ uint32_t encode_tag(uint32_t sm_id, uint32_t block_id __device__ __forceinline__ uint32_t get_timestamp() { volatile uint32_t ret; +#if !defined(__HIPCC__) && !defined(__HIP__) asm volatile("mov.u32 %0, %globaltimer_lo;" : "=r"(ret)); +#else + ret = 0; +#endif return ret; } From cf542f1a63425b67a3d160765d0fe5c7cf1eef50 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 5 May 2026 13:08:52 +0000 Subject: [PATCH 02/19] feat(hip): route MLA backend to HIP path in utils and mla.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit utils.py: determine_mla_backend returns "hip" on HIP/ROCm instead of falling through to the CUDA fa3/fa2 selection logic. mla.py: guard the jit.mla import (CUDA-only Cutlass module) behind if not IS_HIP, and raise RuntimeError in get_mla_module() on HIP so callers get a clear error rather than an ImportError at load time. Import gen_batch_mla_module unconditionally — it is exported by jit/__init__.py in the IS_HIP block. Co-Authored-By: Claude Sonnet 4.6 --- flashinfer/mla.py | 7 ++++++- flashinfer/utils.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index da57d94e6b..0669d66d7e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -19,10 +19,13 @@ import torch +from .device_utils import IS_HIP from .jit import gen_batch_mla_module -from .jit.mla import gen_mla_module from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend +if not IS_HIP: + from .jit.mla import gen_mla_module + def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): if q_nope_pe.ndim != 3: @@ -55,6 +58,8 @@ def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): @functools.cache def get_mla_module(): + if IS_HIP: + raise RuntimeError("Cutlass MLA backend is not supported on HIP/ROCm") return gen_mla_module().build_and_load() diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 6f9fe7e8b9..d59707e5c2 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -640,6 +640,10 @@ def is_sm121a_supported(device: torch.device) -> bool: def determine_mla_backend(device: torch.device) -> str: + from .device_utils import IS_HIP + + if IS_HIP: + return "hip" return "fa3" if is_sm90a_supported(device) else "fa2" From bc739dadb39e90ee33ff9406e6ab5f15d55ee1e2 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 5 May 2026 13:09:00 +0000 Subject: [PATCH 03/19] feat(hip): add JIT module generator for batch MLA attention modules_hip.py: add get_batch_mla_uri and gen_batch_mla_module following the same pattern as gen_batch_decode_module / gen_batch_prefill_module. The generator renders batch_mla_customize_config.jinja into a per-dtype config.inc, copies batch_mla.cu and batch_mla_jit_pybind.cu into the gen directory, and returns a JitSpec for on-demand compilation. jit/attention/__init__.py, jit/__init__.py: re-export the two new symbols in the IS_HIP block so they are importable via flashinfer.jit. Co-Authored-By: Claude Sonnet 4.6 --- flashinfer/jit/__init__.py | 2 + flashinfer/jit/attention/__init__.py | 2 + flashinfer/jit/attention/modules_hip.py | 80 +++++++++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 3be99e7bc0..97db277e0a 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -108,6 +108,8 @@ def get_cudnn_fmha_gen_module(): from .activation import gen_act_and_mul_module as gen_act_and_mul_module from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str from .attention import gen_batch_decode_module as gen_batch_decode_module + from .attention import gen_batch_mla_module as gen_batch_mla_module + from .attention import get_batch_mla_uri as get_batch_mla_uri from .attention import gen_batch_prefill_module as gen_batch_prefill_module from .attention import ( gen_customize_batch_decode_module as gen_customize_batch_decode_module, diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py index b749d5d012..9ffcf219a9 100644 --- a/flashinfer/jit/attention/__init__.py +++ b/flashinfer/jit/attention/__init__.py @@ -72,6 +72,8 @@ ) from .modules_hip import gen_single_decode_module as gen_single_decode_module from .modules_hip import gen_single_prefill_module as gen_single_prefill_module + from .modules_hip import gen_batch_mla_module as gen_batch_mla_module + from .modules_hip import get_batch_mla_uri as get_batch_mla_uri from .modules_hip import get_batch_decode_uri as get_batch_decode_uri from .modules_hip import get_batch_prefill_uri as get_batch_prefill_uri from .modules_hip import get_single_decode_uri as get_single_decode_uri diff --git a/flashinfer/jit/attention/modules_hip.py b/flashinfer/jit/attention/modules_hip.py index 28df48ef01..c93ae7528b 100644 --- a/flashinfer/jit/attention/modules_hip.py +++ b/flashinfer/jit/attention/modules_hip.py @@ -943,3 +943,83 @@ def gen_customize_batch_prefill_module( raise ValueError("FA3 backend not currently supported for ROCm") else: raise ValueError(f"Invalid backend: {backend}") + + +def get_batch_mla_uri( + backend: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_ckv: int, + head_dim_kpe: int, + use_profiler: bool, +) -> str: + return ( + f"batch_mla_attention_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" + f"head_dim_ckv_{head_dim_ckv}_" + f"head_dim_kpe_{head_dim_kpe}_" + f"profiler_{use_profiler}_hip" + ) + + +def gen_batch_mla_module( + backend: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_ckv: int, + head_dim_kpe: int, + use_profiler: bool, +) -> JitSpec: + if backend == "auto": + raise ValueError("backend should not be auto when jit_args is provided") + uri = get_batch_mla_uri( + backend, + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_ckv, + head_dim_kpe, + use_profiler, + ) + gen_directory = FLASHINFER_GEN_SRC_DIR / uri + os.makedirs(gen_directory, exist_ok=True) + + with open(FLASHINFER_CSRC_DIR / "batch_mla_customize_config.jinja") as f: + config_templ = jinja2.Template(f.read()) + + generated_inc_str = config_templ.render( + dtype_q=dtype_map_hip[dtype_q], + dtype_kv=dtype_map_hip[dtype_kv], + dtype_o=dtype_map_hip[dtype_o], + dtype_idx=dtype_map_hip[dtype_idx], + head_dim_ckv=head_dim_ckv, + head_dim_kpe=head_dim_kpe, + ) + + source_paths = [] + for filename in [ + "batch_mla.cu", + "batch_mla_jit_pybind.cu", + ]: + src_path = FLASHINFER_CSRC_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) + + generated_config_path = gen_directory / "batch_mla_config.inc" + write_if_different(generated_config_path, generated_inc_str) + + extra_cuda_cflags = [] + if use_profiler: + extra_cuda_cflags += ["-DFLASHINFER_ENABLE_PROFILER"] + + return gen_jit_spec(uri, source_paths, extra_cuda_cflags=extra_cuda_cflags) From 8df86cc17ad58509ab493120354ff2dabb2e78d5 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 5 May 2026 13:09:21 +0000 Subject: [PATCH 04/19] feat(hip): add batch MLA kernel source files (Phase 1: Plan only) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit batch_mla_customize_config.jinja: Jinja2 template rendered by gen_batch_mla_module into batch_mla_config.inc. Defines DTypeQ/KV/O, IdType, HEAD_DIM_CKV/KPE, and the DISPATCH_context macro that will be used by the run kernel in Phase 2. batch_mla.cu: implements BatchMLAPagedAttentionPlan by calling MLAPlan from flashinfer/attention/generic/scheduler.cuh. BatchMLAPagedAttentionRun is a stub that raises at call time — the run kernel is Phase 2. batch_mla_jit_pybind.cu: torch library fragment declaring the plan and run ops under the extension name, included by the JIT build system. Co-Authored-By: Claude Sonnet 4.6 --- flashinfer/csrc_rocm/batch_mla.cu | 51 +++++++++++++++++++ .../batch_mla_customize_config.jinja | 31 +++++++++++ flashinfer/csrc_rocm/batch_mla_jit_pybind.cu | 23 +++++++++ 3 files changed, 105 insertions(+) create mode 100644 flashinfer/csrc_rocm/batch_mla.cu create mode 100644 flashinfer/csrc_rocm/batch_mla_customize_config.jinja create mode 100644 flashinfer/csrc_rocm/batch_mla_jit_pybind.cu diff --git a/flashinfer/csrc_rocm/batch_mla.cu b/flashinfer/csrc_rocm/batch_mla.cu new file mode 100644 index 0000000000..f3f30bf51f --- /dev/null +++ b/flashinfer/csrc_rocm/batch_mla.cu @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 +#include + +#include + +#include "batch_mla_config.inc" +#include "pytorch_conversion_utils.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor pin_memory_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, + at::Tensor kv_len_arr, int64_t num_heads, int64_t head_dim_o, + bool causal) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + MLAPlanInfo plan_info; + int64_t batch_size = kv_len_arr.size(0); + + const c10::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(float_workspace_buffer.device()); + const hipStream_t stream = c10::hip::getCurrentHIPStream(); + + hipError_t status = + MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), pin_memory_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), + static_cast(kv_len_arr.data_ptr()), static_cast(batch_size), + static_cast(num_heads), static_cast(head_dim_o), causal, stream); + + TORCH_CHECK(status == hipSuccess, + "BatchMLAPagedAttentionPlan failed with error: ", hipGetErrorString(status)); + + return vec_to_tensor(plan_info.ToVector()); +} + +void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, + at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices, + at::Tensor o, std::optional maybe_lse, + int64_t mask_mode_code, int64_t num_heads, int64_t page_size, + double sm_scale, bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS) { + TORCH_CHECK(false, "BatchMLAPagedAttentionRun: not yet implemented on HIP/ROCm"); +} diff --git a/flashinfer/csrc_rocm/batch_mla_customize_config.jinja b/flashinfer/csrc_rocm/batch_mla_customize_config.jinja new file mode 100644 index 0000000000..c654a2f1a7 --- /dev/null +++ b/flashinfer/csrc_rocm/batch_mla_customize_config.jinja @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// NOTE: This file is generated from batch_mla_customize_config.jinja. +// Do not edit manually. +#pragma once +#include + +using namespace flashinfer; + +#ifdef FLASHINFER_ENABLE_PROFILER +#define ADDITIONAL_FUNC_PARAMS , at::Tensor profiler_buffer +#define ADDITIONAL_PARAMS_SETTER \ + params.profiler_buffer = static_cast(profiler_buffer.data_ptr()); +#else +#define ADDITIONAL_FUNC_PARAMS +#define ADDITIONAL_PARAMS_SETTER +#endif + +using DTypeQ = {{ dtype_q }}; +using DTypeKV = {{ dtype_kv }}; +using DTypeO = {{ dtype_o }}; +using IdType = {{ dtype_idx }}; +constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }}; +constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }}; + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \ + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ + using Params = MLAParams; \ + __VA_ARGS__(); \ + }) diff --git a/flashinfer/csrc_rocm/batch_mla_jit_pybind.cu b/flashinfer/csrc_rocm/batch_mla_jit_pybind.cu new file mode 100644 index 0000000000..dc8c1d1302 --- /dev/null +++ b/flashinfer/csrc_rocm/batch_mla_jit_pybind.cu @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 +#include "batch_mla_config.inc" +#include "pytorch_extension_utils.h" + +at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor pin_memory_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, + at::Tensor kv_len_arr, int64_t num_heads, int64_t head_dim_o, + bool causal); + +void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, + at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices, + at::Tensor o, std::optional maybe_lse, + int64_t mask_mode_code, int64_t num_heads, int64_t page_size, + double sm_scale, bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS); + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("plan", BatchMLAPagedAttentionPlan); + m.def("run", BatchMLAPagedAttentionRun); +} From d7a0e3f6017f045b4f7c98468d5d853ababc1396 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 12 May 2026 06:17:11 +0000 Subject: [PATCH 05/19] test(hip): add test_mla_hip.py for Phase 1 MLA plan validation - test_determine_mla_backend: verify determine_mla_backend returns "hip" on ROCm - test_batch_mla_plan: verify BatchMLAPagedAttentionPlan succeeds for typical DeepSeek-v2/v3 dims (head_dim_ckv=512, head_dim_kpe=64) across batch sizes, kv lengths, page sizes, and causal settings - test_batch_mla_run_raises: verify BatchMLAPagedAttentionRun raises RuntimeError with "not yet implemented on HIP" until Phase 2 is complete 39 tests, all passing. Co-Authored-By: Claude Sonnet 4.6 --- tests/rocm_tests/test_mla_hip.py | 122 +++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/rocm_tests/test_mla_hip.py diff --git a/tests/rocm_tests/test_mla_hip.py b/tests/rocm_tests/test_mla_hip.py new file mode 100644 index 0000000000..a60f3f5399 --- /dev/null +++ b/tests/rocm_tests/test_mla_hip.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +# HIP/ROCm tests for MLA (Multi-head Latent Attention) batch paged attention. +# +# Phase 1 coverage (plan-only): +# - test_determine_mla_backend: verify determine_mla_backend returns "hip" +# - test_batch_mla_plan: verify BatchMLAPagedAttentionPlan runs without error +# for typical DeepSeek-v2/v3 head dims across batch sizes and dtypes +# - test_batch_mla_run_raises: verify BatchMLAPagedAttentionRun raises +# RuntimeError with a clear "not yet implemented" message (Phase 2 stub) + +import math + +import pytest +import torch + +import flashinfer +import flashinfer.mla +from flashinfer.jit import build_jit_specs, gen_batch_mla_module +from flashinfer.utils import determine_mla_backend + +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 + + +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + build_jit_specs( + [ + gen_batch_mla_module( + "hip", + torch.float16, + torch.float16, + torch.float16, + torch.int32, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + False, + ) + ], + verbose=False, + ) + yield + + +def _make_wrapper(batch_size, dtype): + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") + return flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") + + +def _plan(wrapper, batch_size, kv_len, page_size, causal, dtype): + qo_len = 1 # decode: one query token per request + pages_per_req = math.ceil(kv_len / page_size) + q_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") + * pages_per_req + ) + kv_indices = torch.arange( + 0, batch_size * pages_per_req, dtype=torch.int32, device="cuda" + ) + kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda") + sm_scale = 1.0 / ((128 + 64) ** 0.5) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + 16, # num_heads + HEAD_DIM_CKV, + HEAD_DIM_KPE, + page_size, + causal, + sm_scale, + dtype, + dtype, + ) + return kv_lens, pages_per_req + + +def test_determine_mla_backend(): + device = torch.device("cuda") + backend = determine_mla_backend(device) + assert backend == "hip", f"expected 'hip', got {backend!r}" + + +@pytest.mark.parametrize("batch_size", [1, 4, 16]) +@pytest.mark.parametrize("kv_len", [1, 64, 512]) +@pytest.mark.parametrize("page_size", [1, 16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_batch_mla_plan(batch_size, kv_len, page_size, causal, dtype): + if causal and kv_len == 0: + pytest.skip("causal with kv_len=0 unsupported") + wrapper = _make_wrapper(batch_size, dtype) + # plan must not raise + _plan(wrapper, batch_size, kv_len, page_size, causal, dtype) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [64]) +@pytest.mark.parametrize("page_size", [16]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_batch_mla_run_raises(batch_size, kv_len, page_size, dtype): + wrapper = _make_wrapper(batch_size, dtype) + kv_lens, pages_per_req = _plan(wrapper, batch_size, kv_len, page_size, False, dtype) + + q_nope = torch.randn(batch_size, 16, HEAD_DIM_CKV, dtype=dtype, device="cuda") + q_pe = torch.randn(batch_size, 16, HEAD_DIM_KPE, dtype=dtype, device="cuda") + ckv = torch.randn( + batch_size * pages_per_req, page_size, HEAD_DIM_CKV, dtype=dtype, device="cuda" + ) + kpe = torch.randn( + batch_size * pages_per_req, page_size, HEAD_DIM_KPE, dtype=dtype, device="cuda" + ) + + with pytest.raises(RuntimeError, match="not yet implemented on HIP"): + wrapper.run(q_nope, q_pe, ckv, kpe) From b40b4355981985a536a2d0a100a04661c7d0ccce Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 12 May 2026 14:30:13 +0000 Subject: [PATCH 06/19] feat(hip): implement MLA attention Run path with CDNA3 MFMA correctness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires up the BatchMLAPagedAttentionRun host stub to a new correctness-first HIP MLA kernel and expands the test to cover end-to-end output, not just plan. Phase 1 kernel (mla_hip.cuh) is FA2-style: 256 threads / 4 wavefronts, each wave owns one head_dim shard, KV streamed through LDS one tile at a time (no swizzle, no pipelining — those are Phase 2). Two CDNA3 MFMA subtleties are load-bearing: 1. mfma_f32_16x16x16f16 produces D in column-major per-thread layout (thread t holds D[m=(t/16)*4+r][n=t%16]). To keep s_frag/o_frag in the more natural row-major form (1 row, 4 cols per thread), we swap the A/B operands in both MFMAs (A=K,B=Q for QK; A=V,B=P for PV) — this makes the result of D = A*B equivalent to the desired matmul transposed, which combined with the column-major output layout yields row-major per-thread fragments without the per-iteration in-register transpose that the generic prefill path uses. 2. partial_lse must be log2(d) + m * sm_scale_log2 (a true log-sum-exp in log2 base), not log2(d) + m. The downstream merge weights partials by 2^(lse_p - global_max), which is only proportional to the true softmax weight d_p * exp(m_p * sm_scale) when m is in the same scaled space. Mirrors the CUDA path's finalize_m step. Co-Authored-By: Claude Sonnet 4.6 --- flashinfer/csrc_rocm/batch_mla.cu | 89 +++- include/flashinfer/attention/mla_hip.cuh | 566 +++++++++++++++++++++++ tests/rocm_tests/test_mla_hip.py | 76 ++- 3 files changed, 715 insertions(+), 16 deletions(-) create mode 100644 include/flashinfer/attention/mla_hip.cuh diff --git a/flashinfer/csrc_rocm/batch_mla.cu b/flashinfer/csrc_rocm/batch_mla.cu index f3f30bf51f..a83ea7c7fb 100644 --- a/flashinfer/csrc_rocm/batch_mla.cu +++ b/flashinfer/csrc_rocm/batch_mla.cu @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 #include -#include +#include #include "batch_mla_config.inc" #include "pytorch_conversion_utils.h" @@ -47,5 +47,90 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t num_heads, int64_t page_size, double sm_scale, bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS) { - TORCH_CHECK(false, "BatchMLAPagedAttentionRun: not yet implemented on HIP/ROCm"); + MLAPlanInfo plan_info; + plan_info.FromVector(tensor_to_vec(plan_info_vec)); + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + const MaskMode mask_mode = static_cast(mask_mode_code); + + const c10::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(q_nope.device()); + const hipStream_t stream = c10::hip::getCurrentHIPStream(); + + unsigned int q_nope_stride_n = q_nope.stride(0); + unsigned int q_nope_stride_h = q_nope.stride(1); + unsigned int q_pe_stride_n = q_pe.stride(0); + unsigned int q_pe_stride_h = q_pe.stride(1); + unsigned int ckv_stride_page = ckv_cache.stride(0); + unsigned int ckv_stride_n = ckv_cache.stride(1); + unsigned int kpe_stride_page = kpe_cache.stride(0); + unsigned int kpe_stride_n = kpe_cache.stride(1); + unsigned int o_stride_n = o.stride(0); + unsigned int o_stride_h = o.stride(1); + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] { + Params params; + + params.q_nope = static_cast(q_nope.data_ptr()); + params.q_pe = static_cast(q_pe.data_ptr()); + params.ckv = static_cast(ckv_cache.data_ptr()); + params.kpe = static_cast(kpe_cache.data_ptr()); + + params.q_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.partial_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.partial_indptr_offset); + params.kv_indices = static_cast(kv_indices.data_ptr()); + params.q_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_len_offset); + params.kv_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.q_start = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_start_offset); + params.kv_start = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_start_offset); + params.kv_end = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_end_offset); + params.work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.merge_packed_offset_start = GetPtrFromBaseOffset( + int_buffer_ptr, plan_info.merge_packed_offset_start_offset); + params.merge_packed_offset_end = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_packed_offset_end_offset); + params.merge_partial_packed_offset_start = GetPtrFromBaseOffset( + int_buffer_ptr, plan_info.merge_partial_packed_offset_start_offset); + params.merge_partial_packed_offset_end = GetPtrFromBaseOffset( + int_buffer_ptr, plan_info.merge_partial_packed_offset_end_offset); + params.merge_partial_stride = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_partial_stride_offset); + + params.final_o = static_cast(o.data_ptr()); + params.final_lse = + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.partial_o = + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); + params.partial_lse = + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_lse_offset); + + params.num_heads = uint_fastdiv(num_heads); + params.block_size = uint_fastdiv(page_size); + + params.q_nope_stride_n = q_nope_stride_n; + params.q_nope_stride_h = q_nope_stride_h; + params.q_pe_stride_n = q_pe_stride_n; + params.q_pe_stride_h = q_pe_stride_h; + params.ckv_stride_page = ckv_stride_page; + params.ckv_stride_n = ckv_stride_n; + params.kpe_stride_page = kpe_stride_page; + params.kpe_stride_n = kpe_stride_n; + params.o_stride_n = o_stride_n; + params.o_stride_h = o_stride_h; + + params.sm_scale = static_cast(sm_scale); + params.return_lse_base_on_e = return_lse_base_on_e; + + ADDITIONAL_PARAMS_SETTER + + hipError_t status = mla::BatchMLAPagedAttentionHIP( + params, plan_info.num_blks_x, plan_info.num_blks_y, stream); + TORCH_CHECK(status == hipSuccess, + "BatchMLAPagedAttentionRun failed: ", hipGetErrorString(status)); + }); } diff --git a/include/flashinfer/attention/mla_hip.cuh b/include/flashinfer/attention/mla_hip.cuh new file mode 100644 index 0000000000..e9f86fae30 --- /dev/null +++ b/include/flashinfer/attention/mla_hip.cuh @@ -0,0 +1,566 @@ +// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// Phase 1 correctness-first HIP MLA kernel for CDNA3 (MI300X). +// +// CTA config: dim3(64,4,1) — 4 wavefronts, 256 threads. +// All 4 wavefronts replicate QK independently; wave w owns +// head_dim shard [w * (HEAD_DIM_CKV/4) .. (w+1) * (HEAD_DIM_CKV/4)). +// Q is re-read from global memory each mma_d step to avoid the 64 KB LDS limit. +// +// MFMA layout note: CDNA3 mfma_f32_16x16x16f16 produces D in column-major +// per-thread layout (each thread holds D[m=(t/16)*4+r][n=t%16] for r=0..3). +// To get a more natural row-major s_frag/o_frag (each thread holds 1 row, 4 +// cols), we swap the A and B operands in both MFMAs: +// QK: A=K, B=Q → D[m=kv][n=q] = S[q][kv]; thread t's r-th reg = S[q=t%16][kv=col_group*4+r] +// PV: A=V, B=P → D[m=d][n=q] = O[q][d]; thread t's r-th reg = +// O[q=t%16][d=mma_d_abs*16+col_group*4+r] +// This avoids per-iteration in-register transposes of the s_frag tile. +#ifndef FLASHINFER_MLA_HIP_CUH_ +#define FLASHINFER_MLA_HIP_CUH_ + +#include + +#include "../fastdiv.cuh" +#include "gpu_iface/cooperative_groups.h" +#include "gpu_iface/enums.hpp" +#include "gpu_iface/math_ops.hpp" +#include "gpu_iface/mma_ops.hpp" +#include "gpu_iface/mma_types.hpp" +#include "gpu_iface/utils.cuh" +#include "mla_params.cuh" + +namespace flashinfer { +namespace mla { + +constexpr uint32_t MLA_HIP_WAVE_SIZE = 64; +constexpr uint32_t MLA_HIP_NUM_WAVES = 4; +constexpr uint32_t MLA_HIP_NUM_THREADS = MLA_HIP_WAVE_SIZE * MLA_HIP_NUM_WAVES; // 256 +constexpr uint32_t MLA_HIP_CTA_TILE_Q = 16; +constexpr uint32_t MLA_HIP_CTA_TILE_KV = 16; +// Sentinel for the running max in online softmax. Matches the CUDA path +// (state_t::init uses -math::inf, which has the same numeric value 5e4f). +constexpr float MLA_HIP_NEG_INF = -5e4f; + +// One KV tile in LDS, no swizzle. Phase 2 will introduce swizzling/pipelining. +template +struct SharedStorageMLAHIP { + DTypeKV ckv_smem[MLA_HIP_CTA_TILE_KV][HEAD_DIM_CKV]; + DTypeKV kpe_smem[MLA_HIP_CTA_TILE_KV][HEAD_DIM_KPE]; +}; + +// --------------------------------------------------------------------------- +// load_kv_hip: cooperative KV tile load, all 256 threads participate. +// packed_kv_tile_base = kv_indptr * block_size + kv_tile_abs_start +// packed_kv_bound = kv_indptr * block_size + kv_len +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void load_kv_hip( + DTypeKV* __restrict__ ckv_smem, DTypeKV* __restrict__ kpe_smem, + const DTypeKV* __restrict__ ckv_global, const DTypeKV* __restrict__ kpe_global, + const IdType* __restrict__ kv_indices, uint32_t ckv_stride_page, uint32_t ckv_stride_n, + uint32_t kpe_stride_page, uint32_t kpe_stride_n, uint32_t packed_kv_tile_base, + uint32_t packed_kv_bound, const uint_fastdiv& block_size, uint32_t tid) { + // CKV: CTA_TILE_KV * HEAD_DIM_CKV fp16 = CTA_TILE_KV * (HEAD_DIM_CKV/4) uint2 + constexpr uint32_t CKV_U2_PER_ROW = HEAD_DIM_CKV / 4; + constexpr uint32_t CKV_TOTAL_U2 = MLA_HIP_CTA_TILE_KV * CKV_U2_PER_ROW; + constexpr uint32_t CKV_U2_PER_THREAD = CKV_TOTAL_U2 / MLA_HIP_NUM_THREADS; + static_assert(CKV_TOTAL_U2 % MLA_HIP_NUM_THREADS == 0, "CKV load not evenly divisible"); + + uint2* smem_ckv = reinterpret_cast(ckv_smem); +#pragma unroll + for (uint32_t i = 0; i < CKV_U2_PER_THREAD; ++i) { + uint32_t u2_idx = tid * CKV_U2_PER_THREAD + i; + uint32_t kv_row = u2_idx / CKV_U2_PER_ROW; + uint32_t col_u2 = u2_idx % CKV_U2_PER_ROW; + uint32_t packed = packed_kv_tile_base + kv_row; + uint32_t page_idx, row_in_page; + block_size.divmod(packed, page_idx, row_in_page); + bool valid = (packed < packed_kv_bound); + const DTypeKV* gptr = ckv_global + + (valid ? kv_indices[page_idx] : IdType(0)) * ckv_stride_page + + row_in_page * ckv_stride_n + col_u2 * 4; + smem_ckv[u2_idx] = valid ? *reinterpret_cast(gptr) : uint2{0u, 0u}; + } + + // KPE: CTA_TILE_KV * HEAD_DIM_KPE fp16 = CTA_TILE_KV * (HEAD_DIM_KPE/4) uint2 + constexpr uint32_t KPE_U2_PER_ROW = HEAD_DIM_KPE / 4; + constexpr uint32_t KPE_TOTAL_U2 = MLA_HIP_CTA_TILE_KV * KPE_U2_PER_ROW; + constexpr uint32_t KPE_U2_PER_THREAD = KPE_TOTAL_U2 / MLA_HIP_NUM_THREADS; + static_assert(KPE_TOTAL_U2 % MLA_HIP_NUM_THREADS == 0, "KPE load not evenly divisible"); + + uint2* smem_kpe = reinterpret_cast(kpe_smem); +#pragma unroll + for (uint32_t i = 0; i < KPE_U2_PER_THREAD; ++i) { + uint32_t u2_idx = tid * KPE_U2_PER_THREAD + i; + uint32_t kv_row = u2_idx / KPE_U2_PER_ROW; + uint32_t col_u2 = u2_idx % KPE_U2_PER_ROW; + uint32_t packed = packed_kv_tile_base + kv_row; + uint32_t page_idx, row_in_page; + block_size.divmod(packed, page_idx, row_in_page); + bool valid = (packed < packed_kv_bound); + const DTypeKV* gptr = kpe_global + + (valid ? kv_indices[page_idx] : IdType(0)) * kpe_stride_page + + row_in_page * kpe_stride_n + col_u2 * 4; + smem_kpe[u2_idx] = valid ? *reinterpret_cast(gptr) : uint2{0u, 0u}; + } +} + +// --------------------------------------------------------------------------- +// compute_qk_hip: accumulate s_frag = Q_pe*KPE^T + Q_nope*CKV^T +// +// CDNA3 MFMA computes D = A*B where the D output is column-major per-thread: +// thread t holds D[m=(t/16)*4+r][n=t%16]. To get the natural row-major s_frag +// layout (thread t holds s_frag[r] = S[q=t%16][kv=col_group*4+r]) we swap the +// A and B operands: A=K, B=Q. The math: +// D[m=kv][n=q] = sum_d K[kv=m][d=k] * Q[q=n][d=k] = (K@Q^T)[kv][q] = S[q][kv]. +// Combined with the column-major D layout, thread t's r-th register = +// S[q=t%16][kv=(t/16)*4+r]. +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void compute_qk_hip(float* s_frag, const DTypeQ* q_nope, + const DTypeQ* q_pe, uint32_t q_nope_stride_n, + uint32_t q_nope_stride_h, uint32_t q_pe_stride_n, + uint32_t q_pe_stride_h, uint32_t qo_packed_idx_base, + const uint_fastdiv& num_heads, uint32_t q_len, + const DTypeKV* ckv_smem, const DTypeKV* kpe_smem, + uint32_t lane_idx) { + constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16; + constexpr uint32_t NUM_MMA_D_KPE = HEAD_DIM_KPE / 16; + constexpr uint32_t CKV_U2_PER_ROW = HEAD_DIM_CKV / 4; + constexpr uint32_t KPE_U2_PER_ROW = HEAD_DIM_KPE / 4; + + uint32_t q_row = lane_idx % 16; + uint32_t col_group = lane_idx / 16; + uint32_t q_packed = qo_packed_idx_base + q_row; + uint32_t batch_idx, head_idx; + num_heads.divmod(q_packed, batch_idx, head_idx); + bool q_valid = (batch_idx < q_len); + + const uint2* smem_ckv = reinterpret_cast(ckv_smem); + const uint2* smem_kpe = reinterpret_cast(kpe_smem); + + // KPE tiles run first — kInit on tile 0 zeros the accumulator before CKV accumulates. +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_KPE; ++mma_d) { + uint32_t q_frag[2] = {0u, 0u}; + if (q_valid) { + const DTypeQ* ptr = + q_pe + batch_idx * q_pe_stride_n + head_idx * q_pe_stride_h + mma_d * 16 + col_group * 4; + *reinterpret_cast(q_frag) = *reinterpret_cast(ptr); + } + uint32_t k_frag[2]; + *reinterpret_cast(k_frag) = smem_kpe[q_row * KPE_U2_PER_ROW + mma_d * 4 + col_group]; + + if (mma_d == 0) { + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag, k_frag, q_frag); + } else { + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag, k_frag, q_frag); + } + } + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV; ++mma_d) { + uint32_t q_frag[2] = {0u, 0u}; + if (q_valid) { + const DTypeQ* ptr = q_nope + batch_idx * q_nope_stride_n + head_idx * q_nope_stride_h + + mma_d * 16 + col_group * 4; + *reinterpret_cast(q_frag) = *reinterpret_cast(ptr); + } + uint32_t k_frag[2]; + *reinterpret_cast(k_frag) = smem_ckv[q_row * CKV_U2_PER_ROW + mma_d * 4 + col_group]; + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag, k_frag, q_frag); + } +} + +// --------------------------------------------------------------------------- +// logits_mask_hip: zero out scores for out-of-bounds or causally-masked positions. +// s_frag[r] = S[q_row=lane%16][kv_col = kv_idx_base + (lane/16)*4 + r] +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void logits_mask_hip(float* s_frag, uint32_t qo_packed_idx_base, + uint32_t kv_idx_base, uint32_t q_len, + uint32_t kv_len, uint32_t kv_end, + const uint_fastdiv& num_heads, uint32_t lane_idx) { + uint32_t q_idx = (qo_packed_idx_base + lane_idx % 16) / num_heads; + uint32_t col_group = lane_idx / 16; + +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) { + uint32_t kv_idx = kv_idx_base + col_group * 4 + r; + bool out_of_bound = (kv_idx >= kv_end); + bool causal_masked = CAUSAL && (kv_idx + q_len > kv_len + q_idx); + if (out_of_bound || causal_masked) s_frag[r] = MLA_HIP_NEG_INF; + } +} + +// --------------------------------------------------------------------------- +// update_mdo_states_hip: online softmax — rescale (m, d, o), then exp(s - m) +// in-place into s_frag and accumulate the row sum into d_scalar. +// +// Layout: each thread holds 4 kv values for one q_row (s_frag[r] = S[t%16][col_group*4+r]). +// Row max/sum reduce within the thread, then butterfly-XOR over bits 4 and 5 +// (the col_group dimension) to combine across all 4 threads with the same q_row. +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void update_mdo_states_hip(float* s_frag, float (*o_frag)[4], + float& m_val, float& d_scalar, + float sm_scale_log2) { + float m_local = fmaxf(fmaxf(s_frag[0], s_frag[1]), fmaxf(s_frag[2], s_frag[3])); + m_local = fmaxf(m_local, math::shfl_xor_sync(m_local, 0x10)); + m_local = fmaxf(m_local, math::shfl_xor_sync(m_local, 0x20)); + + float m_prev = m_val; + m_val = fmaxf(m_prev, m_local); + + float o_scale = math::ptx_exp2((m_prev - m_val) * sm_scale_log2); + d_scalar *= o_scale; + +#pragma unroll + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) { +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] *= o_scale; + } + + const float m_scaled = m_val * sm_scale_log2; + float partial_d = 0.f; +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) { + float p = math::ptx_exp2(s_frag[r] * sm_scale_log2 - m_scaled); + s_frag[r] = p; + partial_d += p; + } + partial_d += math::shfl_xor_sync(partial_d, 0x10); + partial_d += math::shfl_xor_sync(partial_d, 0x20); + d_scalar += partial_d; +} + +// --------------------------------------------------------------------------- +// compute_pv_hip: accumulate P*V into o_frag for this wave's head_dim shard. +// s_frag must already hold exp-scaled values (output of update_mdo_states_hip). +// +// As with compute_qk_hip, we swap A/B so the column-major D output gives a +// row-major o_frag: thread t's r-th register = O[q=t%16][d=mma_d_abs*16+col_group*4+r]. +// We pass A=V (loaded strided so f16x4[j]=V[kv=col_group*4+j][d=mma_d_abs*16+t%16]) +// and B=P (s_frag, with f16x4[r]=P[q=t%16][kv=col_group*4+r]). +// The math: D[m=d_local][n=q] = sum_kv V[kv][d] * P[q][kv] = O[q][d_global]. +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void compute_pv_hip(float (*o_frag)[4], const float* s_frag, + const DTypeKV* ckv_smem, uint32_t wave_idx, + uint32_t lane_idx) { + uint32_t q_row = lane_idx % 16; + uint32_t col_group = lane_idx / 16; + uint32_t d_start = wave_idx * NUM_MMA_D_PER_WAVE; + + // Pack s_frag (4 floats) into 2 uint32_t in fp16 format for MFMA B input + DTypeKV p_f16[4]; +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) p_f16[r] = static_cast(s_frag[r]); + uint32_t p_frag[2]; + p_frag[0] = *reinterpret_cast(&p_f16[0]); + p_frag[1] = *reinterpret_cast(&p_f16[2]); + +#pragma unroll + for (uint32_t wave_mma_d = 0; wave_mma_d < NUM_MMA_D_PER_WAVE; ++wave_mma_d) { + uint32_t mma_d_abs = d_start + wave_mma_d; + // Load V via 4 strided scalar reads: V[kv=col_group*4+j][d=mma_d_abs*16+q_row]. + // Used as A operand: A[m=t%16][k=col_group*4+j] = V[kv=k][d=mma_d_abs*16+m]. + uint32_t d_col = mma_d_abs * 16 + q_row; + DTypeKV v_vals[4]; +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + v_vals[j] = ckv_smem[(col_group * 4 + j) * HEAD_DIM_CKV + d_col]; + } + uint32_t v_frag[2]; + v_frag[0] = (uint32_t)(*reinterpret_cast(&v_vals[0])) | + ((uint32_t)(*reinterpret_cast(&v_vals[1])) << 16); + v_frag[1] = (uint32_t)(*reinterpret_cast(&v_vals[2])) | + ((uint32_t)(*reinterpret_cast(&v_vals[3])) << 16); + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32(o_frag[wave_mma_d], v_frag, + p_frag); + } +} + +// --------------------------------------------------------------------------- +// normalize_d_hip: divide o_frag by the softmax denominator. +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void normalize_d_hip(float (*o_frag)[4], float m_val, float d_scalar) { + float d_rcp = (m_val != MLA_HIP_NEG_INF) ? math::ptx_rcp(d_scalar) : 0.f; +#pragma unroll + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) { +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] *= d_rcp; + } +} + +// --------------------------------------------------------------------------- +// write_o_hip: store results to final_o or partial_o. +// +// Each thread writes for q_row = lane_idx%16 and its wave's head_dim shard. +// Only wave_idx==0 && col_group==0 (lane_idx<16) writes the LSE scalar. +// +// partial_o / partial_lse are the GLOBAL arrays; partial_indptr is the start +// row index into these arrays for this CTA's work item. Pass -1 if no partial. +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void write_o_hip( + float (*o_frag)[4], float m_val, float d_scalar, DTypeO* final_o, float* final_lse, + DTypeO* partial_o, float* partial_lse, uint32_t o_stride_n, uint32_t o_stride_h, uint32_t q_len, + uint32_t qo_packed_idx_base, int32_t partial_indptr, const uint_fastdiv& num_heads, + bool return_lse_base_on_e, uint32_t lane_idx, uint32_t wave_idx) { + uint32_t q_row = lane_idx % 16; + uint32_t col_group = lane_idx / 16; + uint32_t q_packed = qo_packed_idx_base + q_row; + uint32_t batch_idx, head_idx; + num_heads.divmod(q_packed, batch_idx, head_idx); + bool q_valid = (batch_idx < q_len); + uint32_t d_start = wave_idx * NUM_MMA_D_PER_WAVE; + + if (partial_indptr >= 0) { + uint32_t partial_row = static_cast(partial_indptr) + q_row; + if (q_valid) { + // One thread per q_row writes the LSE (lane_idx < 16 means col_group == 0) + if (wave_idx == 0 && col_group == 0) { + partial_lse[partial_row] = math::ptx_log2(d_scalar) + m_val; + } +#pragma unroll + for (uint32_t wave_mma_d = 0; wave_mma_d < NUM_MMA_D_PER_WAVE; ++wave_mma_d) { + uint32_t head_dim_col = (d_start + wave_mma_d) * 16 + col_group * 4; + DTypeO* ptr = partial_o + partial_row * HEAD_DIM_CKV + head_dim_col; +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) ptr[r] = static_cast(o_frag[wave_mma_d][r]); + } + } + } else { + if (q_valid) { + if (final_lse && wave_idx == 0 && col_group == 0) { + float lse = math::ptx_log2(d_scalar) + m_val; + if (return_lse_base_on_e) lse *= math::loge2; + final_lse[batch_idx * (uint32_t)num_heads + head_idx] = lse; + } +#pragma unroll + for (uint32_t wave_mma_d = 0; wave_mma_d < NUM_MMA_D_PER_WAVE; ++wave_mma_d) { + uint32_t head_dim_col = (d_start + wave_mma_d) * 16 + col_group * 4; + DTypeO* ptr = final_o + batch_idx * o_stride_n + head_idx * o_stride_h + head_dim_col; +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) ptr[r] = static_cast(o_frag[wave_mma_d][r]); + } + } + } +} + +// --------------------------------------------------------------------------- +// DevicePersistentMergeStatesHIP: reduce partial outputs → final output. +// +// partial_o[row * HEAD_DIM_CKV + col] stores DTypeO (normalized by each partial's d). +// partial_lse[row] = log2(d) + m for each partial block. +// Merge treats partial_lse as the combined log-sum-exp weight for each partial. +// --------------------------------------------------------------------------- +template +__device__ void DevicePersistentMergeStatesHIP( + const IdType* merge_packed_offset_start, const IdType* merge_packed_offset_end, + const IdType* merge_partial_packed_offset_start, const IdType* merge_partial_packed_offset_end, + const IdType* merge_partial_stride, const DTypeO* partial_o, const float* partial_lse, + DTypeO* final_o, float* final_lse, uint32_t o_stride_n, uint32_t o_stride_h, + const uint_fastdiv& num_heads, bool return_lse_base_on_e) { + constexpr uint32_t VEC_SIZE = 8; + constexpr uint32_t NUM_THRS_PER_ROW = HEAD_DIM_CKV / VEC_SIZE; + constexpr uint32_t ROWS_PER_ITER = MLA_HIP_NUM_THREADS / NUM_THRS_PER_ROW; + + uint32_t cta_idx = gridDim.x * blockIdx.y + blockIdx.x; + uint32_t thread_id = threadIdx.y * MLA_HIP_WAVE_SIZE + threadIdx.x; + + uint32_t offset_start = merge_packed_offset_start[cta_idx]; + uint32_t len = merge_packed_offset_end[cta_idx] - offset_start; + uint32_t partial_start = merge_partial_packed_offset_start[cta_idx]; + uint32_t partial_end = merge_partial_packed_offset_end[cta_idx]; + uint32_t stride = merge_partial_stride[cta_idx]; + + for (uint32_t local_row = thread_id / NUM_THRS_PER_ROW; local_row < len; + local_row += ROWS_PER_ITER) { + uint32_t global_packed = offset_start + local_row; + uint32_t q, r; + num_heads.divmod(global_packed, q, r); + uint32_t thr_col = thread_id % NUM_THRS_PER_ROW; + + // Accumulate state over partial blocks + // State representation: (st_m, st_d, st_o) where st_m tracks running max of partial_lse, + // st_d accumulates exp-weighted partial sums, st_o is the weighted output sum. + float st_o[VEC_SIZE]; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) st_o[i] = 0.f; + float st_m = MLA_HIP_NEG_INF; + float st_d = 1.f; + + for (uint32_t pp = partial_start + local_row; pp < partial_end; pp += stride) { + float other_lse = partial_lse[pp]; // = log2(d_partial) + m_partial + const DTypeO* src = partial_o + (uint64_t)pp * HEAD_DIM_CKV + thr_col * VEC_SIZE; + + float new_m = fmaxf(st_m, other_lse); + float scale_st = math::ptx_exp2(st_m - new_m); + float scale_other = math::ptx_exp2(other_lse - new_m); + st_d = st_d * scale_st + scale_other; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) + st_o[i] = st_o[i] * scale_st + static_cast(src[i]) * scale_other; + st_m = new_m; + } + + float d_rcp = (st_m != MLA_HIP_NEG_INF) ? math::ptx_rcp(st_d) : 0.f; + DTypeO* dst = final_o + q * o_stride_n + r * o_stride_h + thr_col * VEC_SIZE; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) dst[i] = static_cast(st_o[i] * d_rcp); + + if (final_lse && thr_col == 0) { + float lse = st_m + math::ptx_log2(st_d); + if (return_lse_base_on_e) lse *= math::loge2; + final_lse[q * (uint32_t)num_heads + r] = lse; + } + } +} + +// --------------------------------------------------------------------------- +// BatchMLAPagedAttentionKernelHIP: main GPU kernel +// --------------------------------------------------------------------------- +template +__global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKernelHIP( + Params params) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + + constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16; + constexpr uint32_t NUM_MMA_D_PER_WAVE = NUM_MMA_D_CKV / MLA_HIP_NUM_WAVES; + static_assert(NUM_MMA_D_CKV % MLA_HIP_NUM_WAVES == 0, + "HEAD_DIM_CKV must be divisible by 4 * 16 = 64"); + + uint32_t lane_idx = threadIdx.x; // 0..63 + uint32_t wave_idx = threadIdx.y; // 0..3 + uint32_t tid = wave_idx * MLA_HIP_WAVE_SIZE + lane_idx; // 0..255 + + extern __shared__ uint8_t smem[]; + using SharedStorage = SharedStorageMLAHIP; + SharedStorage& smem_storage = *reinterpret_cast(smem); + + const float sm_scale_log2 = params.sm_scale * math::log2e; + const uint_fastdiv& num_heads = params.num_heads; + const uint_fastdiv& block_size = params.block_size; + const uint32_t block_size_val = static_cast(block_size); + + float s_frag[4]; + float o_frag[NUM_MMA_D_PER_WAVE][4]; + float m_val, d_scalar; + + for (IdType work_idx = params.work_indptr[blockIdx.y]; + work_idx < params.work_indptr[blockIdx.y + 1]; ++work_idx) { + const uint32_t q_indptr = params.q_indptr[work_idx]; + const uint32_t kv_indptr = params.kv_indptr[work_idx]; + const int32_t partial_indptr = params.partial_indptr[work_idx]; + const uint32_t q_len = params.q_len[work_idx]; + const uint32_t kv_len = params.kv_len[work_idx]; + const uint32_t kv_start = params.kv_start[work_idx]; + const uint32_t kv_end = params.kv_end[work_idx]; + const uint32_t packed_qo_start = params.q_start[work_idx]; + + const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * MLA_HIP_CTA_TILE_Q; + + m_val = MLA_HIP_NEG_INF; + d_scalar = 1.f; +#pragma unroll + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] = 0.f; + + const uint32_t packed_kv_bound = kv_indptr * block_size_val + kv_len; + const int32_t num_kv_tiles = + static_cast(ceil_div(kv_end - kv_start, MLA_HIP_CTA_TILE_KV)); + + for (int32_t kv_tile_idx = num_kv_tiles - 1; kv_tile_idx >= 0; --kv_tile_idx) { + const uint32_t kv_tile_abs_start = + kv_start + static_cast(kv_tile_idx) * MLA_HIP_CTA_TILE_KV; + const uint32_t packed_kv_tile_base = kv_indptr * block_size_val + kv_tile_abs_start; + + load_kv_hip( + smem_storage.ckv_smem[0], smem_storage.kpe_smem[0], params.ckv, params.kpe, + params.kv_indices, params.ckv_stride_page, params.ckv_stride_n, params.kpe_stride_page, + params.kpe_stride_n, packed_kv_tile_base, packed_kv_bound, block_size, tid); + + __syncthreads(); + + compute_qk_hip( + s_frag, params.q_nope + q_indptr * params.q_nope_stride_n, + params.q_pe + q_indptr * params.q_pe_stride_n, params.q_nope_stride_n, + params.q_nope_stride_h, params.q_pe_stride_n, params.q_pe_stride_h, qo_packed_idx_base, + num_heads, q_len, smem_storage.ckv_smem[0], smem_storage.kpe_smem[0], lane_idx); + + logits_mask_hip(s_frag, qo_packed_idx_base, kv_tile_abs_start, q_len, kv_len, kv_end, + num_heads, lane_idx); + + update_mdo_states_hip(s_frag, o_frag, m_val, d_scalar, + sm_scale_log2); + + compute_pv_hip( + o_frag, s_frag, smem_storage.ckv_smem[0], wave_idx, lane_idx); + + __syncthreads(); + } + + normalize_d_hip(o_frag, m_val, d_scalar); + + // finalize_m: scale the running max into log2 space so that + // lse = log2(d_scaled) + m_scaled = log2(sum exp(s * sm_scale)) + // is a true log-sum-exp in log2 base. The downstream merge relies on this. + float m_scaled = (m_val != MLA_HIP_NEG_INF) ? m_val * sm_scale_log2 : m_val; + + write_o_hip( + o_frag, m_scaled, d_scalar, params.final_o + q_indptr * params.o_stride_n, + params.final_lse ? params.final_lse + q_indptr * (uint32_t)num_heads : nullptr, + params.partial_o, params.partial_lse, params.o_stride_n, params.o_stride_h, q_len, + qo_packed_idx_base, partial_indptr, num_heads, params.return_lse_base_on_e, lane_idx, + wave_idx); + } + + gpu_iface::cg::this_grid().sync(); + + DevicePersistentMergeStatesHIP( + params.merge_packed_offset_start, params.merge_packed_offset_end, + params.merge_partial_packed_offset_start, params.merge_partial_packed_offset_end, + params.merge_partial_stride, params.partial_o, params.partial_lse, params.final_o, + params.final_lse, params.o_stride_n, params.o_stride_h, num_heads, + params.return_lse_base_on_e); +} + +// --------------------------------------------------------------------------- +// BatchMLAPagedAttentionHIP: host-side launcher +// --------------------------------------------------------------------------- +template +hipError_t BatchMLAPagedAttentionHIP(Params params, uint32_t num_blks_x, uint32_t num_blks_y, + hipStream_t stream) { + if (MASK_MODE == MaskMode::kCustom) return hipErrorNotSupported; + constexpr bool CAUSAL = (MASK_MODE == MaskMode::kCausal); + + using SharedStorage = SharedStorageMLAHIP; + size_t smem_size = sizeof(SharedStorage); + + dim3 nblks(num_blks_x, num_blks_y); + dim3 nthrs(MLA_HIP_WAVE_SIZE, MLA_HIP_NUM_WAVES, 1); + + auto kernel = BatchMLAPagedAttentionKernelHIP; + hipError_t err = + hipFuncSetAttribute(reinterpret_cast(kernel), + hipFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); + if (err != hipSuccess) return err; + + void* args[] = {reinterpret_cast(¶ms)}; + return hipLaunchCooperativeKernel(reinterpret_cast(kernel), nblks, nthrs, args, + smem_size, stream); +} + +} // namespace mla +} // namespace flashinfer + +#endif // FLASHINFER_MLA_HIP_CUH_ diff --git a/tests/rocm_tests/test_mla_hip.py b/tests/rocm_tests/test_mla_hip.py index a60f3f5399..2c24a8b185 100644 --- a/tests/rocm_tests/test_mla_hip.py +++ b/tests/rocm_tests/test_mla_hip.py @@ -3,13 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 # HIP/ROCm tests for MLA (Multi-head Latent Attention) batch paged attention. -# -# Phase 1 coverage (plan-only): -# - test_determine_mla_backend: verify determine_mla_backend returns "hip" -# - test_batch_mla_plan: verify BatchMLAPagedAttentionPlan runs without error -# for typical DeepSeek-v2/v3 head dims across batch sizes and dtypes -# - test_batch_mla_run_raises: verify BatchMLAPagedAttentionRun raises -# RuntimeError with a clear "not yet implemented" message (Phase 2 stub) import math @@ -64,7 +57,7 @@ def _plan(wrapper, batch_size, kv_len, page_size, causal, dtype): 0, batch_size * pages_per_req, dtype=torch.int32, device="cuda" ) kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda") - sm_scale = 1.0 / ((128 + 64) ** 0.5) + sm_scale = 1.0 / ((HEAD_DIM_CKV + HEAD_DIM_KPE) ** 0.5) wrapper.plan( q_indptr, kv_indptr, @@ -82,6 +75,41 @@ def _plan(wrapper, batch_size, kv_len, page_size, causal, dtype): return kv_lens, pages_per_req +def _mla_reference(q_nope, q_pe, ckv, kpe, kv_lens, page_size, sm_scale, causal): + """Pure-PyTorch MLA reference: S = q_pe @ kpe^T + q_nope @ ckv^T; O = softmax(S) @ ckv.""" + batch_size, num_heads, _ = q_nope.shape + dtype = q_nope.dtype + + # ckv/kpe: [num_pages, page_size, head_dim] + # flatten to [batch_size, kv_len, head_dim] using kv_lens + max_kv_len = int(kv_lens.max().item()) + # pages are laid out sequentially per request + pages_per_req = math.ceil(max_kv_len / page_size) + + ckv_flat = ckv.reshape(batch_size, pages_per_req * page_size, HEAD_DIM_CKV)[ + :, :max_kv_len, : + ] + kpe_flat = kpe.reshape(batch_size, pages_per_req * page_size, HEAD_DIM_KPE)[ + :, :max_kv_len, : + ] + + # q: [batch, heads, dim] K/V: [batch, kv_len, dim] scores: [batch, heads, kv_len] + # decode is qo_len=1 so causal masking is a no-op; only kv-length padding matters. + del causal + scores = torch.einsum("bhd,bsd->bhs", q_pe.float(), kpe_flat.float()) + scores += torch.einsum("bhd,bsd->bhs", q_nope.float(), ckv_flat.float()) + scores = scores * sm_scale + + for b in range(batch_size): + kl = int(kv_lens[b].item()) + if kl < max_kv_len: + scores[b, :, kl:] = float("-inf") + + weights = torch.softmax(scores, dim=-1) + out = torch.einsum("bhs,bsd->bhd", weights, ckv_flat.float()) + return out.to(dtype) + + def test_determine_mla_backend(): device = torch.device("cuda") backend = determine_mla_backend(device) @@ -97,17 +125,21 @@ def test_batch_mla_plan(batch_size, kv_len, page_size, causal, dtype): if causal and kv_len == 0: pytest.skip("causal with kv_len=0 unsupported") wrapper = _make_wrapper(batch_size, dtype) - # plan must not raise _plan(wrapper, batch_size, kv_len, page_size, causal, dtype) @pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("kv_len", [64]) +@pytest.mark.parametrize("kv_len", [64, 256]) @pytest.mark.parametrize("page_size", [16]) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16]) -def test_batch_mla_run_raises(batch_size, kv_len, page_size, dtype): +def test_batch_mla_correctness(batch_size, kv_len, page_size, causal, dtype): + torch.manual_seed(42) wrapper = _make_wrapper(batch_size, dtype) - kv_lens, pages_per_req = _plan(wrapper, batch_size, kv_len, page_size, False, dtype) + sm_scale = 1.0 / ((HEAD_DIM_CKV + HEAD_DIM_KPE) ** 0.5) + kv_lens, pages_per_req = _plan( + wrapper, batch_size, kv_len, page_size, causal, dtype + ) q_nope = torch.randn(batch_size, 16, HEAD_DIM_CKV, dtype=dtype, device="cuda") q_pe = torch.randn(batch_size, 16, HEAD_DIM_KPE, dtype=dtype, device="cuda") @@ -118,5 +150,21 @@ def test_batch_mla_run_raises(batch_size, kv_len, page_size, dtype): batch_size * pages_per_req, page_size, HEAD_DIM_KPE, dtype=dtype, device="cuda" ) - with pytest.raises(RuntimeError, match="not yet implemented on HIP"): - wrapper.run(q_nope, q_pe, ckv, kpe) + out = wrapper.run(q_nope, q_pe, ckv, kpe) + + ref = _mla_reference( + q_nope, + q_pe, + ckv.reshape(batch_size, pages_per_req, page_size, HEAD_DIM_CKV).reshape( + batch_size * pages_per_req, page_size, HEAD_DIM_CKV + ), + kpe.reshape(batch_size, pages_per_req, page_size, HEAD_DIM_KPE).reshape( + batch_size * pages_per_req, page_size, HEAD_DIM_KPE + ), + kv_lens, + page_size, + sm_scale, + causal, + ) + + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) From 33eeb2a37c7f7f5bc11541914a0a2aaf1dfb4689 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 12 May 2026 14:50:28 +0000 Subject: [PATCH 07/19] =?UTF-8?q?perf(hip):=20MLA=20Phase=202a=20=E2=80=94?= =?UTF-8?q?=20uint4=20KV=20loads,=20LDS=20pad,=20hoisted=20Q=20&=20divmod?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three independent improvements over the Phase 1 correctness baseline: 1. Pad ckv_smem rows by 8 fp16. Without padding, the stride-HEAD_DIM_CKV reads in compute_pv_hip all land on LDS bank 0 (1024 bytes / 4 = 256, 256 % 32 banks = 0) — a 4-way bank conflict per thread. Padding spreads the per-thread reads across banks {b, b+8, b+16, b+24}. This was the dominant bottleneck. 2. Widen the global→LDS KV loads from uint2 (64-bit) to uint4 (128-bit) to halve the number of load instructions and improve DRAM throughput. HEAD_DIM_KPE=64 has fewer uint4s than threads; the helper falls back to a "first N threads load one each" path in that case. 3. Cache Q in registers across the kv_tile loop and hoist the per-q_packed divmod out of the loop — Q values and (batch_idx, head_idx) are loop- invariant. Skip the trailing __syncthreads on the last tile. Bench (deepseek MLA decode, num_heads=16, page_size=1): config Phase 1 Phase 2a speedup bs=64, seq=1024 384 GB/s 560 GB/s +46% bs=64, seq=2048 452 GB/s 744 GB/s +65% bs=16, seq=8192 441 GB/s 725 GB/s +64% bs=64, seq=8192 398 GB/s 801 GB/s +101% All 45 MLA correctness tests still pass. Phase 2b (async global→LDS via cp.async-equivalent, double-buffered KV pipeline, wave-specialized QK to remove 4× redundant work) remains TODO and will likely unlock the next factor of 2-3×. Co-Authored-By: Claude Sonnet 4.6 --- include/flashinfer/attention/mla_hip.cuh | 209 +++++++++++++++-------- 1 file changed, 137 insertions(+), 72 deletions(-) diff --git a/include/flashinfer/attention/mla_hip.cuh b/include/flashinfer/attention/mla_hip.cuh index e9f86fae30..64e1f78570 100644 --- a/include/flashinfer/attention/mla_hip.cuh +++ b/include/flashinfer/attention/mla_hip.cuh @@ -42,10 +42,16 @@ constexpr uint32_t MLA_HIP_CTA_TILE_KV = 16; // (state_t::init uses -math::inf, which has the same numeric value 5e4f). constexpr float MLA_HIP_NEG_INF = -5e4f; -// One KV tile in LDS, no swizzle. Phase 2 will introduce swizzling/pipelining. +// One KV tile in LDS. The CKV row is padded by 8 fp16 (one uint4) so that +// the stride-row reads in compute_pv_hip don't collide on the same LDS bank +// — without padding all 4 strided reads from one thread land on bank 0 +// (HEAD_DIM_CKV*2 bytes / 4 % 32 == 0), giving a 4-way bank conflict per +// thread. With +8 fp16 padding the per-thread reads spread across banks +// {b, b+8, b+16, b+24}. +constexpr uint32_t MLA_HIP_CKV_LDS_PAD = 8; template struct SharedStorageMLAHIP { - DTypeKV ckv_smem[MLA_HIP_CTA_TILE_KV][HEAD_DIM_CKV]; + DTypeKV ckv_smem[MLA_HIP_CTA_TILE_KV][HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD]; DTypeKV kpe_smem[MLA_HIP_CTA_TILE_KV][HEAD_DIM_KPE]; }; @@ -61,48 +67,110 @@ __device__ __forceinline__ void load_kv_hip( const IdType* __restrict__ kv_indices, uint32_t ckv_stride_page, uint32_t ckv_stride_n, uint32_t kpe_stride_page, uint32_t kpe_stride_n, uint32_t packed_kv_tile_base, uint32_t packed_kv_bound, const uint_fastdiv& block_size, uint32_t tid) { - // CKV: CTA_TILE_KV * HEAD_DIM_CKV fp16 = CTA_TILE_KV * (HEAD_DIM_CKV/4) uint2 - constexpr uint32_t CKV_U2_PER_ROW = HEAD_DIM_CKV / 4; - constexpr uint32_t CKV_TOTAL_U2 = MLA_HIP_CTA_TILE_KV * CKV_U2_PER_ROW; - constexpr uint32_t CKV_U2_PER_THREAD = CKV_TOTAL_U2 / MLA_HIP_NUM_THREADS; - static_assert(CKV_TOTAL_U2 % MLA_HIP_NUM_THREADS == 0, "CKV load not evenly divisible"); - - uint2* smem_ckv = reinterpret_cast(ckv_smem); + // 128-bit (uint4 = 8 fp16) loads to maximize per-instruction DRAM throughput. + // ckv_smem rows are padded by MLA_HIP_CKV_LDS_PAD fp16 (= 1 uint4); load by + // (kv_row, col_u4) to handle the LDS stride correctly. + constexpr uint32_t CKV_U4_PER_ROW = HEAD_DIM_CKV / 8; + constexpr uint32_t CKV_LDS_U4_STRIDE = (HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD) / 8; + constexpr uint32_t CKV_TOTAL_U4 = MLA_HIP_CTA_TILE_KV * CKV_U4_PER_ROW; + constexpr uint32_t CKV_U4_PER_THREAD = CKV_TOTAL_U4 / MLA_HIP_NUM_THREADS; + static_assert(CKV_TOTAL_U4 % MLA_HIP_NUM_THREADS == 0, "CKV load not evenly divisible"); + + uint4* smem_ckv = reinterpret_cast(ckv_smem); #pragma unroll - for (uint32_t i = 0; i < CKV_U2_PER_THREAD; ++i) { - uint32_t u2_idx = tid * CKV_U2_PER_THREAD + i; - uint32_t kv_row = u2_idx / CKV_U2_PER_ROW; - uint32_t col_u2 = u2_idx % CKV_U2_PER_ROW; + for (uint32_t i = 0; i < CKV_U4_PER_THREAD; ++i) { + uint32_t u4_idx = tid * CKV_U4_PER_THREAD + i; + uint32_t kv_row = u4_idx / CKV_U4_PER_ROW; + uint32_t col_u4 = u4_idx % CKV_U4_PER_ROW; uint32_t packed = packed_kv_tile_base + kv_row; uint32_t page_idx, row_in_page; block_size.divmod(packed, page_idx, row_in_page); bool valid = (packed < packed_kv_bound); const DTypeKV* gptr = ckv_global + (valid ? kv_indices[page_idx] : IdType(0)) * ckv_stride_page + - row_in_page * ckv_stride_n + col_u2 * 4; - smem_ckv[u2_idx] = valid ? *reinterpret_cast(gptr) : uint2{0u, 0u}; + row_in_page * ckv_stride_n + col_u4 * 8; + smem_ckv[kv_row * CKV_LDS_U4_STRIDE + col_u4] = + valid ? *reinterpret_cast(gptr) : uint4{0u, 0u, 0u, 0u}; } - // KPE: CTA_TILE_KV * HEAD_DIM_KPE fp16 = CTA_TILE_KV * (HEAD_DIM_KPE/4) uint2 - constexpr uint32_t KPE_U2_PER_ROW = HEAD_DIM_KPE / 4; - constexpr uint32_t KPE_TOTAL_U2 = MLA_HIP_CTA_TILE_KV * KPE_U2_PER_ROW; - constexpr uint32_t KPE_U2_PER_THREAD = KPE_TOTAL_U2 / MLA_HIP_NUM_THREADS; - static_assert(KPE_TOTAL_U2 % MLA_HIP_NUM_THREADS == 0, "KPE load not evenly divisible"); + constexpr uint32_t KPE_U4_PER_ROW = HEAD_DIM_KPE / 8; + constexpr uint32_t KPE_TOTAL_U4 = MLA_HIP_CTA_TILE_KV * KPE_U4_PER_ROW; + constexpr uint32_t KPE_U4_PER_THREAD = KPE_TOTAL_U4 / MLA_HIP_NUM_THREADS; + // KPE may be too small for one uint4 per thread; fall back to a per-thread loop only if work + // exists. For HEAD_DIM_KPE=64, CTA_TILE_KV=16: total=128 uint4, per-thread=0 (only 128 of 256 + // threads load). Handle as "first 128 threads do 1 load each". + uint4* smem_kpe = reinterpret_cast(kpe_smem); + if constexpr (KPE_U4_PER_THREAD >= 1) { +#pragma unroll + for (uint32_t i = 0; i < KPE_U4_PER_THREAD; ++i) { + uint32_t u4_idx = tid * KPE_U4_PER_THREAD + i; + uint32_t kv_row = u4_idx / KPE_U4_PER_ROW; + uint32_t col_u4 = u4_idx % KPE_U4_PER_ROW; + uint32_t packed = packed_kv_tile_base + kv_row; + uint32_t page_idx, row_in_page; + block_size.divmod(packed, page_idx, row_in_page); + bool valid = (packed < packed_kv_bound); + const DTypeKV* gptr = kpe_global + + (valid ? kv_indices[page_idx] : IdType(0)) * kpe_stride_page + + row_in_page * kpe_stride_n + col_u4 * 8; + smem_kpe[u4_idx] = valid ? *reinterpret_cast(gptr) : uint4{0u, 0u, 0u, 0u}; + } + } else { + // Less than 1 uint4 per thread: only the first KPE_TOTAL_U4 threads do a load. + if (tid < KPE_TOTAL_U4) { + uint32_t kv_row = tid / KPE_U4_PER_ROW; + uint32_t col_u4 = tid % KPE_U4_PER_ROW; + uint32_t packed = packed_kv_tile_base + kv_row; + uint32_t page_idx, row_in_page; + block_size.divmod(packed, page_idx, row_in_page); + bool valid = (packed < packed_kv_bound); + const DTypeKV* gptr = kpe_global + + (valid ? kv_indices[page_idx] : IdType(0)) * kpe_stride_page + + row_in_page * kpe_stride_n + col_u4 * 8; + smem_kpe[tid] = valid ? *reinterpret_cast(gptr) : uint4{0u, 0u, 0u, 0u}; + } + } +} + +// --------------------------------------------------------------------------- +// load_q_frags_hip: read this thread's Q fragments once from global memory. +// Q is loop-invariant across the kv_tile loop, so caching saves repeated +// global loads (per work_idx, save (num_kv_tiles - 1) reloads of all Q). +// +// q_pe_frag[mma_d][2] holds Q_pe[q=t%16][d=mma_d*16 + col_group*4 + 0..3] as 4 fp16. +// q_nope_frag is the same for Q_nope across NUM_MMA_D_CKV head_dim tiles. +// Out-of-range threads (batch_idx >= q_len) get zero-filled fragments. +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void load_q_frags_hip(uint32_t (*q_pe_frag)[2], + uint32_t (*q_nope_frag)[2], const DTypeQ* q_nope, + const DTypeQ* q_pe, uint32_t q_nope_stride_n, + uint32_t q_nope_stride_h, uint32_t q_pe_stride_n, + uint32_t q_pe_stride_h, uint32_t batch_idx, + uint32_t head_idx, bool q_valid, + uint32_t col_group) { + constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16; + constexpr uint32_t NUM_MMA_D_KPE = HEAD_DIM_KPE / 16; - uint2* smem_kpe = reinterpret_cast(kpe_smem); #pragma unroll - for (uint32_t i = 0; i < KPE_U2_PER_THREAD; ++i) { - uint32_t u2_idx = tid * KPE_U2_PER_THREAD + i; - uint32_t kv_row = u2_idx / KPE_U2_PER_ROW; - uint32_t col_u2 = u2_idx % KPE_U2_PER_ROW; - uint32_t packed = packed_kv_tile_base + kv_row; - uint32_t page_idx, row_in_page; - block_size.divmod(packed, page_idx, row_in_page); - bool valid = (packed < packed_kv_bound); - const DTypeKV* gptr = kpe_global + - (valid ? kv_indices[page_idx] : IdType(0)) * kpe_stride_page + - row_in_page * kpe_stride_n + col_u2 * 4; - smem_kpe[u2_idx] = valid ? *reinterpret_cast(gptr) : uint2{0u, 0u}; + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_KPE; ++mma_d) { + q_pe_frag[mma_d][0] = 0u; + q_pe_frag[mma_d][1] = 0u; + if (q_valid) { + const DTypeQ* ptr = + q_pe + batch_idx * q_pe_stride_n + head_idx * q_pe_stride_h + mma_d * 16 + col_group * 4; + *reinterpret_cast(q_pe_frag[mma_d]) = *reinterpret_cast(ptr); + } + } +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV; ++mma_d) { + q_nope_frag[mma_d][0] = 0u; + q_nope_frag[mma_d][1] = 0u; + if (q_valid) { + const DTypeQ* ptr = q_nope + batch_idx * q_nope_stride_n + head_idx * q_nope_stride_h + + mma_d * 16 + col_group * 4; + *reinterpret_cast(q_nope_frag[mma_d]) = *reinterpret_cast(ptr); + } } } @@ -117,60 +185,39 @@ __device__ __forceinline__ void load_kv_hip( // Combined with the column-major D layout, thread t's r-th register = // S[q=t%16][kv=(t/16)*4+r]. // --------------------------------------------------------------------------- -template -__device__ __forceinline__ void compute_qk_hip(float* s_frag, const DTypeQ* q_nope, - const DTypeQ* q_pe, uint32_t q_nope_stride_n, - uint32_t q_nope_stride_h, uint32_t q_pe_stride_n, - uint32_t q_pe_stride_h, uint32_t qo_packed_idx_base, - const uint_fastdiv& num_heads, uint32_t q_len, +template +__device__ __forceinline__ void compute_qk_hip(float* s_frag, const uint32_t (*q_pe_frag)[2], + const uint32_t (*q_nope_frag)[2], const DTypeKV* ckv_smem, const DTypeKV* kpe_smem, - uint32_t lane_idx) { + uint32_t q_row, uint32_t col_group) { constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16; constexpr uint32_t NUM_MMA_D_KPE = HEAD_DIM_KPE / 16; - constexpr uint32_t CKV_U2_PER_ROW = HEAD_DIM_CKV / 4; + constexpr uint32_t CKV_LDS_U2_STRIDE = (HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD) / 4; constexpr uint32_t KPE_U2_PER_ROW = HEAD_DIM_KPE / 4; - uint32_t q_row = lane_idx % 16; - uint32_t col_group = lane_idx / 16; - uint32_t q_packed = qo_packed_idx_base + q_row; - uint32_t batch_idx, head_idx; - num_heads.divmod(q_packed, batch_idx, head_idx); - bool q_valid = (batch_idx < q_len); - const uint2* smem_ckv = reinterpret_cast(ckv_smem); const uint2* smem_kpe = reinterpret_cast(kpe_smem); // KPE tiles run first — kInit on tile 0 zeros the accumulator before CKV accumulates. #pragma unroll for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_KPE; ++mma_d) { - uint32_t q_frag[2] = {0u, 0u}; - if (q_valid) { - const DTypeQ* ptr = - q_pe + batch_idx * q_pe_stride_n + head_idx * q_pe_stride_h + mma_d * 16 + col_group * 4; - *reinterpret_cast(q_frag) = *reinterpret_cast(ptr); - } uint32_t k_frag[2]; *reinterpret_cast(k_frag) = smem_kpe[q_row * KPE_U2_PER_ROW + mma_d * 4 + col_group]; - if (mma_d == 0) { gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32( - s_frag, k_frag, q_frag); + s_frag, k_frag, const_cast(q_pe_frag[mma_d])); } else { - gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag, k_frag, q_frag); + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag, k_frag, const_cast(q_pe_frag[mma_d])); } } #pragma unroll for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV; ++mma_d) { - uint32_t q_frag[2] = {0u, 0u}; - if (q_valid) { - const DTypeQ* ptr = q_nope + batch_idx * q_nope_stride_n + head_idx * q_nope_stride_h + - mma_d * 16 + col_group * 4; - *reinterpret_cast(q_frag) = *reinterpret_cast(ptr); - } uint32_t k_frag[2]; - *reinterpret_cast(k_frag) = smem_ckv[q_row * CKV_U2_PER_ROW + mma_d * 4 + col_group]; - gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag, k_frag, q_frag); + *reinterpret_cast(k_frag) = smem_ckv[q_row * CKV_LDS_U2_STRIDE + mma_d * 4 + col_group]; + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag, k_frag, const_cast(q_nope_frag[mma_d])); } } @@ -271,7 +318,7 @@ __device__ __forceinline__ void compute_pv_hip(float (*o_frag)[4], const float* DTypeKV v_vals[4]; #pragma unroll for (uint32_t j = 0; j < 4; ++j) { - v_vals[j] = ckv_smem[(col_group * 4 + j) * HEAD_DIM_CKV + d_col]; + v_vals[j] = ckv_smem[(col_group * 4 + j) * (HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD) + d_col]; } uint32_t v_frag[2]; v_frag[0] = (uint32_t)(*reinterpret_cast(&v_vals[0])) | @@ -434,6 +481,7 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer using IdType = typename Params::IdType; constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16; + constexpr uint32_t NUM_MMA_D_KPE = HEAD_DIM_KPE / 16; constexpr uint32_t NUM_MMA_D_PER_WAVE = NUM_MMA_D_CKV / MLA_HIP_NUM_WAVES; static_assert(NUM_MMA_D_CKV % MLA_HIP_NUM_WAVES == 0, "HEAD_DIM_CKV must be divisible by 4 * 16 = 64"); @@ -451,9 +499,14 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer const uint_fastdiv& block_size = params.block_size; const uint32_t block_size_val = static_cast(block_size); + const uint32_t q_row = lane_idx % 16; + const uint32_t col_group = lane_idx / 16; + float s_frag[4]; float o_frag[NUM_MMA_D_PER_WAVE][4]; float m_val, d_scalar; + uint32_t q_pe_frag[NUM_MMA_D_KPE][2]; + uint32_t q_nope_frag[NUM_MMA_D_CKV][2]; for (IdType work_idx = params.work_indptr[blockIdx.y]; work_idx < params.work_indptr[blockIdx.y + 1]; ++work_idx) { @@ -468,6 +521,20 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * MLA_HIP_CTA_TILE_Q; + // Hoist the q_packed → (batch_idx, head_idx) divmod out of the kv-tile loop; + // it's the same for every iteration of every helper that consumed it. + const uint32_t q_packed = qo_packed_idx_base + q_row; + uint32_t batch_idx, head_idx; + num_heads.divmod(q_packed, batch_idx, head_idx); + const bool q_valid = (batch_idx < q_len); + + // Q is loop-invariant across kv-tiles; load all fragments once into registers. + load_q_frags_hip( + q_pe_frag, q_nope_frag, params.q_nope + q_indptr * params.q_nope_stride_n, + params.q_pe + q_indptr * params.q_pe_stride_n, params.q_nope_stride_n, + params.q_nope_stride_h, params.q_pe_stride_n, params.q_pe_stride_h, batch_idx, head_idx, + q_valid, col_group); + m_val = MLA_HIP_NEG_INF; d_scalar = 1.f; #pragma unroll @@ -491,11 +558,9 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer __syncthreads(); - compute_qk_hip( - s_frag, params.q_nope + q_indptr * params.q_nope_stride_n, - params.q_pe + q_indptr * params.q_pe_stride_n, params.q_nope_stride_n, - params.q_nope_stride_h, params.q_pe_stride_n, params.q_pe_stride_h, qo_packed_idx_base, - num_heads, q_len, smem_storage.ckv_smem[0], smem_storage.kpe_smem[0], lane_idx); + compute_qk_hip( + s_frag, q_pe_frag, q_nope_frag, smem_storage.ckv_smem[0], smem_storage.kpe_smem[0], q_row, + col_group); logits_mask_hip(s_frag, qo_packed_idx_base, kv_tile_abs_start, q_len, kv_len, kv_end, num_heads, lane_idx); @@ -506,7 +571,7 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer compute_pv_hip( o_frag, s_frag, smem_storage.ckv_smem[0], wave_idx, lane_idx); - __syncthreads(); + if (kv_tile_idx > 0) __syncthreads(); } normalize_d_hip(o_frag, m_val, d_scalar); From bf6629c09f7e7eed2a07f3486074af9a3cae39aa Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 12 May 2026 15:11:24 +0000 Subject: [PATCH 08/19] =?UTF-8?q?perf(hip):=20MLA=20Phase=202b=20=E2=80=94?= =?UTF-8?q?=20double-buffered=20KV=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two-stage LDS for the KV tile and an explicit prologue+inner-loop pipeline: each iteration N issues the global→LDS load for tile N-1 into the opposite stage before computing on tile N. Without cp.async-equivalent on HIP, the loads still block the wave at the LDS write, but the AMD compiler is then free to schedule the global VGPR loads ahead of the prior tile's MFMAs, hiding some load latency. Bench (deepseek MLA decode, num_heads=16, page_size=1): config Phase 2a Phase 2b delta bs=64, seq=1024 560 GB/s 582 GB/s +4% bs=64, seq=2048 744 GB/s 792 GB/s +6% bs=16, seq=8192 725 GB/s 768 GB/s +6% bs=64, seq=8192 801 GB/s 881 GB/s +10% Cumulative since Phase 1 baseline at bs=64,seq=8192: 398 → 881 GB/s (+121%). LDS doubles to 37 KB per CTA (was 18.7), which drops occupancy from a theoretical 3 to 1 CTA/CU on MI300X — the +10% measured win confirms the compiler-scheduling benefit outweighs the occupancy hit at long sequences. A truly async global→LDS path (cp.async-equivalent in memory_ops_hip.h) plus wave-specialized QK to reclaim 4× redundant work would be the next lever — both need backend infrastructure work to unlock. All 45 MLA correctness tests still pass. Co-Authored-By: Claude Sonnet 4.6 --- include/flashinfer/attention/mla_hip.cuh | 71 +++++++++++++++++------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/include/flashinfer/attention/mla_hip.cuh b/include/flashinfer/attention/mla_hip.cuh index 64e1f78570..e03d29135f 100644 --- a/include/flashinfer/attention/mla_hip.cuh +++ b/include/flashinfer/attention/mla_hip.cuh @@ -42,17 +42,18 @@ constexpr uint32_t MLA_HIP_CTA_TILE_KV = 16; // (state_t::init uses -math::inf, which has the same numeric value 5e4f). constexpr float MLA_HIP_NEG_INF = -5e4f; -// One KV tile in LDS. The CKV row is padded by 8 fp16 (one uint4) so that -// the stride-row reads in compute_pv_hip don't collide on the same LDS bank -// — without padding all 4 strided reads from one thread land on bank 0 -// (HEAD_DIM_CKV*2 bytes / 4 % 32 == 0), giving a 4-way bank conflict per -// thread. With +8 fp16 padding the per-thread reads spread across banks -// {b, b+8, b+16, b+24}. +// Two KV tiles in LDS to support double-buffered load/compute pipelining. +// CKV rows are padded by 8 fp16 (one uint4) so that the stride-row reads in +// compute_pv_hip don't collide on the same LDS bank — without padding all 4 +// strided reads from one thread land on bank 0 (HEAD_DIM_CKV*2/4 % 32 == 0), +// giving a 4-way bank conflict per thread. With +8 fp16 padding the per-thread +// reads spread across banks {b, b+8, b+16, b+24}. constexpr uint32_t MLA_HIP_CKV_LDS_PAD = 8; +constexpr uint32_t MLA_HIP_NUM_STAGES = 2; template struct SharedStorageMLAHIP { - DTypeKV ckv_smem[MLA_HIP_CTA_TILE_KV][HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD]; - DTypeKV kpe_smem[MLA_HIP_CTA_TILE_KV][HEAD_DIM_KPE]; + DTypeKV ckv_smem[MLA_HIP_NUM_STAGES][MLA_HIP_CTA_TILE_KV][HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD]; + DTypeKV kpe_smem[MLA_HIP_NUM_STAGES][MLA_HIP_CTA_TILE_KV][HEAD_DIM_KPE]; }; // --------------------------------------------------------------------------- @@ -546,32 +547,60 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer const int32_t num_kv_tiles = static_cast(ceil_div(kv_end - kv_start, MLA_HIP_CTA_TILE_KV)); - for (int32_t kv_tile_idx = num_kv_tiles - 1; kv_tile_idx >= 0; --kv_tile_idx) { - const uint32_t kv_tile_abs_start = - kv_start + static_cast(kv_tile_idx) * MLA_HIP_CTA_TILE_KV; - const uint32_t packed_kv_tile_base = kv_indptr * block_size_val + kv_tile_abs_start; - + // Double-buffered pipeline: stage 0 loaded outside the loop, then each + // iteration N loads stage[(N+1)%2] while computing on stage[N%2]. The hope + // is the AMD compiler interleaves the global→VGPR→LDS load chain with the + // MFMAs on the previous tile's data, hiding global-load latency. + auto kv_tile_to_packed_base = [&](int32_t kv_tile_idx) { + return kv_indptr * block_size_val + kv_start + + static_cast(kv_tile_idx) * MLA_HIP_CTA_TILE_KV; + }; + auto kv_tile_to_abs_start = [&](int32_t kv_tile_idx) { + return kv_start + static_cast(kv_tile_idx) * MLA_HIP_CTA_TILE_KV; + }; + + // Prologue: load tile (num_kv_tiles - 1) into stage 0. + { + const int32_t first_tile = num_kv_tiles - 1; load_kv_hip( - smem_storage.ckv_smem[0], smem_storage.kpe_smem[0], params.ckv, params.kpe, + smem_storage.ckv_smem[0][0], smem_storage.kpe_smem[0][0], params.ckv, params.kpe, params.kv_indices, params.ckv_stride_page, params.ckv_stride_n, params.kpe_stride_page, - params.kpe_stride_n, packed_kv_tile_base, packed_kv_bound, block_size, tid); + params.kpe_stride_n, kv_tile_to_packed_base(first_tile), packed_kv_bound, block_size, + tid); + } + + uint32_t cur_stage = 0; + for (int32_t kv_tile_idx = num_kv_tiles - 1; kv_tile_idx >= 0; --kv_tile_idx) { + const uint32_t next_stage = 1 - cur_stage; + const bool has_next = (kv_tile_idx > 0); + + // Issue load for tile (kv_tile_idx - 1) into next_stage. Without + // cp.async, this still blocks the wave but lets the compiler schedule + // the global VGPR loads ahead of compute. + if (has_next) { + load_kv_hip( + smem_storage.ckv_smem[next_stage][0], smem_storage.kpe_smem[next_stage][0], params.ckv, + params.kpe, params.kv_indices, params.ckv_stride_page, params.ckv_stride_n, + params.kpe_stride_page, params.kpe_stride_n, kv_tile_to_packed_base(kv_tile_idx - 1), + packed_kv_bound, block_size, tid); + } __syncthreads(); compute_qk_hip( - s_frag, q_pe_frag, q_nope_frag, smem_storage.ckv_smem[0], smem_storage.kpe_smem[0], q_row, - col_group); + s_frag, q_pe_frag, q_nope_frag, smem_storage.ckv_smem[cur_stage][0], + smem_storage.kpe_smem[cur_stage][0], q_row, col_group); - logits_mask_hip(s_frag, qo_packed_idx_base, kv_tile_abs_start, q_len, kv_len, kv_end, - num_heads, lane_idx); + logits_mask_hip(s_frag, qo_packed_idx_base, kv_tile_to_abs_start(kv_tile_idx), q_len, + kv_len, kv_end, num_heads, lane_idx); update_mdo_states_hip(s_frag, o_frag, m_val, d_scalar, sm_scale_log2); compute_pv_hip( - o_frag, s_frag, smem_storage.ckv_smem[0], wave_idx, lane_idx); + o_frag, s_frag, smem_storage.ckv_smem[cur_stage][0], wave_idx, lane_idx); - if (kv_tile_idx > 0) __syncthreads(); + cur_stage = next_stage; } normalize_d_hip(o_frag, m_val, d_scalar); From ba10f705a5db4cba44c211a3cc14ffef014b8316 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 12 May 2026 15:59:43 +0000 Subject: [PATCH 09/19] =?UTF-8?q?feat(hip):=20MLA=20Phase=203=20=E2=80=94?= =?UTF-8?q?=20extended=20tests,=20AOT=20bake,=20README=20update?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tests/rocm_tests/test_deepseek_mla_hip.py — adds DeepSeek-shape coverage: fp16 + bf16, varying batch/kv_len/page_size, LSE return, pre-allocated output buffers, and a varlen-kv test where each request has a different kv_len. 132 new test cases, all passing. Restricted to qo_len = 1: the Phase-1/2 kernel uses CTA_TILE_Q = 16, while the shared MLA planner assumes 64. For decode (qo_len * num_heads ≤ 16) the planner generates one work item per batch and CTA_TILE_Q = 16 covers it exactly; prefill needs an inner Q-sub-tile loop in the kernel. Documented in the test as a TODO. * flashinfer/aot_hip.py — registers the HIP MLA module (head_dim_ckv=512, head_dim_kpe=64, fp16/bf16) so AOT bake produces a cached .so. Existing test_aot_hip.py tests still pass. * README.md — feature matrix now lists MLA as supported (decode-path on CDNA3) instead of "Not Yet Ported". Co-Authored-By: Claude Sonnet 4.6 --- README.md | 2 +- flashinfer/aot_hip.py | 15 ++ tests/rocm_tests/test_deepseek_mla_hip.py | 290 ++++++++++++++++++++++ 3 files changed, 306 insertions(+), 1 deletion(-) create mode 100644 tests/rocm_tests/test_deepseek_mla_hip.py diff --git a/README.md b/README.md index 1446c2f4d9..1e70336cdd 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ to its corresponding upstream tag (e.g., `0.2.5+amd.2` is second release of amd- | **Decode Attention** | ✅ | ✅ | No | Supports MHA, GQA, and MQA | | **Prefill Attention** | ✅ | WIP | ✅ | Supports MHA, GQA, and MQA | | **Cascade Attention** | TBD | TBD | No | Not Yet Ported | -| **MLA** | TBD | TBD | No | Not Yet Ported | +| **MLA** | ✅ | TBD | No | Decode-path (qo_len=1) on CDNA3; prefill TODO | | **POD** | TBD | TBD | No | Not Yet Ported | | **Positional Encoding** | TBD | TBD | No | Not Yet Ported | | **Sampling** | ✅ | TBD | No | Supports Top-K/Top-P Sampling/OnlineSoftmax/SamplingFromLogits | diff --git a/flashinfer/aot_hip.py b/flashinfer/aot_hip.py index 5912977f27..6e8ba20183 100644 --- a/flashinfer/aot_hip.py +++ b/flashinfer/aot_hip.py @@ -118,6 +118,21 @@ def gen_attention( use_logits_soft_cap=use_logits_soft_cap, ) + # MLA (DeepSeek shapes: head_dim_ckv=512, head_dim_kpe=64). + from .jit.attention import gen_batch_mla_module + + for dtype in f16_dtype_: + yield gen_batch_mla_module( + backend="hip", + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + dtype_idx=torch.int32, + head_dim_ckv=512, + head_dim_kpe=64, + use_profiler=False, + ) + def gen_all_modules( f16_dtype_: List[torch.dtype], diff --git a/tests/rocm_tests/test_deepseek_mla_hip.py b/tests/rocm_tests/test_deepseek_mla_hip.py new file mode 100644 index 0000000000..c1b31d5ab3 --- /dev/null +++ b/tests/rocm_tests/test_deepseek_mla_hip.py @@ -0,0 +1,290 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# DeepSeek-style HIP/ROCm tests for MLA: prefill (qo_len > 1) + decode (qo_len = 1), +# LSE return, pre-allocated output buffers, varying kv_lens per batch, +# fp16 + bf16. Mirrors the core checks in tests/attention/test_deepseek_mla.py +# but uses the HIP backend and the head dimensions DeepSeek actually ships. + +import math + +import pytest +import torch + +import flashinfer +import flashinfer.mla +from flashinfer.jit import build_jit_specs, gen_batch_mla_module + +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 + + +def _attention_ref(batch_size, q, k, v, causal, sm_scale): + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_heads = q.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_heads, head_dim_qk).float(), + ) + * sm_scale + ) + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_heads, head_dim_vo) + .to(q) + ) + # convert lse from natural log to log2 to match the kernel's return_lse_base_on_e=False default + return o_ref, lse_ref * math.log2(math.e) + + +def _kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): + bs_page_num, page_size, ckv_dim = ckv.shape + page_num = bs_page_num // batch_size + _, _, kpe_dim = kpe.shape + ckv = ckv.view(batch_size, page_num * page_size, ckv_dim)[:, :kv_len, :] + kpe = kpe.view(batch_size, page_num * page_size, kpe_dim)[:, :kv_len, :] + k = ( + torch.cat([ckv, kpe], dim=-1) + .unsqueeze(2) + .repeat(1, 1, num_heads, 1) + .view(batch_size * kv_len, num_heads, ckv_dim + kpe_dim) + ) + v = ( + ckv.unsqueeze(2) + .repeat(1, 1, num_heads, 1) + .view(batch_size * kv_len, num_heads, ckv_dim) + ) + return k, v + + +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + build_jit_specs( + [ + gen_batch_mla_module( + "hip", + torch.float16, + torch.float16, + torch.float16, + torch.int32, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + False, + ), + gen_batch_mla_module( + "hip", + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + torch.int32, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + False, + ), + ], + verbose=False, + ) + yield + + +@pytest.mark.parametrize("batch_size", [1, 3, 7]) +@pytest.mark.parametrize("kv_len", [17, 33, 96, 514, 1024]) +@pytest.mark.parametrize("qo_len", [1]) +@pytest.mark.parametrize("num_heads", [16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("page_size", [1, 16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_batch_mla_page_attention( + batch_size, kv_len, qo_len, num_heads, causal, page_size, dtype +): + # The Phase-1/2 HIP MLA kernel uses CTA_TILE_Q = 16, while the shared + # MLA planner assumes CTA_TILE_Q = 64. As long as a single batch's + # packed_qo_len = qo_len * num_heads <= 16 the planner generates one + # work item per batch and our CTA covers it exactly. Prefill paths + # (qo_len > 1) need an inner Q-sub-tile loop in the kernel — TODO. + assert qo_len * num_heads <= 16, "Prefill MLA on HIP not yet supported" + if causal and qo_len > kv_len: + pytest.skip("qo_len > kv_len not supported for causal attention") + device = torch.device("cuda") + torch.manual_seed(42) + + q_nope = torch.randn( + batch_size * qo_len, num_heads, HEAD_DIM_CKV, dtype=dtype, device=device + ) + q_pe = torch.randn( + batch_size * qo_len, num_heads, HEAD_DIM_KPE, dtype=dtype, device=device + ) + pages_per_req = math.ceil(kv_len / page_size) + ckv = torch.randn( + batch_size * pages_per_req, page_size, HEAD_DIM_CKV, dtype=dtype, device=device + ) + kpe = torch.randn( + batch_size * pages_per_req, page_size, HEAD_DIM_KPE, dtype=dtype, device=device + ) + + sm_scale = 1.0 / ((128 + 64) ** 0.5) # head dim before matrix absorption + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") + + q_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) + * pages_per_req + ) + kv_indices = torch.arange( + 0, batch_size * pages_per_req, dtype=torch.int32, device=device + ) + kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device=device) + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + num_heads, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + page_size, + causal, + sm_scale, + q_nope.dtype, + ckv.dtype, + ) + o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) + + k, v = _kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) + q = torch.cat([q_nope, q_pe], dim=-1) + o_ref, lse_ref = _attention_ref(batch_size, q, k, v, causal, sm_scale) + lse_ref = lse_ref.flatten(0, 1) + + rtol, atol = (1e-3, 1e-3) if dtype == torch.float16 else (1e-2, 1e-2) + torch.testing.assert_close(o, o_ref, rtol=rtol, atol=atol) + torch.testing.assert_close(lse, lse_ref, rtol=rtol, atol=atol) + + # Pre-allocated output buffers must produce identical results. + o_buf = torch.empty_like(o) + lse_buf = torch.empty_like(lse) + wrapper.run(q_nope, q_pe, ckv, kpe, out=o_buf, lse=lse_buf) + torch.testing.assert_close(o, o_buf, rtol=rtol, atol=atol) + torch.testing.assert_close(lse, lse_buf, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batch_size", [3, 5]) +@pytest.mark.parametrize( + "kv_lens_list", + [ + [17, 33, 79], + [96, 514, 2048], + [128, 256, 512, 1024, 2048], + ], + ids=lambda v: "x".join(str(x) for x in v), +) +@pytest.mark.parametrize("qo_len", [1]) +@pytest.mark.parametrize("causal", [False, True]) +def test_batch_mla_varlen_kv(batch_size, kv_lens_list, qo_len, causal): + """Each request in the batch has a different kv_len (still page_size=1).""" + if causal and qo_len > min(kv_lens_list): + pytest.skip("qo_len > min(kv_len) not supported for causal attention") + device = torch.device("cuda") + torch.manual_seed(0) + + num_heads = 16 + page_size = 1 + dtype = torch.float16 + # Repeat the kv_lens_list `batch_size` times along the request axis. + kv_lens_full = (kv_lens_list * batch_size)[: batch_size * len(kv_lens_list)] + n_req = len(kv_lens_full) + pages = [math.ceil(kv / page_size) for kv in kv_lens_full] + total_pages = sum(pages) + + q_nope = torch.randn( + n_req * qo_len, num_heads, HEAD_DIM_CKV, dtype=dtype, device=device + ) + q_pe = torch.randn( + n_req * qo_len, num_heads, HEAD_DIM_KPE, dtype=dtype, device=device + ) + ckv = torch.randn(total_pages, page_size, HEAD_DIM_CKV, dtype=dtype, device=device) + kpe = torch.randn(total_pages, page_size, HEAD_DIM_KPE, dtype=dtype, device=device) + + sm_scale = 1.0 / ((128 + 64) ** 0.5) + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") + + q_indptr = torch.arange(0, n_req + 1, dtype=torch.int32, device=device) * qo_len + page_indptr = torch.zeros(n_req + 1, dtype=torch.int32, device=device) + page_indptr[1:] = torch.tensor(pages, dtype=torch.int32, device=device).cumsum(0) + kv_indices = torch.arange(0, total_pages, dtype=torch.int32, device=device) + kv_lens = torch.tensor(kv_lens_full, dtype=torch.int32, device=device) + + wrapper.plan( + q_indptr, + page_indptr, + kv_indices, + kv_lens, + num_heads, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + page_size, + causal, + sm_scale, + q_nope.dtype, + ckv.dtype, + ) + o = wrapper.run(q_nope, q_pe, ckv, kpe) + + # Per-request reference (each request has its own kv_len so we can't batch + # them in a single attention_ref call with padding zeros — that would shift + # the softmax denominator. Compute and concatenate.) + out_chunks = [] + page_offset = 0 + qo_offset = 0 + for kv_len in kv_lens_full: + k_p = ckv[page_offset : page_offset + kv_len].view(1, kv_len, HEAD_DIM_CKV) + kpe_p = kpe[page_offset : page_offset + kv_len].view(1, kv_len, HEAD_DIM_KPE) + page_offset += kv_len + q_n = q_nope[qo_offset : qo_offset + qo_len] + q_p = q_pe[qo_offset : qo_offset + qo_len] + qo_offset += qo_len + if kv_len == 0: + out_chunks.append(torch.zeros_like(q_n)) + continue + # k = [ckv ; kpe] repeated across heads; v = ckv repeated across heads. + k = torch.cat([k_p, kpe_p], dim=-1).expand( + 1, kv_len, HEAD_DIM_CKV + HEAD_DIM_KPE + ) + k = ( + k.unsqueeze(2) + .repeat(1, 1, num_heads, 1) + .view(kv_len, num_heads, HEAD_DIM_CKV + HEAD_DIM_KPE) + ) + v = k_p.expand(1, kv_len, HEAD_DIM_CKV) + v = ( + v.unsqueeze(2) + .repeat(1, 1, num_heads, 1) + .view(kv_len, num_heads, HEAD_DIM_CKV) + ) + q = torch.cat([q_n, q_p], dim=-1) + ref, _ = _attention_ref(1, q, k, v, causal, sm_scale) + out_chunks.append(ref) + o_ref = torch.cat(out_chunks, dim=0) + + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) From 1cc73c165ef7bd0a557ca49febe15e5438759493 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 18 May 2026 18:46:34 +0000 Subject: [PATCH 10/19] fix(mypy): add partial_state to BatchPrefillWithPagedKVCacheWrapper.run overloads The HIP fused-cascade path in cascade.py passes partial_state=(out, lse) to wrapper.run(). The @overload stubs in prefill.py lacked this kwarg, causing mypy to report a "no matching overload" error on cascade.py:547. Add partial_state (Optional, default None) to both @overload stubs and the implementation signature. Co-Authored-By: Claude Sonnet 4.6 --- flashinfer/prefill.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 47d725c5d3..dbac6d4663 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1959,6 +1959,7 @@ def run( return_lse: Literal[False] = False, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, + partial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: ... @overload @@ -1974,6 +1975,7 @@ def run( return_lse: Literal[True] = True, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, + partial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... def run( @@ -1990,6 +1992,7 @@ def run( enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, sinks: Optional[torch.Tensor] = None, + partial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and paged kv-cache. From c4ebae779d64046a5f4322472b3522b6be916257 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 18 May 2026 19:12:25 +0000 Subject: [PATCH 11/19] refactor(hip): production cleanup of MLA kernel and tests - mla_hip.cuh: fix stale file header (remove "Phase 1", update Q-loading description); remove what-the-code-does comments (// 256, "Hoist the q_packed", "Q is loop-invariant", "Pack s_frag"); drop redundant section banners for normalize_d_hip, KernelHIP, LauncherHIP; simplify v_frag and p_frag packing to single uint2 reinterpret casts - test_mla_hip.py: remove inline "# num_heads" annotation; drop no-op double reshape (ckv/kpe passed directly to _mla_reference) - test_deepseek_mla_hip.py: remove "Phase-1/2" task reference from guard comment Co-Authored-By: Claude Sonnet 4.6 --- include/flashinfer/attention/mla_hip.cuh | 27 +++++------------------ tests/rocm_tests/test_deepseek_mla_hip.py | 9 ++++---- tests/rocm_tests/test_mla_hip.py | 10 +++------ 3 files changed, 12 insertions(+), 34 deletions(-) diff --git a/include/flashinfer/attention/mla_hip.cuh b/include/flashinfer/attention/mla_hip.cuh index e03d29135f..b8f73d805f 100644 --- a/include/flashinfer/attention/mla_hip.cuh +++ b/include/flashinfer/attention/mla_hip.cuh @@ -1,12 +1,12 @@ // SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. // SPDX-License-Identifier: Apache-2.0 // -// Phase 1 correctness-first HIP MLA kernel for CDNA3 (MI300X). +// HIP MLA kernel for CDNA3 (MI300X). // // CTA config: dim3(64,4,1) — 4 wavefronts, 256 threads. // All 4 wavefronts replicate QK independently; wave w owns // head_dim shard [w * (HEAD_DIM_CKV/4) .. (w+1) * (HEAD_DIM_CKV/4)). -// Q is re-read from global memory each mma_d step to avoid the 64 KB LDS limit. +// Q is loaded once into registers per work item (loop-invariant across KV tiles). // // MFMA layout note: CDNA3 mfma_f32_16x16x16f16 produces D in column-major // per-thread layout (each thread holds D[m=(t/16)*4+r][n=t%16] for r=0..3). @@ -35,7 +35,7 @@ namespace mla { constexpr uint32_t MLA_HIP_WAVE_SIZE = 64; constexpr uint32_t MLA_HIP_NUM_WAVES = 4; -constexpr uint32_t MLA_HIP_NUM_THREADS = MLA_HIP_WAVE_SIZE * MLA_HIP_NUM_WAVES; // 256 +constexpr uint32_t MLA_HIP_NUM_THREADS = MLA_HIP_WAVE_SIZE * MLA_HIP_NUM_WAVES; constexpr uint32_t MLA_HIP_CTA_TILE_Q = 16; constexpr uint32_t MLA_HIP_CTA_TILE_KV = 16; // Sentinel for the running max in online softmax. Matches the CUDA path @@ -302,13 +302,11 @@ __device__ __forceinline__ void compute_pv_hip(float (*o_frag)[4], const float* uint32_t col_group = lane_idx / 16; uint32_t d_start = wave_idx * NUM_MMA_D_PER_WAVE; - // Pack s_frag (4 floats) into 2 uint32_t in fp16 format for MFMA B input DTypeKV p_f16[4]; #pragma unroll for (uint32_t r = 0; r < 4; ++r) p_f16[r] = static_cast(s_frag[r]); uint32_t p_frag[2]; - p_frag[0] = *reinterpret_cast(&p_f16[0]); - p_frag[1] = *reinterpret_cast(&p_f16[2]); + *reinterpret_cast(p_frag) = *reinterpret_cast(p_f16); #pragma unroll for (uint32_t wave_mma_d = 0; wave_mma_d < NUM_MMA_D_PER_WAVE; ++wave_mma_d) { @@ -322,18 +320,12 @@ __device__ __forceinline__ void compute_pv_hip(float (*o_frag)[4], const float* v_vals[j] = ckv_smem[(col_group * 4 + j) * (HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD) + d_col]; } uint32_t v_frag[2]; - v_frag[0] = (uint32_t)(*reinterpret_cast(&v_vals[0])) | - ((uint32_t)(*reinterpret_cast(&v_vals[1])) << 16); - v_frag[1] = (uint32_t)(*reinterpret_cast(&v_vals[2])) | - ((uint32_t)(*reinterpret_cast(&v_vals[3])) << 16); + *reinterpret_cast(v_frag) = *reinterpret_cast(v_vals); gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32(o_frag[wave_mma_d], v_frag, p_frag); } } -// --------------------------------------------------------------------------- -// normalize_d_hip: divide o_frag by the softmax denominator. -// --------------------------------------------------------------------------- template __device__ __forceinline__ void normalize_d_hip(float (*o_frag)[4], float m_val, float d_scalar) { float d_rcp = (m_val != MLA_HIP_NEG_INF) ? math::ptx_rcp(d_scalar) : 0.f; @@ -470,9 +462,6 @@ __device__ void DevicePersistentMergeStatesHIP( } } -// --------------------------------------------------------------------------- -// BatchMLAPagedAttentionKernelHIP: main GPU kernel -// --------------------------------------------------------------------------- template __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKernelHIP( Params params) { @@ -522,14 +511,11 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * MLA_HIP_CTA_TILE_Q; - // Hoist the q_packed → (batch_idx, head_idx) divmod out of the kv-tile loop; - // it's the same for every iteration of every helper that consumed it. const uint32_t q_packed = qo_packed_idx_base + q_row; uint32_t batch_idx, head_idx; num_heads.divmod(q_packed, batch_idx, head_idx); const bool q_valid = (batch_idx < q_len); - // Q is loop-invariant across kv-tiles; load all fragments once into registers. load_q_frags_hip( q_pe_frag, q_nope_frag, params.q_nope + q_indptr * params.q_nope_stride_n, params.q_pe + q_indptr * params.q_pe_stride_n, params.q_nope_stride_n, @@ -628,9 +614,6 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer params.return_lse_base_on_e); } -// --------------------------------------------------------------------------- -// BatchMLAPagedAttentionHIP: host-side launcher -// --------------------------------------------------------------------------- template hipError_t BatchMLAPagedAttentionHIP(Params params, uint32_t num_blks_x, uint32_t num_blks_y, hipStream_t stream) { diff --git a/tests/rocm_tests/test_deepseek_mla_hip.py b/tests/rocm_tests/test_deepseek_mla_hip.py index c1b31d5ab3..12ea41c5a5 100644 --- a/tests/rocm_tests/test_deepseek_mla_hip.py +++ b/tests/rocm_tests/test_deepseek_mla_hip.py @@ -114,11 +114,10 @@ def warmup_jit(): def test_batch_mla_page_attention( batch_size, kv_len, qo_len, num_heads, causal, page_size, dtype ): - # The Phase-1/2 HIP MLA kernel uses CTA_TILE_Q = 16, while the shared - # MLA planner assumes CTA_TILE_Q = 64. As long as a single batch's - # packed_qo_len = qo_len * num_heads <= 16 the planner generates one - # work item per batch and our CTA covers it exactly. Prefill paths - # (qo_len > 1) need an inner Q-sub-tile loop in the kernel — TODO. + # The HIP MLA kernel uses CTA_TILE_Q = 16, while the MLA planner assumes + # CTA_TILE_Q = 64. For decode (packed_qo_len = qo_len * num_heads <= 16) + # the planner generates one work item per batch and the CTA covers it exactly. + # Prefill (qo_len > 1) requires an inner Q-sub-tile loop — not yet implemented. assert qo_len * num_heads <= 16, "Prefill MLA on HIP not yet supported" if causal and qo_len > kv_len: pytest.skip("qo_len > kv_len not supported for causal attention") diff --git a/tests/rocm_tests/test_mla_hip.py b/tests/rocm_tests/test_mla_hip.py index 2c24a8b185..5cccc27232 100644 --- a/tests/rocm_tests/test_mla_hip.py +++ b/tests/rocm_tests/test_mla_hip.py @@ -63,7 +63,7 @@ def _plan(wrapper, batch_size, kv_len, page_size, causal, dtype): kv_indptr, kv_indices, kv_lens, - 16, # num_heads + 16, HEAD_DIM_CKV, HEAD_DIM_KPE, page_size, @@ -155,12 +155,8 @@ def test_batch_mla_correctness(batch_size, kv_len, page_size, causal, dtype): ref = _mla_reference( q_nope, q_pe, - ckv.reshape(batch_size, pages_per_req, page_size, HEAD_DIM_CKV).reshape( - batch_size * pages_per_req, page_size, HEAD_DIM_CKV - ), - kpe.reshape(batch_size, pages_per_req, page_size, HEAD_DIM_KPE).reshape( - batch_size * pages_per_req, page_size, HEAD_DIM_KPE - ), + ckv, + kpe, kv_lens, page_size, sm_scale, From 1a0b5bcda4a51e23fcfc03db3069d1b55693a22f Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Mon, 18 May 2026 23:11:57 +0000 Subject: [PATCH 12/19] =?UTF-8?q?perf(hip):=20MLA=20Phase=202b=20=E2=80=94?= =?UTF-8?q?=20wave-specialized=20QK=20and=20Phase=203=20completion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wave-specialized QK (avoid 4× redundant MFMA work per CTA): - Only wave 0 computes QK (q_pe·kpe^T + q_nope·ckv^T) and online softmax - Wave 0 broadcasts p_frag + m/o_scale/d scalars via new WaveSpecStageMLAHIP in LDS; waves 1..3 read and skip QK entirely (~75% MFMA reduction) - New helpers: broadcast_softmax_hip (wave 0) / read_softmax_hip (waves 1..3) - Removes update_mdo_states_hip (logic inlined into broadcast_softmax_hip) - LDS p_frag row stride = CTA_TILE_KV + 1 = 17 to avoid worst bank conflicts - All 169 existing MLA tests pass unchanged Phase 3 completion items: - tests/rocm_tests/test_mla_page_hip.py: port of test_mla_page.py; tests the append_paged_mla_kv_cache write path on HIP (18 test cases) - benchmarks/rocm_benchmarks/bench_mla_hip.py: MLA decode benchmark for HIP with --batch/--seq/--heads/--dtype CLI flags, reports BW + TFLOPS - tests/rocm_tests/test_deepseek_mla_hip.py: add pytest.mark.slow to kv_len={4096,8192} and large kv_lens_list configs - flashinfer/__init__.py: export append_paged_mla_kv_cache in the HIP branch Co-Authored-By: Claude Sonnet 4.6 --- benchmarks/rocm_benchmarks/bench_mla_hip.py | 147 ++++++++++++++++++++ flashinfer/__init__.py | 1 + include/flashinfer/attention/mla_hip.cuh | 98 +++++++++---- tests/rocm_tests/test_deepseek_mla_hip.py | 17 ++- tests/rocm_tests/test_mla_page_hip.py | 101 ++++++++++++++ 5 files changed, 337 insertions(+), 27 deletions(-) create mode 100644 benchmarks/rocm_benchmarks/bench_mla_hip.py create mode 100644 tests/rocm_tests/test_mla_page_hip.py diff --git a/benchmarks/rocm_benchmarks/bench_mla_hip.py b/benchmarks/rocm_benchmarks/bench_mla_hip.py new file mode 100644 index 0000000000..48520eab51 --- /dev/null +++ b/benchmarks/rocm_benchmarks/bench_mla_hip.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# MLA decode benchmark for HIP/ROCm (DeepSeek shapes). +# Mirrors benchmarks/bench_deepseek_mla.py with backend="hip". +# +# Run: +# python benchmarks/rocm_benchmarks/bench_mla_hip.py +# python benchmarks/rocm_benchmarks/bench_mla_hip.py --batch 64 --seq 8192 --heads 128 + +import argparse +import math + +import numpy as np +import torch + +import flashinfer +import flashinfer.mla +from flashinfer.jit import build_jit_specs, gen_batch_mla_module +from flashinfer.testing.utils import bench_gpu_time + +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 + + +def _warmup_jit(dtype: torch.dtype) -> None: + build_jit_specs( + [ + gen_batch_mla_module( + "hip", + dtype, + dtype, + dtype, + torch.int32, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + False, + ) + ], + verbose=False, + ) + + +@torch.inference_mode() +def bench(batch_size: int, seq_len: int, num_heads: int, dtype: torch.dtype) -> None: + page_size = 1 + sm_scale = 1.0 / math.sqrt(HEAD_DIM_CKV + HEAD_DIM_KPE) + + q_nope = torch.randn( + batch_size, num_heads, HEAD_DIM_CKV, dtype=dtype, device="cuda" + ) + q_pe = torch.randn(batch_size, num_heads, HEAD_DIM_KPE, dtype=dtype, device="cuda") + ckv = torch.randn( + batch_size * seq_len, page_size, HEAD_DIM_CKV, dtype=dtype, device="cuda" + ) + kpe = torch.randn( + batch_size * seq_len, page_size, HEAD_DIM_KPE, dtype=dtype, device="cuda" + ) + + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") + + q_indptr = torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") + kv_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * seq_len + ) + kv_indices = torch.arange(0, batch_size * seq_len, dtype=torch.int32, device="cuda") + kv_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + num_heads, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + page_size, + False, + sm_scale, + dtype, + dtype, + ) + + # Correctness smoke-check before timing. + o = wrapper.run(q_nope, q_pe, ckv, kpe) + assert o.shape == (batch_size, num_heads, HEAD_DIM_CKV) + + measurements = bench_gpu_time( + lambda: wrapper.run(q_nope, q_pe, ckv, kpe), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io_bytes = sum(t.numel() * t.element_size() for t in [q_nope, q_pe, ckv, kpe, o]) + # Two GEMMs: Q·KVᵀ (ckv + kpe dim) and P·V (ckv dim). + flops = 2 * batch_size * num_heads * (2 * HEAD_DIM_CKV + HEAD_DIM_KPE) * seq_len + + dtype_str = "fp16" if dtype == torch.float16 else "bf16" + print( + f"[{dtype_str}] batch={batch_size:>4d} seq={seq_len:>6d} heads={num_heads:>4d} | " + f"lat={ms * 1e3:.1f} µs | " + f"BW={io_bytes * 1e-6 / ms:.1f} GB/s | " + f"TFLOPS={flops * 1e-9 / ms:.3f}" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="MLA HIP decode benchmark") + parser.add_argument( + "--batch", type=int, default=0, help="single batch size (0 = sweep)" + ) + parser.add_argument("--seq", type=int, default=0, help="single seq len (0 = sweep)") + parser.add_argument( + "--heads", type=int, default=16, help="number of attention heads" + ) + parser.add_argument( + "--dtype", + choices=["fp16", "bf16", "both"], + default="fp16", + help="data type (default: fp16)", + ) + args = parser.parse_args() + + dtypes = [] + if args.dtype in ("fp16", "both"): + dtypes.append(torch.float16) + if args.dtype in ("bf16", "both"): + dtypes.append(torch.bfloat16) + + batch_sizes = [args.batch] if args.batch > 0 else [64, 128, 768] + seq_lens = [args.seq] if args.seq > 0 else [1024, 2048, 8192] + + for dtype in dtypes: + print( + f"\n=== Warming up JIT ({('fp16' if dtype == torch.float16 else 'bf16')}) ===" + ) + _warmup_jit(dtype) + print("=== Benchmarking ===") + for seq_len in seq_lens: + for batch_size in batch_sizes: + bench(batch_size, seq_len, args.heads, dtype) + + +if __name__ == "__main__": + main() diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 0d7cca81ff..b800a4c9e6 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -199,6 +199,7 @@ from .norm import layernorm as layernorm from .norm import rmsnorm as rmsnorm from .page import append_paged_kv_cache as append_paged_kv_cache + from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache from .page import get_batch_indices_positions as get_batch_indices_positions from .page import get_seq_lens as get_seq_lens from .prefill_rocm import ( # type: ignore[assignment] diff --git a/include/flashinfer/attention/mla_hip.cuh b/include/flashinfer/attention/mla_hip.cuh index b8f73d805f..c801497d1f 100644 --- a/include/flashinfer/attention/mla_hip.cuh +++ b/include/flashinfer/attention/mla_hip.cuh @@ -4,8 +4,10 @@ // HIP MLA kernel for CDNA3 (MI300X). // // CTA config: dim3(64,4,1) — 4 wavefronts, 256 threads. -// All 4 wavefronts replicate QK independently; wave w owns -// head_dim shard [w * (HEAD_DIM_CKV/4) .. (w+1) * (HEAD_DIM_CKV/4)). +// Wave 0 computes QK (q_pe·kpe^T + q_nope·ckv^T) and softmax, then broadcasts +// the softmax weights (p_frag) and updated m/d scalars via LDS to waves 1..3. +// All 4 waves then compute their PV head_dim shard in parallel. +// Wave w owns head_dim shard [w*(HEAD_DIM_CKV/4) .. (w+1)*(HEAD_DIM_CKV/4)). // Q is loaded once into registers per work item (loop-invariant across KV tiles). // // MFMA layout note: CDNA3 mfma_f32_16x16x16f16 produces D in column-major @@ -50,10 +52,25 @@ constexpr float MLA_HIP_NEG_INF = -5e4f; // reads spread across banks {b, b+8, b+16, b+24}. constexpr uint32_t MLA_HIP_CKV_LDS_PAD = 8; constexpr uint32_t MLA_HIP_NUM_STAGES = 2; + +// Wave-specialization LDS stage: wave 0 writes softmax weights and online-softmax +// scalars here after QK; waves 1..3 read them to skip QK MFMAs entirely. +// p_frag row stride is 17 (= CTA_TILE_KV + 1) to reduce LDS bank conflicts: +// thread (q_row=a, col_group=b) maps to float offset a*17 + b*4, giving a more +// uniform bank distribution than stride=16 which produces 8-way conflicts. +constexpr uint32_t MLA_HIP_P_FRAG_STRIDE = MLA_HIP_CTA_TILE_KV + 1; +struct WaveSpecStageMLAHIP { + float p_frag[MLA_HIP_CTA_TILE_Q][MLA_HIP_P_FRAG_STRIDE]; + float m_stage[MLA_HIP_CTA_TILE_Q]; + float o_scale_stage[MLA_HIP_CTA_TILE_Q]; + float d_stage[MLA_HIP_CTA_TILE_Q]; +}; + template struct SharedStorageMLAHIP { DTypeKV ckv_smem[MLA_HIP_NUM_STAGES][MLA_HIP_CTA_TILE_KV][HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD]; DTypeKV kpe_smem[MLA_HIP_NUM_STAGES][MLA_HIP_CTA_TILE_KV][HEAD_DIM_KPE]; + WaveSpecStageMLAHIP wave_spec; }; // --------------------------------------------------------------------------- @@ -244,32 +261,28 @@ __device__ __forceinline__ void logits_mask_hip(float* s_frag, uint32_t qo_packe } // --------------------------------------------------------------------------- -// update_mdo_states_hip: online softmax — rescale (m, d, o), then exp(s - m) -// in-place into s_frag and accumulate the row sum into d_scalar. -// -// Layout: each thread holds 4 kv values for one q_row (s_frag[r] = S[t%16][col_group*4+r]). -// Row max/sum reduce within the thread, then butterfly-XOR over bits 4 and 5 -// (the col_group dimension) to combine across all 4 threads with the same q_row. +// broadcast_softmax_hip: called by wave 0 after compute_qk + logits_mask. +// Inlines update_mdo_states, captures o_scale, and writes p_frag + scalars to +// the wave_spec LDS stage for waves 1..3 to consume via read_softmax_hip. // --------------------------------------------------------------------------- -template -__device__ __forceinline__ void update_mdo_states_hip(float* s_frag, float (*o_frag)[4], +template +__device__ __forceinline__ void broadcast_softmax_hip(float* s_frag, float (*o_frag)[4], float& m_val, float& d_scalar, - float sm_scale_log2) { + WaveSpecStageMLAHIP& stage, uint32_t q_row, + uint32_t col_group, float sm_scale_log2) { float m_local = fmaxf(fmaxf(s_frag[0], s_frag[1]), fmaxf(s_frag[2], s_frag[3])); m_local = fmaxf(m_local, math::shfl_xor_sync(m_local, 0x10)); m_local = fmaxf(m_local, math::shfl_xor_sync(m_local, 0x20)); float m_prev = m_val; m_val = fmaxf(m_prev, m_local); - float o_scale = math::ptx_exp2((m_prev - m_val) * sm_scale_log2); d_scalar *= o_scale; #pragma unroll - for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) { + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) #pragma unroll for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] *= o_scale; - } const float m_scaled = m_val * sm_scale_log2; float partial_d = 0.f; @@ -282,11 +295,40 @@ __device__ __forceinline__ void update_mdo_states_hip(float* s_frag, float (*o_f partial_d += math::shfl_xor_sync(partial_d, 0x10); partial_d += math::shfl_xor_sync(partial_d, 0x20); d_scalar += partial_d; + + // Broadcast p_frag and per-q_row scalars to LDS for waves 1..3. +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) stage.p_frag[q_row][col_group * 4 + r] = s_frag[r]; + if (col_group == 0) { + stage.m_stage[q_row] = m_val; + stage.o_scale_stage[q_row] = o_scale; + stage.d_stage[q_row] = d_scalar; + } +} + +// --------------------------------------------------------------------------- +// read_softmax_hip: called by waves 1..3 after the __syncthreads() following +// broadcast_softmax_hip. Reads p_frag and applies the o_scale / state update +// that wave 0 already applied to its own o_frag. +// --------------------------------------------------------------------------- +template +__device__ __forceinline__ void read_softmax_hip(float* s_frag, float (*o_frag)[4], float& m_val, + float& d_scalar, const WaveSpecStageMLAHIP& stage, + uint32_t q_row, uint32_t col_group) { +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) s_frag[r] = stage.p_frag[q_row][col_group * 4 + r]; + float o_scale = stage.o_scale_stage[q_row]; +#pragma unroll + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] *= o_scale; + m_val = stage.m_stage[q_row]; + d_scalar = stage.d_stage[q_row]; } // --------------------------------------------------------------------------- // compute_pv_hip: accumulate P*V into o_frag for this wave's head_dim shard. -// s_frag must already hold exp-scaled values (output of update_mdo_states_hip). +// s_frag must already hold exp-scaled softmax weights (output of broadcast/read_softmax_hip). // // As with compute_qk_hip, we swap A/B so the column-major D output gives a // row-major o_frag: thread t's r-th register = O[q=t%16][d=mma_d_abs*16+col_group*4+r]. @@ -573,15 +615,23 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer __syncthreads(); - compute_qk_hip( - s_frag, q_pe_frag, q_nope_frag, smem_storage.ckv_smem[cur_stage][0], - smem_storage.kpe_smem[cur_stage][0], q_row, col_group); - - logits_mask_hip(s_frag, qo_packed_idx_base, kv_tile_to_abs_start(kv_tile_idx), q_len, - kv_len, kv_end, num_heads, lane_idx); - - update_mdo_states_hip(s_frag, o_frag, m_val, d_scalar, - sm_scale_log2); + // Wave 0 computes QK and softmax, then broadcasts via LDS. + // Waves 1..3 skip QK (saves 75% of MFMA work) and read p_frag from LDS. + if (wave_idx == 0) { + compute_qk_hip( + s_frag, q_pe_frag, q_nope_frag, smem_storage.ckv_smem[cur_stage][0], + smem_storage.kpe_smem[cur_stage][0], q_row, col_group); + logits_mask_hip(s_frag, qo_packed_idx_base, kv_tile_to_abs_start(kv_tile_idx), + q_len, kv_len, kv_end, num_heads, lane_idx); + broadcast_softmax_hip(s_frag, o_frag, m_val, d_scalar, + smem_storage.wave_spec, q_row, col_group, + sm_scale_log2); + } + __syncthreads(); // wave_spec visible to all waves + if (wave_idx > 0) { + read_softmax_hip(s_frag, o_frag, m_val, d_scalar, + smem_storage.wave_spec, q_row, col_group); + } compute_pv_hip( o_frag, s_frag, smem_storage.ckv_smem[cur_stage][0], wave_idx, lane_idx); diff --git a/tests/rocm_tests/test_deepseek_mla_hip.py b/tests/rocm_tests/test_deepseek_mla_hip.py index 12ea41c5a5..0a5287b0c3 100644 --- a/tests/rocm_tests/test_deepseek_mla_hip.py +++ b/tests/rocm_tests/test_deepseek_mla_hip.py @@ -105,7 +105,18 @@ def warmup_jit(): @pytest.mark.parametrize("batch_size", [1, 3, 7]) -@pytest.mark.parametrize("kv_len", [17, 33, 96, 514, 1024]) +@pytest.mark.parametrize( + "kv_len", + [ + 17, + 33, + 96, + 514, + 1024, + pytest.param(4096, marks=pytest.mark.slow), + pytest.param(8192, marks=pytest.mark.slow), + ], +) @pytest.mark.parametrize("qo_len", [1]) @pytest.mark.parametrize("num_heads", [16]) @pytest.mark.parametrize("causal", [False, True]) @@ -192,8 +203,8 @@ def test_batch_mla_page_attention( "kv_lens_list", [ [17, 33, 79], - [96, 514, 2048], - [128, 256, 512, 1024, 2048], + pytest.param([96, 514, 2048], marks=pytest.mark.slow), + pytest.param([128, 256, 512, 1024, 2048], marks=pytest.mark.slow), ], ids=lambda v: "x".join(str(x) for x in v), ) diff --git a/tests/rocm_tests/test_mla_page_hip.py b/tests/rocm_tests/test_mla_page_hip.py new file mode 100644 index 0000000000..8efcc506d5 --- /dev/null +++ b/tests/rocm_tests/test_mla_page_hip.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# HIP/ROCm test for append_paged_mla_kv_cache (MLA KV-cache write path). +# Mirrors tests/attention/test_mla_page.py. + +import math + +import pytest +import torch + +import flashinfer + +CKV_DIM = 512 +KPE_DIM = 64 + + +def _last_page_lens(kv_lens, page_size): + return [kl % page_size if kl % page_size != 0 else page_size for kl in kv_lens] + + +_KV_LEN_CONFIGS = [ + [45], + [4096], + [45, 8, 25], + [45, 8, 25, 22], + [45, 8, 25, 22, 400], + [45, 8, 25, 22, 100], +] + + +@pytest.mark.parametrize("kv_lens", _KV_LEN_CONFIGS) +@pytest.mark.parametrize("page_size", [1, 16, 64]) +def test_append_mla_paged_kv_cache(kv_lens, page_size): + device = torch.device("cuda") + nnz_kv = sum(kv_lens) + ckv_append = torch.randn(nnz_kv, CKV_DIM, dtype=torch.float16, device=device) + kpe_append = torch.randn(nnz_kv, KPE_DIM, dtype=torch.float16, device=device) + + num_pages_per_req = torch.tensor( + [math.ceil(kl / page_size) for kl in kv_lens], dtype=torch.int32, device=device + ) + kv_lens_t = torch.tensor(kv_lens, dtype=torch.int32, device=device) + kv_append_indptr = torch.cat( + [torch.zeros(1, dtype=torch.int32, device=device), kv_lens_t.cumsum(0).int()] + ) + + max_num_pages = int(num_pages_per_req.sum().item()) + ckv_cache = torch.zeros( + max_num_pages, page_size, CKV_DIM, dtype=torch.float16, device=device + ) + kpe_cache = torch.zeros( + max_num_pages, page_size, KPE_DIM, dtype=torch.float16, device=device + ) + + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + num_pages_per_req.cumsum(0).int(), + ] + ) + kv_page_indices = torch.arange(max_num_pages, dtype=torch.int32, device=device) + kv_last_page_len = torch.tensor( + _last_page_lens(kv_lens, page_size), dtype=torch.int32, device=device + ) + + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, + flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size), + nnz_kv, + ) + flashinfer.append_paged_mla_kv_cache( + ckv_append, + kpe_append, + batch_indices, + positions, + ckv_cache, + kpe_cache, + kv_page_indices, + kv_page_indptr, + kv_last_page_len, + ) + + ckv_flat = ckv_cache.view(-1, CKV_DIM) + kpe_flat = kpe_cache.view(-1, KPE_DIM) + + acc_kv = 0 + acc_pad = 0 + for i, kl in enumerate(kv_lens): + pages_i = int(num_pages_per_req[i].item()) + torch.testing.assert_close( + ckv_append[acc_kv : acc_kv + kl], ckv_flat[acc_pad : acc_pad + kl] + ) + torch.testing.assert_close( + kpe_append[acc_kv : acc_kv + kl], kpe_flat[acc_pad : acc_pad + kl] + ) + # Padding slots must remain zero. + assert torch.all(ckv_flat[acc_pad + kl : acc_pad + pages_i * page_size] == 0) + assert torch.all(kpe_flat[acc_pad + kl : acc_pad + pages_i * page_size] == 0) + acc_kv += kl + acc_pad += pages_i * page_size From f41fee036e98263ad11d760c28b2cede915100b1 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 19 May 2026 01:28:54 +0000 Subject: [PATCH 13/19] fix(hip): MLA cleanup and double-buffer race fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - mla_hip.cuh: add missing __syncthreads() after compute_pv_hip to prevent ckv_smem[cur_stage] being overwritten by the next iteration's load_kv_hip while slower waves still read it (race for num_kv_tiles≥3) - gen_batch_mla_module: validate backend=='hip' instead of backend!='auto' - gen_batch_mla_module: simplify extra_cuda_cflags to single expression - test_mla_hip.py: drop unused _make_wrapper params; remove unused import - test_deepseek_mla_hip.py: remove no-op list slice; remove unused import - bench_mla_hip.py: remove WHAT comment before shape assertion Co-Authored-By: Claude Sonnet 4.6 --- benchmarks/rocm_benchmarks/bench_mla_hip.py | 1 - flashinfer/jit/attention/modules_hip.py | 9 +++------ include/flashinfer/attention/mla_hip.cuh | 7 +++++++ tests/rocm_tests/test_deepseek_mla_hip.py | 4 +--- tests/rocm_tests/test_mla_hip.py | 7 +++---- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmarks/rocm_benchmarks/bench_mla_hip.py b/benchmarks/rocm_benchmarks/bench_mla_hip.py index 48520eab51..1a4b27f85d 100644 --- a/benchmarks/rocm_benchmarks/bench_mla_hip.py +++ b/benchmarks/rocm_benchmarks/bench_mla_hip.py @@ -82,7 +82,6 @@ def bench(batch_size: int, seq_len: int, num_heads: int, dtype: torch.dtype) -> dtype, ) - # Correctness smoke-check before timing. o = wrapper.run(q_nope, q_pe, ckv, kpe) assert o.shape == (batch_size, num_heads, HEAD_DIM_CKV) diff --git a/flashinfer/jit/attention/modules_hip.py b/flashinfer/jit/attention/modules_hip.py index c93ae7528b..ce740f1b14 100644 --- a/flashinfer/jit/attention/modules_hip.py +++ b/flashinfer/jit/attention/modules_hip.py @@ -976,8 +976,8 @@ def gen_batch_mla_module( head_dim_kpe: int, use_profiler: bool, ) -> JitSpec: - if backend == "auto": - raise ValueError("backend should not be auto when jit_args is provided") + if backend != "hip": + raise ValueError(f"MLA only supports backend='hip', got {backend!r}") uri = get_batch_mla_uri( backend, dtype_q, @@ -1018,8 +1018,5 @@ def gen_batch_mla_module( generated_config_path = gen_directory / "batch_mla_config.inc" write_if_different(generated_config_path, generated_inc_str) - extra_cuda_cflags = [] - if use_profiler: - extra_cuda_cflags += ["-DFLASHINFER_ENABLE_PROFILER"] - + extra_cuda_cflags = ["-DFLASHINFER_ENABLE_PROFILER"] if use_profiler else [] return gen_jit_spec(uri, source_paths, extra_cuda_cflags=extra_cuda_cflags) diff --git a/include/flashinfer/attention/mla_hip.cuh b/include/flashinfer/attention/mla_hip.cuh index c801497d1f..71dbd3a781 100644 --- a/include/flashinfer/attention/mla_hip.cuh +++ b/include/flashinfer/attention/mla_hip.cuh @@ -636,6 +636,13 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer compute_pv_hip( o_frag, s_frag, smem_storage.ckv_smem[cur_stage][0], wave_idx, lane_idx); + // All threads must finish PV reads from ckv_smem[cur_stage] before any thread + // starts the next iteration's load_kv_hip into next_stage. On CDNA3 the four + // wavefronts are not lockstep across MFMA; without this barrier, faster waves + // can overwrite the buffer (now the upcoming next_stage) while slower waves + // still read it — wrong attention whenever num_kv_tiles >= 3. + __syncthreads(); + cur_stage = next_stage; } diff --git a/tests/rocm_tests/test_deepseek_mla_hip.py b/tests/rocm_tests/test_deepseek_mla_hip.py index 0a5287b0c3..e9ba0e4c48 100644 --- a/tests/rocm_tests/test_deepseek_mla_hip.py +++ b/tests/rocm_tests/test_deepseek_mla_hip.py @@ -11,7 +11,6 @@ import pytest import torch -import flashinfer import flashinfer.mla from flashinfer.jit import build_jit_specs, gen_batch_mla_module @@ -220,8 +219,7 @@ def test_batch_mla_varlen_kv(batch_size, kv_lens_list, qo_len, causal): num_heads = 16 page_size = 1 dtype = torch.float16 - # Repeat the kv_lens_list `batch_size` times along the request axis. - kv_lens_full = (kv_lens_list * batch_size)[: batch_size * len(kv_lens_list)] + kv_lens_full = kv_lens_list * batch_size n_req = len(kv_lens_full) pages = [math.ceil(kv / page_size) for kv in kv_lens_full] total_pages = sum(pages) diff --git a/tests/rocm_tests/test_mla_hip.py b/tests/rocm_tests/test_mla_hip.py index 5cccc27232..67eb868b64 100644 --- a/tests/rocm_tests/test_mla_hip.py +++ b/tests/rocm_tests/test_mla_hip.py @@ -9,7 +9,6 @@ import pytest import torch -import flashinfer import flashinfer.mla from flashinfer.jit import build_jit_specs, gen_batch_mla_module from flashinfer.utils import determine_mla_backend @@ -38,7 +37,7 @@ def warmup_jit(): yield -def _make_wrapper(batch_size, dtype): +def _make_wrapper(): workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") return flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") @@ -124,7 +123,7 @@ def test_determine_mla_backend(): def test_batch_mla_plan(batch_size, kv_len, page_size, causal, dtype): if causal and kv_len == 0: pytest.skip("causal with kv_len=0 unsupported") - wrapper = _make_wrapper(batch_size, dtype) + wrapper = _make_wrapper() _plan(wrapper, batch_size, kv_len, page_size, causal, dtype) @@ -135,7 +134,7 @@ def test_batch_mla_plan(batch_size, kv_len, page_size, causal, dtype): @pytest.mark.parametrize("dtype", [torch.float16]) def test_batch_mla_correctness(batch_size, kv_len, page_size, causal, dtype): torch.manual_seed(42) - wrapper = _make_wrapper(batch_size, dtype) + wrapper = _make_wrapper() sm_scale = 1.0 / ((HEAD_DIM_CKV + HEAD_DIM_KPE) ** 0.5) kv_lens, pages_per_req = _plan( wrapper, batch_size, kv_len, page_size, causal, dtype From 29747bc5941e79af90d002e1cfe3c86c32ab3612 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 19 May 2026 01:40:50 +0000 Subject: [PATCH 14/19] fix(hip): address PR review nits - bench_mla_hip.py: remove redundant `import flashinfer` (covered by `import flashinfer.mla` which binds the flashinfer namespace as a side effect) - get_batch_mla_uri: use backend in URI suffix instead of hardcoding '_hip', so the parameter is no longer dead and the signature stays consistent with the CUDA version (which appends '_sm90' for fa3) Co-Authored-By: Claude Sonnet 4.6 --- benchmarks/rocm_benchmarks/bench_mla_hip.py | 1 - flashinfer/jit/attention/modules_hip.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmarks/rocm_benchmarks/bench_mla_hip.py b/benchmarks/rocm_benchmarks/bench_mla_hip.py index 1a4b27f85d..05891368d7 100644 --- a/benchmarks/rocm_benchmarks/bench_mla_hip.py +++ b/benchmarks/rocm_benchmarks/bench_mla_hip.py @@ -14,7 +14,6 @@ import numpy as np import torch -import flashinfer import flashinfer.mla from flashinfer.jit import build_jit_specs, gen_batch_mla_module from flashinfer.testing.utils import bench_gpu_time diff --git a/flashinfer/jit/attention/modules_hip.py b/flashinfer/jit/attention/modules_hip.py index ce740f1b14..b4556e4632 100644 --- a/flashinfer/jit/attention/modules_hip.py +++ b/flashinfer/jit/attention/modules_hip.py @@ -962,7 +962,7 @@ def get_batch_mla_uri( f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" f"head_dim_ckv_{head_dim_ckv}_" f"head_dim_kpe_{head_dim_kpe}_" - f"profiler_{use_profiler}_hip" + f"profiler_{use_profiler}_{backend}" ) From 5fb346f874408a06f928b466f3e35185cde844c8 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 19 May 2026 01:56:27 +0000 Subject: [PATCH 15/19] style(hip): trim what-commentary and inline single-use variable in mla_hip.cuh Remove separator lines and WHAT-opener sentences from all device-function block comments, keeping only non-obvious WHY notes (MFMA A/B swap derivation, LDS bank-conflict padding, wave-spec broadcast protocol, partial-indptr sentinel). Inline the single-use `has_next` boolean and drop the duplicate wave-specialisation comment inside the KV tile loop. Co-Authored-By: Claude Sonnet 4.6 --- include/flashinfer/attention/mla_hip.cuh | 51 ++---------------------- 1 file changed, 3 insertions(+), 48 deletions(-) diff --git a/include/flashinfer/attention/mla_hip.cuh b/include/flashinfer/attention/mla_hip.cuh index 71dbd3a781..ac3806cb0b 100644 --- a/include/flashinfer/attention/mla_hip.cuh +++ b/include/flashinfer/attention/mla_hip.cuh @@ -73,11 +73,6 @@ struct SharedStorageMLAHIP { WaveSpecStageMLAHIP wave_spec; }; -// --------------------------------------------------------------------------- -// load_kv_hip: cooperative KV tile load, all 256 threads participate. -// packed_kv_tile_base = kv_indptr * block_size + kv_tile_abs_start -// packed_kv_bound = kv_indptr * block_size + kv_len -// --------------------------------------------------------------------------- template __device__ __forceinline__ void load_kv_hip( DTypeKV* __restrict__ ckv_smem, DTypeKV* __restrict__ kpe_smem, @@ -150,15 +145,10 @@ __device__ __forceinline__ void load_kv_hip( } } -// --------------------------------------------------------------------------- -// load_q_frags_hip: read this thread's Q fragments once from global memory. -// Q is loop-invariant across the kv_tile loop, so caching saves repeated -// global loads (per work_idx, save (num_kv_tiles - 1) reloads of all Q). -// +// Q is loop-invariant across the kv_tile loop; caching saves (num_kv_tiles - 1) reloads. // q_pe_frag[mma_d][2] holds Q_pe[q=t%16][d=mma_d*16 + col_group*4 + 0..3] as 4 fp16. // q_nope_frag is the same for Q_nope across NUM_MMA_D_CKV head_dim tiles. // Out-of-range threads (batch_idx >= q_len) get zero-filled fragments. -// --------------------------------------------------------------------------- template __device__ __forceinline__ void load_q_frags_hip(uint32_t (*q_pe_frag)[2], uint32_t (*q_nope_frag)[2], const DTypeQ* q_nope, @@ -192,9 +182,6 @@ __device__ __forceinline__ void load_q_frags_hip(uint32_t (*q_pe_frag)[2], } } -// --------------------------------------------------------------------------- -// compute_qk_hip: accumulate s_frag = Q_pe*KPE^T + Q_nope*CKV^T -// // CDNA3 MFMA computes D = A*B where the D output is column-major per-thread: // thread t holds D[m=(t/16)*4+r][n=t%16]. To get the natural row-major s_frag // layout (thread t holds s_frag[r] = S[q=t%16][kv=col_group*4+r]) we swap the @@ -202,7 +189,6 @@ __device__ __forceinline__ void load_q_frags_hip(uint32_t (*q_pe_frag)[2], // D[m=kv][n=q] = sum_d K[kv=m][d=k] * Q[q=n][d=k] = (K@Q^T)[kv][q] = S[q][kv]. // Combined with the column-major D layout, thread t's r-th register = // S[q=t%16][kv=(t/16)*4+r]. -// --------------------------------------------------------------------------- template __device__ __forceinline__ void compute_qk_hip(float* s_frag, const uint32_t (*q_pe_frag)[2], const uint32_t (*q_nope_frag)[2], @@ -239,10 +225,7 @@ __device__ __forceinline__ void compute_qk_hip(float* s_frag, const uint32_t (*q } } -// --------------------------------------------------------------------------- -// logits_mask_hip: zero out scores for out-of-bounds or causally-masked positions. // s_frag[r] = S[q_row=lane%16][kv_col = kv_idx_base + (lane/16)*4 + r] -// --------------------------------------------------------------------------- template __device__ __forceinline__ void logits_mask_hip(float* s_frag, uint32_t qo_packed_idx_base, uint32_t kv_idx_base, uint32_t q_len, @@ -260,11 +243,6 @@ __device__ __forceinline__ void logits_mask_hip(float* s_frag, uint32_t qo_packe } } -// --------------------------------------------------------------------------- -// broadcast_softmax_hip: called by wave 0 after compute_qk + logits_mask. -// Inlines update_mdo_states, captures o_scale, and writes p_frag + scalars to -// the wave_spec LDS stage for waves 1..3 to consume via read_softmax_hip. -// --------------------------------------------------------------------------- template __device__ __forceinline__ void broadcast_softmax_hip(float* s_frag, float (*o_frag)[4], float& m_val, float& d_scalar, @@ -306,11 +284,6 @@ __device__ __forceinline__ void broadcast_softmax_hip(float* s_frag, float (*o_f } } -// --------------------------------------------------------------------------- -// read_softmax_hip: called by waves 1..3 after the __syncthreads() following -// broadcast_softmax_hip. Reads p_frag and applies the o_scale / state update -// that wave 0 already applied to its own o_frag. -// --------------------------------------------------------------------------- template __device__ __forceinline__ void read_softmax_hip(float* s_frag, float (*o_frag)[4], float& m_val, float& d_scalar, const WaveSpecStageMLAHIP& stage, @@ -326,16 +299,11 @@ __device__ __forceinline__ void read_softmax_hip(float* s_frag, float (*o_frag)[ d_scalar = stage.d_stage[q_row]; } -// --------------------------------------------------------------------------- -// compute_pv_hip: accumulate P*V into o_frag for this wave's head_dim shard. -// s_frag must already hold exp-scaled softmax weights (output of broadcast/read_softmax_hip). -// // As with compute_qk_hip, we swap A/B so the column-major D output gives a // row-major o_frag: thread t's r-th register = O[q=t%16][d=mma_d_abs*16+col_group*4+r]. // We pass A=V (loaded strided so f16x4[j]=V[kv=col_group*4+j][d=mma_d_abs*16+t%16]) // and B=P (s_frag, with f16x4[r]=P[q=t%16][kv=col_group*4+r]). // The math: D[m=d_local][n=q] = sum_kv V[kv][d] * P[q][kv] = O[q][d_global]. -// --------------------------------------------------------------------------- template __device__ __forceinline__ void compute_pv_hip(float (*o_frag)[4], const float* s_frag, const DTypeKV* ckv_smem, uint32_t wave_idx, @@ -378,15 +346,9 @@ __device__ __forceinline__ void normalize_d_hip(float (*o_frag)[4], float m_val, } } -// --------------------------------------------------------------------------- -// write_o_hip: store results to final_o or partial_o. -// // Each thread writes for q_row = lane_idx%16 and its wave's head_dim shard. // Only wave_idx==0 && col_group==0 (lane_idx<16) writes the LSE scalar. -// -// partial_o / partial_lse are the GLOBAL arrays; partial_indptr is the start -// row index into these arrays for this CTA's work item. Pass -1 if no partial. -// --------------------------------------------------------------------------- +// partial_o / partial_lse are the GLOBAL arrays; partial_indptr = -1 when single-tile. template __device__ __forceinline__ void write_o_hip( float (*o_frag)[4], float m_val, float d_scalar, DTypeO* final_o, float* final_lse, @@ -434,13 +396,9 @@ __device__ __forceinline__ void write_o_hip( } } -// --------------------------------------------------------------------------- -// DevicePersistentMergeStatesHIP: reduce partial outputs → final output. -// // partial_o[row * HEAD_DIM_CKV + col] stores DTypeO (normalized by each partial's d). // partial_lse[row] = log2(d) + m for each partial block. // Merge treats partial_lse as the combined log-sum-exp weight for each partial. -// --------------------------------------------------------------------------- template __device__ void DevicePersistentMergeStatesHIP( const IdType* merge_packed_offset_start, const IdType* merge_packed_offset_end, @@ -600,12 +558,11 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer uint32_t cur_stage = 0; for (int32_t kv_tile_idx = num_kv_tiles - 1; kv_tile_idx >= 0; --kv_tile_idx) { const uint32_t next_stage = 1 - cur_stage; - const bool has_next = (kv_tile_idx > 0); // Issue load for tile (kv_tile_idx - 1) into next_stage. Without // cp.async, this still blocks the wave but lets the compiler schedule // the global VGPR loads ahead of compute. - if (has_next) { + if (kv_tile_idx > 0) { load_kv_hip( smem_storage.ckv_smem[next_stage][0], smem_storage.kpe_smem[next_stage][0], params.ckv, params.kpe, params.kv_indices, params.ckv_stride_page, params.ckv_stride_n, @@ -615,8 +572,6 @@ __global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKer __syncthreads(); - // Wave 0 computes QK and softmax, then broadcasts via LDS. - // Waves 1..3 skip QK (saves 75% of MFMA work) and read p_frag from LDS. if (wave_idx == 0) { compute_qk_hip( s_frag, q_pe_frag, q_nope_frag, smem_storage.ckv_smem[cur_stage][0], From 05ed67bf37585d4584c11fcaff2bee14b4c4fcec Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 19 May 2026 01:56:33 +0000 Subject: [PATCH 16/19] refactor(hip): inline stride locals in batch_mla.cu Replace 10 single-use stride variables extracted before the DISPATCH_context lambda with direct Tensor::stride() calls inside the lambda. The lambda already captures by reference, so no semantic change; removes the intermediate layer of temporaries. Co-Authored-By: Claude Sonnet 4.6 --- flashinfer/csrc_rocm/batch_mla.cu | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/flashinfer/csrc_rocm/batch_mla.cu b/flashinfer/csrc_rocm/batch_mla.cu index a83ea7c7fb..69bc8e8d0d 100644 --- a/flashinfer/csrc_rocm/batch_mla.cu +++ b/flashinfer/csrc_rocm/batch_mla.cu @@ -58,17 +58,6 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int const c10::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(q_nope.device()); const hipStream_t stream = c10::hip::getCurrentHIPStream(); - unsigned int q_nope_stride_n = q_nope.stride(0); - unsigned int q_nope_stride_h = q_nope.stride(1); - unsigned int q_pe_stride_n = q_pe.stride(0); - unsigned int q_pe_stride_h = q_pe.stride(1); - unsigned int ckv_stride_page = ckv_cache.stride(0); - unsigned int ckv_stride_n = ckv_cache.stride(1); - unsigned int kpe_stride_page = kpe_cache.stride(0); - unsigned int kpe_stride_n = kpe_cache.stride(1); - unsigned int o_stride_n = o.stride(0); - unsigned int o_stride_h = o.stride(1); - DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] { Params params; @@ -112,16 +101,16 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int params.num_heads = uint_fastdiv(num_heads); params.block_size = uint_fastdiv(page_size); - params.q_nope_stride_n = q_nope_stride_n; - params.q_nope_stride_h = q_nope_stride_h; - params.q_pe_stride_n = q_pe_stride_n; - params.q_pe_stride_h = q_pe_stride_h; - params.ckv_stride_page = ckv_stride_page; - params.ckv_stride_n = ckv_stride_n; - params.kpe_stride_page = kpe_stride_page; - params.kpe_stride_n = kpe_stride_n; - params.o_stride_n = o_stride_n; - params.o_stride_h = o_stride_h; + params.q_nope_stride_n = q_nope.stride(0); + params.q_nope_stride_h = q_nope.stride(1); + params.q_pe_stride_n = q_pe.stride(0); + params.q_pe_stride_h = q_pe.stride(1); + params.ckv_stride_page = ckv_cache.stride(0); + params.ckv_stride_n = ckv_cache.stride(1); + params.kpe_stride_page = kpe_cache.stride(0); + params.kpe_stride_n = kpe_cache.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); params.sm_scale = static_cast(sm_scale); params.return_lse_base_on_e = return_lse_base_on_e; From 4642d3fe27fd5976aef61f264aa2638867fc0857 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 19 May 2026 01:56:39 +0000 Subject: [PATCH 17/19] style(hip): consolidate deferred imports in aot_hip.py and utils.py Move gen_batch_mla_module import to the top of gen_attention() so all deferred imports appear together rather than mid-function. Promote IS_HIP import in utils.py from a function-level deferred import inside determine_mla_backend() to module level, consistent with every other device_utils import in the file. Co-Authored-By: Claude Sonnet 4.6 --- flashinfer/aot_hip.py | 4 ++-- flashinfer/utils.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/flashinfer/aot_hip.py b/flashinfer/aot_hip.py index 6e8ba20183..58b801cede 100644 --- a/flashinfer/aot_hip.py +++ b/flashinfer/aot_hip.py @@ -95,6 +95,8 @@ def gen_attention( use_sliding_window_: List[bool], use_logits_soft_cap_: List[bool], ) -> Iterator: + from .jit.attention import gen_batch_mla_module + # FA2 MHA / MQA / GQA for ( (head_dim_qk, head_dim_vo), @@ -119,8 +121,6 @@ def gen_attention( ) # MLA (DeepSeek shapes: head_dim_ckv=512, head_dim_kpe=64). - from .jit.attention import gen_batch_mla_module - for dtype in f16_dtype_: yield gen_batch_mla_module( backend="hip", diff --git a/flashinfer/utils.py b/flashinfer/utils.py index d59707e5c2..7e038c9ee0 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -28,6 +28,7 @@ from torch.torch_version import __version__ as torch_version import inspect +from .device_utils import IS_HIP from .jit.spdlog import gen_spdlog_module @@ -640,8 +641,6 @@ def is_sm121a_supported(device: torch.device) -> bool: def determine_mla_backend(device: torch.device) -> str: - from .device_utils import IS_HIP - if IS_HIP: return "hip" return "fa3" if is_sm90a_supported(device) else "fa2" From 49e722ea4987a74aca4f798bda346409dfb5c4b5 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 19 May 2026 01:56:53 +0000 Subject: [PATCH 18/19] style(hip): clean up HIP MLA tests and benchmark test_mla_hip.py: - Add NUM_HEADS = 16 constant; replace magic literal in _plan and tensor allocs. - Rename causal -> _causal in _mla_reference (param is unused by design for decode-only reference). - Remove shape-annotation comments; keep the pages-are-sequential layout note. test_deepseek_mla_hip.py: - Remove @pytest.mark.parametrize("qo_len", [1]) from both test functions; inline qo_len = 1 in the function body. test_mla_page_hip.py: - Remove comment that restates the assert directly below it. bench_mla_hip.py: - Promote page_size = 1 to module-level _PAGE_SIZE constant. Co-Authored-By: Claude Sonnet 4.6 --- benchmarks/rocm_benchmarks/bench_mla_hip.py | 8 ++++---- tests/rocm_tests/test_deepseek_mla_hip.py | 9 ++++----- tests/rocm_tests/test_mla_hip.py | 15 +++++++-------- tests/rocm_tests/test_mla_page_hip.py | 1 - 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/benchmarks/rocm_benchmarks/bench_mla_hip.py b/benchmarks/rocm_benchmarks/bench_mla_hip.py index 05891368d7..9df82f73c5 100644 --- a/benchmarks/rocm_benchmarks/bench_mla_hip.py +++ b/benchmarks/rocm_benchmarks/bench_mla_hip.py @@ -20,6 +20,7 @@ HEAD_DIM_CKV = 512 HEAD_DIM_KPE = 64 +_PAGE_SIZE = 1 def _warmup_jit(dtype: torch.dtype) -> None: @@ -42,7 +43,6 @@ def _warmup_jit(dtype: torch.dtype) -> None: @torch.inference_mode() def bench(batch_size: int, seq_len: int, num_heads: int, dtype: torch.dtype) -> None: - page_size = 1 sm_scale = 1.0 / math.sqrt(HEAD_DIM_CKV + HEAD_DIM_KPE) q_nope = torch.randn( @@ -50,10 +50,10 @@ def bench(batch_size: int, seq_len: int, num_heads: int, dtype: torch.dtype) -> ) q_pe = torch.randn(batch_size, num_heads, HEAD_DIM_KPE, dtype=dtype, device="cuda") ckv = torch.randn( - batch_size * seq_len, page_size, HEAD_DIM_CKV, dtype=dtype, device="cuda" + batch_size * seq_len, _PAGE_SIZE, HEAD_DIM_CKV, dtype=dtype, device="cuda" ) kpe = torch.randn( - batch_size * seq_len, page_size, HEAD_DIM_KPE, dtype=dtype, device="cuda" + batch_size * seq_len, _PAGE_SIZE, HEAD_DIM_KPE, dtype=dtype, device="cuda" ) workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") @@ -74,7 +74,7 @@ def bench(batch_size: int, seq_len: int, num_heads: int, dtype: torch.dtype) -> num_heads, HEAD_DIM_CKV, HEAD_DIM_KPE, - page_size, + _PAGE_SIZE, False, sm_scale, dtype, diff --git a/tests/rocm_tests/test_deepseek_mla_hip.py b/tests/rocm_tests/test_deepseek_mla_hip.py index e9ba0e4c48..a1b3f34be8 100644 --- a/tests/rocm_tests/test_deepseek_mla_hip.py +++ b/tests/rocm_tests/test_deepseek_mla_hip.py @@ -116,19 +116,18 @@ def warmup_jit(): pytest.param(8192, marks=pytest.mark.slow), ], ) -@pytest.mark.parametrize("qo_len", [1]) @pytest.mark.parametrize("num_heads", [16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("page_size", [1, 16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_batch_mla_page_attention( - batch_size, kv_len, qo_len, num_heads, causal, page_size, dtype + batch_size, kv_len, num_heads, causal, page_size, dtype ): # The HIP MLA kernel uses CTA_TILE_Q = 16, while the MLA planner assumes # CTA_TILE_Q = 64. For decode (packed_qo_len = qo_len * num_heads <= 16) # the planner generates one work item per batch and the CTA covers it exactly. # Prefill (qo_len > 1) requires an inner Q-sub-tile loop — not yet implemented. - assert qo_len * num_heads <= 16, "Prefill MLA on HIP not yet supported" + qo_len = 1 if causal and qo_len > kv_len: pytest.skip("qo_len > kv_len not supported for causal attention") device = torch.device("cuda") @@ -207,10 +206,10 @@ def test_batch_mla_page_attention( ], ids=lambda v: "x".join(str(x) for x in v), ) -@pytest.mark.parametrize("qo_len", [1]) @pytest.mark.parametrize("causal", [False, True]) -def test_batch_mla_varlen_kv(batch_size, kv_lens_list, qo_len, causal): +def test_batch_mla_varlen_kv(batch_size, kv_lens_list, causal): """Each request in the batch has a different kv_len (still page_size=1).""" + qo_len = 1 if causal and qo_len > min(kv_lens_list): pytest.skip("qo_len > min(kv_len) not supported for causal attention") device = torch.device("cuda") diff --git a/tests/rocm_tests/test_mla_hip.py b/tests/rocm_tests/test_mla_hip.py index 67eb868b64..55a91674b2 100644 --- a/tests/rocm_tests/test_mla_hip.py +++ b/tests/rocm_tests/test_mla_hip.py @@ -15,6 +15,7 @@ HEAD_DIM_CKV = 512 HEAD_DIM_KPE = 64 +NUM_HEADS = 16 @pytest.fixture(autouse=True, scope="module") @@ -62,7 +63,7 @@ def _plan(wrapper, batch_size, kv_len, page_size, causal, dtype): kv_indptr, kv_indices, kv_lens, - 16, + NUM_HEADS, HEAD_DIM_CKV, HEAD_DIM_KPE, page_size, @@ -74,13 +75,11 @@ def _plan(wrapper, batch_size, kv_len, page_size, causal, dtype): return kv_lens, pages_per_req -def _mla_reference(q_nope, q_pe, ckv, kpe, kv_lens, page_size, sm_scale, causal): +def _mla_reference(q_nope, q_pe, ckv, kpe, kv_lens, page_size, sm_scale, _causal): """Pure-PyTorch MLA reference: S = q_pe @ kpe^T + q_nope @ ckv^T; O = softmax(S) @ ckv.""" batch_size, num_heads, _ = q_nope.shape dtype = q_nope.dtype - # ckv/kpe: [num_pages, page_size, head_dim] - # flatten to [batch_size, kv_len, head_dim] using kv_lens max_kv_len = int(kv_lens.max().item()) # pages are laid out sequentially per request pages_per_req = math.ceil(max_kv_len / page_size) @@ -92,9 +91,7 @@ def _mla_reference(q_nope, q_pe, ckv, kpe, kv_lens, page_size, sm_scale, causal) :, :max_kv_len, : ] - # q: [batch, heads, dim] K/V: [batch, kv_len, dim] scores: [batch, heads, kv_len] # decode is qo_len=1 so causal masking is a no-op; only kv-length padding matters. - del causal scores = torch.einsum("bhd,bsd->bhs", q_pe.float(), kpe_flat.float()) scores += torch.einsum("bhd,bsd->bhs", q_nope.float(), ckv_flat.float()) scores = scores * sm_scale @@ -140,8 +137,10 @@ def test_batch_mla_correctness(batch_size, kv_len, page_size, causal, dtype): wrapper, batch_size, kv_len, page_size, causal, dtype ) - q_nope = torch.randn(batch_size, 16, HEAD_DIM_CKV, dtype=dtype, device="cuda") - q_pe = torch.randn(batch_size, 16, HEAD_DIM_KPE, dtype=dtype, device="cuda") + q_nope = torch.randn( + batch_size, NUM_HEADS, HEAD_DIM_CKV, dtype=dtype, device="cuda" + ) + q_pe = torch.randn(batch_size, NUM_HEADS, HEAD_DIM_KPE, dtype=dtype, device="cuda") ckv = torch.randn( batch_size * pages_per_req, page_size, HEAD_DIM_CKV, dtype=dtype, device="cuda" ) diff --git a/tests/rocm_tests/test_mla_page_hip.py b/tests/rocm_tests/test_mla_page_hip.py index 8efcc506d5..44a132cf01 100644 --- a/tests/rocm_tests/test_mla_page_hip.py +++ b/tests/rocm_tests/test_mla_page_hip.py @@ -94,7 +94,6 @@ def test_append_mla_paged_kv_cache(kv_lens, page_size): torch.testing.assert_close( kpe_append[acc_kv : acc_kv + kl], kpe_flat[acc_pad : acc_pad + kl] ) - # Padding slots must remain zero. assert torch.all(ckv_flat[acc_pad + kl : acc_pad + pages_i * page_size] == 0) assert torch.all(kpe_flat[acc_pad + kl : acc_pad + pages_i * page_size] == 0) acc_kv += kl From 9e17034c4887eeba907f5c7449475c08af109461 Mon Sep 17 00:00:00 2001 From: Debasis Mandal Date: Tue, 19 May 2026 02:31:30 +0000 Subject: [PATCH 19/19] fix(hip): reject use_profiler=True in gen_batch_mla_module on HIP profiler.cuh contains NVIDIA PTX inline asm (%smid, globaltimer) that does not compile under HIP. Guard against accidental misuse by raising ValueError when use_profiler=True is passed on the hip backend, matching the existing backend validation pattern. After the guard, use_profiler is always False so the extra_cuda_cflags branch is dead; replace with an explicitly typed empty list to satisfy mypy. Co-Authored-By: Claude Sonnet 4.6 --- flashinfer/jit/attention/modules_hip.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flashinfer/jit/attention/modules_hip.py b/flashinfer/jit/attention/modules_hip.py index b4556e4632..9040c283ee 100644 --- a/flashinfer/jit/attention/modules_hip.py +++ b/flashinfer/jit/attention/modules_hip.py @@ -978,6 +978,10 @@ def gen_batch_mla_module( ) -> JitSpec: if backend != "hip": raise ValueError(f"MLA only supports backend='hip', got {backend!r}") + if use_profiler: + raise ValueError( + "use_profiler is not supported on HIP/ROCm — profiler.cuh uses NVIDIA PTX asm" + ) uri = get_batch_mla_uri( backend, dtype_q, @@ -1018,5 +1022,5 @@ def gen_batch_mla_module( generated_config_path = gen_directory / "batch_mla_config.inc" write_if_different(generated_config_path, generated_inc_str) - extra_cuda_cflags = ["-DFLASHINFER_ENABLE_PROFILER"] if use_profiler else [] + extra_cuda_cflags: list[str] = [] return gen_jit_spec(uri, source_paths, extra_cuda_cflags=extra_cuda_cflags)