diff --git a/README.md b/README.md index 1446c2f4d9..1e70336cdd 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ to its corresponding upstream tag (e.g., `0.2.5+amd.2` is second release of amd- | **Decode Attention** | ✅ | ✅ | No | Supports MHA, GQA, and MQA | | **Prefill Attention** | ✅ | WIP | ✅ | Supports MHA, GQA, and MQA | | **Cascade Attention** | TBD | TBD | No | Not Yet Ported | -| **MLA** | TBD | TBD | No | Not Yet Ported | +| **MLA** | ✅ | TBD | No | Decode-path (qo_len=1) on CDNA3; prefill TODO | | **POD** | TBD | TBD | No | Not Yet Ported | | **Positional Encoding** | TBD | TBD | No | Not Yet Ported | | **Sampling** | ✅ | TBD | No | Supports Top-K/Top-P Sampling/OnlineSoftmax/SamplingFromLogits | diff --git a/benchmarks/rocm_benchmarks/bench_mla_hip.py b/benchmarks/rocm_benchmarks/bench_mla_hip.py new file mode 100644 index 0000000000..9df82f73c5 --- /dev/null +++ b/benchmarks/rocm_benchmarks/bench_mla_hip.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# MLA decode benchmark for HIP/ROCm (DeepSeek shapes). +# Mirrors benchmarks/bench_deepseek_mla.py with backend="hip". +# +# Run: +# python benchmarks/rocm_benchmarks/bench_mla_hip.py +# python benchmarks/rocm_benchmarks/bench_mla_hip.py --batch 64 --seq 8192 --heads 128 + +import argparse +import math + +import numpy as np +import torch + +import flashinfer.mla +from flashinfer.jit import build_jit_specs, gen_batch_mla_module +from flashinfer.testing.utils import bench_gpu_time + +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 +_PAGE_SIZE = 1 + + +def _warmup_jit(dtype: torch.dtype) -> None: + build_jit_specs( + [ + gen_batch_mla_module( + "hip", + dtype, + dtype, + dtype, + torch.int32, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + False, + ) + ], + verbose=False, + ) + + +@torch.inference_mode() +def bench(batch_size: int, seq_len: int, num_heads: int, dtype: torch.dtype) -> None: + sm_scale = 1.0 / math.sqrt(HEAD_DIM_CKV + HEAD_DIM_KPE) + + q_nope = torch.randn( + batch_size, num_heads, HEAD_DIM_CKV, dtype=dtype, device="cuda" + ) + q_pe = torch.randn(batch_size, num_heads, HEAD_DIM_KPE, dtype=dtype, device="cuda") + ckv = torch.randn( + batch_size * seq_len, _PAGE_SIZE, HEAD_DIM_CKV, dtype=dtype, device="cuda" + ) + kpe = torch.randn( + batch_size * seq_len, _PAGE_SIZE, HEAD_DIM_KPE, dtype=dtype, device="cuda" + ) + + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") + + q_indptr = torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") + kv_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * seq_len + ) + kv_indices = torch.arange(0, batch_size * seq_len, dtype=torch.int32, device="cuda") + kv_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + num_heads, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + _PAGE_SIZE, + False, + sm_scale, + dtype, + dtype, + ) + + o = wrapper.run(q_nope, q_pe, ckv, kpe) + assert o.shape == (batch_size, num_heads, HEAD_DIM_CKV) + + measurements = bench_gpu_time( + lambda: wrapper.run(q_nope, q_pe, ckv, kpe), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io_bytes = sum(t.numel() * t.element_size() for t in [q_nope, q_pe, ckv, kpe, o]) + # Two GEMMs: Q·KVᵀ (ckv + kpe dim) and P·V (ckv dim). + flops = 2 * batch_size * num_heads * (2 * HEAD_DIM_CKV + HEAD_DIM_KPE) * seq_len + + dtype_str = "fp16" if dtype == torch.float16 else "bf16" + print( + f"[{dtype_str}] batch={batch_size:>4d} seq={seq_len:>6d} heads={num_heads:>4d} | " + f"lat={ms * 1e3:.1f} µs | " + f"BW={io_bytes * 1e-6 / ms:.1f} GB/s | " + f"TFLOPS={flops * 1e-9 / ms:.3f}" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="MLA HIP decode benchmark") + parser.add_argument( + "--batch", type=int, default=0, help="single batch size (0 = sweep)" + ) + parser.add_argument("--seq", type=int, default=0, help="single seq len (0 = sweep)") + parser.add_argument( + "--heads", type=int, default=16, help="number of attention heads" + ) + parser.add_argument( + "--dtype", + choices=["fp16", "bf16", "both"], + default="fp16", + help="data type (default: fp16)", + ) + args = parser.parse_args() + + dtypes = [] + if args.dtype in ("fp16", "both"): + dtypes.append(torch.float16) + if args.dtype in ("bf16", "both"): + dtypes.append(torch.bfloat16) + + batch_sizes = [args.batch] if args.batch > 0 else [64, 128, 768] + seq_lens = [args.seq] if args.seq > 0 else [1024, 2048, 8192] + + for dtype in dtypes: + print( + f"\n=== Warming up JIT ({('fp16' if dtype == torch.float16 else 'bf16')}) ===" + ) + _warmup_jit(dtype) + print("=== Benchmarking ===") + for seq_len in seq_lens: + for batch_size in batch_sizes: + bench(batch_size, seq_len, args.heads, dtype) + + +if __name__ == "__main__": + main() diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 0d7cca81ff..b800a4c9e6 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -199,6 +199,7 @@ from .norm import layernorm as layernorm from .norm import rmsnorm as rmsnorm from .page import append_paged_kv_cache as append_paged_kv_cache + from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache from .page import get_batch_indices_positions as get_batch_indices_positions from .page import get_seq_lens as get_seq_lens from .prefill_rocm import ( # type: ignore[assignment] diff --git a/flashinfer/aot_hip.py b/flashinfer/aot_hip.py index 5912977f27..58b801cede 100644 --- a/flashinfer/aot_hip.py +++ b/flashinfer/aot_hip.py @@ -95,6 +95,8 @@ def gen_attention( use_sliding_window_: List[bool], use_logits_soft_cap_: List[bool], ) -> Iterator: + from .jit.attention import gen_batch_mla_module + # FA2 MHA / MQA / GQA for ( (head_dim_qk, head_dim_vo), @@ -118,6 +120,19 @@ def gen_attention( use_logits_soft_cap=use_logits_soft_cap, ) + # MLA (DeepSeek shapes: head_dim_ckv=512, head_dim_kpe=64). + for dtype in f16_dtype_: + yield gen_batch_mla_module( + backend="hip", + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + dtype_idx=torch.int32, + head_dim_ckv=512, + head_dim_kpe=64, + use_profiler=False, + ) + def gen_all_modules( f16_dtype_: List[torch.dtype], diff --git a/flashinfer/csrc_rocm/batch_mla.cu b/flashinfer/csrc_rocm/batch_mla.cu new file mode 100644 index 0000000000..69bc8e8d0d --- /dev/null +++ b/flashinfer/csrc_rocm/batch_mla.cu @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 +#include + +#include + +#include "batch_mla_config.inc" +#include "pytorch_conversion_utils.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor pin_memory_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, + at::Tensor kv_len_arr, int64_t num_heads, int64_t head_dim_o, + bool causal) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + MLAPlanInfo plan_info; + int64_t batch_size = kv_len_arr.size(0); + + const c10::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(float_workspace_buffer.device()); + const hipStream_t stream = c10::hip::getCurrentHIPStream(); + + hipError_t status = + MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), pin_memory_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), + static_cast(kv_len_arr.data_ptr()), static_cast(batch_size), + static_cast(num_heads), static_cast(head_dim_o), causal, stream); + + TORCH_CHECK(status == hipSuccess, + "BatchMLAPagedAttentionPlan failed with error: ", hipGetErrorString(status)); + + return vec_to_tensor(plan_info.ToVector()); +} + +void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, + at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices, + at::Tensor o, std::optional maybe_lse, + int64_t mask_mode_code, int64_t num_heads, int64_t page_size, + double sm_scale, bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS) { + MLAPlanInfo plan_info; + plan_info.FromVector(tensor_to_vec(plan_info_vec)); + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + const MaskMode mask_mode = static_cast(mask_mode_code); + + const c10::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(q_nope.device()); + const hipStream_t stream = c10::hip::getCurrentHIPStream(); + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] { + Params params; + + params.q_nope = static_cast(q_nope.data_ptr()); + params.q_pe = static_cast(q_pe.data_ptr()); + params.ckv = static_cast(ckv_cache.data_ptr()); + params.kpe = static_cast(kpe_cache.data_ptr()); + + params.q_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.partial_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.partial_indptr_offset); + params.kv_indices = static_cast(kv_indices.data_ptr()); + params.q_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_len_offset); + params.kv_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.q_start = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_start_offset); + params.kv_start = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_start_offset); + params.kv_end = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_end_offset); + params.work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.merge_packed_offset_start = GetPtrFromBaseOffset( + int_buffer_ptr, plan_info.merge_packed_offset_start_offset); + params.merge_packed_offset_end = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_packed_offset_end_offset); + params.merge_partial_packed_offset_start = GetPtrFromBaseOffset( + int_buffer_ptr, plan_info.merge_partial_packed_offset_start_offset); + params.merge_partial_packed_offset_end = GetPtrFromBaseOffset( + int_buffer_ptr, plan_info.merge_partial_packed_offset_end_offset); + params.merge_partial_stride = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_partial_stride_offset); + + params.final_o = static_cast(o.data_ptr()); + params.final_lse = + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.partial_o = + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); + params.partial_lse = + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_lse_offset); + + params.num_heads = uint_fastdiv(num_heads); + params.block_size = uint_fastdiv(page_size); + + params.q_nope_stride_n = q_nope.stride(0); + params.q_nope_stride_h = q_nope.stride(1); + params.q_pe_stride_n = q_pe.stride(0); + params.q_pe_stride_h = q_pe.stride(1); + params.ckv_stride_page = ckv_cache.stride(0); + params.ckv_stride_n = ckv_cache.stride(1); + params.kpe_stride_page = kpe_cache.stride(0); + params.kpe_stride_n = kpe_cache.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + + params.sm_scale = static_cast(sm_scale); + params.return_lse_base_on_e = return_lse_base_on_e; + + ADDITIONAL_PARAMS_SETTER + + hipError_t status = mla::BatchMLAPagedAttentionHIP( + params, plan_info.num_blks_x, plan_info.num_blks_y, stream); + TORCH_CHECK(status == hipSuccess, + "BatchMLAPagedAttentionRun failed: ", hipGetErrorString(status)); + }); +} diff --git a/flashinfer/csrc_rocm/batch_mla_customize_config.jinja b/flashinfer/csrc_rocm/batch_mla_customize_config.jinja new file mode 100644 index 0000000000..c654a2f1a7 --- /dev/null +++ b/flashinfer/csrc_rocm/batch_mla_customize_config.jinja @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// NOTE: This file is generated from batch_mla_customize_config.jinja. +// Do not edit manually. +#pragma once +#include + +using namespace flashinfer; + +#ifdef FLASHINFER_ENABLE_PROFILER +#define ADDITIONAL_FUNC_PARAMS , at::Tensor profiler_buffer +#define ADDITIONAL_PARAMS_SETTER \ + params.profiler_buffer = static_cast(profiler_buffer.data_ptr()); +#else +#define ADDITIONAL_FUNC_PARAMS +#define ADDITIONAL_PARAMS_SETTER +#endif + +using DTypeQ = {{ dtype_q }}; +using DTypeKV = {{ dtype_kv }}; +using DTypeO = {{ dtype_o }}; +using IdType = {{ dtype_idx }}; +constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }}; +constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }}; + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \ + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ + using Params = MLAParams; \ + __VA_ARGS__(); \ + }) diff --git a/flashinfer/csrc_rocm/batch_mla_jit_pybind.cu b/flashinfer/csrc_rocm/batch_mla_jit_pybind.cu new file mode 100644 index 0000000000..dc8c1d1302 --- /dev/null +++ b/flashinfer/csrc_rocm/batch_mla_jit_pybind.cu @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 +#include "batch_mla_config.inc" +#include "pytorch_extension_utils.h" + +at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor pin_memory_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, + at::Tensor kv_len_arr, int64_t num_heads, int64_t head_dim_o, + bool causal); + +void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, + at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices, + at::Tensor o, std::optional maybe_lse, + int64_t mask_mode_code, int64_t num_heads, int64_t page_size, + double sm_scale, bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS); + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("plan", BatchMLAPagedAttentionPlan); + m.def("run", BatchMLAPagedAttentionRun); +} diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 3be99e7bc0..97db277e0a 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -108,6 +108,8 @@ def get_cudnn_fmha_gen_module(): from .activation import gen_act_and_mul_module as gen_act_and_mul_module from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str from .attention import gen_batch_decode_module as gen_batch_decode_module + from .attention import gen_batch_mla_module as gen_batch_mla_module + from .attention import get_batch_mla_uri as get_batch_mla_uri from .attention import gen_batch_prefill_module as gen_batch_prefill_module from .attention import ( gen_customize_batch_decode_module as gen_customize_batch_decode_module, diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py index b749d5d012..9ffcf219a9 100644 --- a/flashinfer/jit/attention/__init__.py +++ b/flashinfer/jit/attention/__init__.py @@ -72,6 +72,8 @@ ) from .modules_hip import gen_single_decode_module as gen_single_decode_module from .modules_hip import gen_single_prefill_module as gen_single_prefill_module + from .modules_hip import gen_batch_mla_module as gen_batch_mla_module + from .modules_hip import get_batch_mla_uri as get_batch_mla_uri from .modules_hip import get_batch_decode_uri as get_batch_decode_uri from .modules_hip import get_batch_prefill_uri as get_batch_prefill_uri from .modules_hip import get_single_decode_uri as get_single_decode_uri diff --git a/flashinfer/jit/attention/modules_hip.py b/flashinfer/jit/attention/modules_hip.py index 28df48ef01..9040c283ee 100644 --- a/flashinfer/jit/attention/modules_hip.py +++ b/flashinfer/jit/attention/modules_hip.py @@ -943,3 +943,84 @@ def gen_customize_batch_prefill_module( raise ValueError("FA3 backend not currently supported for ROCm") else: raise ValueError(f"Invalid backend: {backend}") + + +def get_batch_mla_uri( + backend: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_ckv: int, + head_dim_kpe: int, + use_profiler: bool, +) -> str: + return ( + f"batch_mla_attention_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" + f"head_dim_ckv_{head_dim_ckv}_" + f"head_dim_kpe_{head_dim_kpe}_" + f"profiler_{use_profiler}_{backend}" + ) + + +def gen_batch_mla_module( + backend: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_ckv: int, + head_dim_kpe: int, + use_profiler: bool, +) -> JitSpec: + if backend != "hip": + raise ValueError(f"MLA only supports backend='hip', got {backend!r}") + if use_profiler: + raise ValueError( + "use_profiler is not supported on HIP/ROCm — profiler.cuh uses NVIDIA PTX asm" + ) + uri = get_batch_mla_uri( + backend, + dtype_q, + dtype_kv, + dtype_o, + dtype_idx, + head_dim_ckv, + head_dim_kpe, + use_profiler, + ) + gen_directory = FLASHINFER_GEN_SRC_DIR / uri + os.makedirs(gen_directory, exist_ok=True) + + with open(FLASHINFER_CSRC_DIR / "batch_mla_customize_config.jinja") as f: + config_templ = jinja2.Template(f.read()) + + generated_inc_str = config_templ.render( + dtype_q=dtype_map_hip[dtype_q], + dtype_kv=dtype_map_hip[dtype_kv], + dtype_o=dtype_map_hip[dtype_o], + dtype_idx=dtype_map_hip[dtype_idx], + head_dim_ckv=head_dim_ckv, + head_dim_kpe=head_dim_kpe, + ) + + source_paths = [] + for filename in [ + "batch_mla.cu", + "batch_mla_jit_pybind.cu", + ]: + src_path = FLASHINFER_CSRC_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) + + generated_config_path = gen_directory / "batch_mla_config.inc" + write_if_different(generated_config_path, generated_inc_str) + + extra_cuda_cflags: list[str] = [] + return gen_jit_spec(uri, source_paths, extra_cuda_cflags=extra_cuda_cflags) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index da57d94e6b..0669d66d7e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -19,10 +19,13 @@ import torch +from .device_utils import IS_HIP from .jit import gen_batch_mla_module -from .jit.mla import gen_mla_module from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend +if not IS_HIP: + from .jit.mla import gen_mla_module + def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): if q_nope_pe.ndim != 3: @@ -55,6 +58,8 @@ def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): @functools.cache def get_mla_module(): + if IS_HIP: + raise RuntimeError("Cutlass MLA backend is not supported on HIP/ROCm") return gen_mla_module().build_and_load() diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 47d725c5d3..dbac6d4663 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1959,6 +1959,7 @@ def run( return_lse: Literal[False] = False, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, + partial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: ... @overload @@ -1974,6 +1975,7 @@ def run( return_lse: Literal[True] = True, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, + partial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... def run( @@ -1990,6 +1992,7 @@ def run( enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, sinks: Optional[torch.Tensor] = None, + partial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and paged kv-cache. diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 6f9fe7e8b9..7e038c9ee0 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -28,6 +28,7 @@ from torch.torch_version import __version__ as torch_version import inspect +from .device_utils import IS_HIP from .jit.spdlog import gen_spdlog_module @@ -640,6 +641,8 @@ def is_sm121a_supported(device: torch.device) -> bool: def determine_mla_backend(device: torch.device) -> str: + if IS_HIP: + return "hip" return "fa3" if is_sm90a_supported(device) else "fa2" diff --git a/include/flashinfer/attention/mla_hip.cuh b/include/flashinfer/attention/mla_hip.cuh new file mode 100644 index 0000000000..ac3806cb0b --- /dev/null +++ b/include/flashinfer/attention/mla_hip.cuh @@ -0,0 +1,655 @@ +// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 +// +// HIP MLA kernel for CDNA3 (MI300X). +// +// CTA config: dim3(64,4,1) — 4 wavefronts, 256 threads. +// Wave 0 computes QK (q_pe·kpe^T + q_nope·ckv^T) and softmax, then broadcasts +// the softmax weights (p_frag) and updated m/d scalars via LDS to waves 1..3. +// All 4 waves then compute their PV head_dim shard in parallel. +// Wave w owns head_dim shard [w*(HEAD_DIM_CKV/4) .. (w+1)*(HEAD_DIM_CKV/4)). +// Q is loaded once into registers per work item (loop-invariant across KV tiles). +// +// MFMA layout note: CDNA3 mfma_f32_16x16x16f16 produces D in column-major +// per-thread layout (each thread holds D[m=(t/16)*4+r][n=t%16] for r=0..3). +// To get a more natural row-major s_frag/o_frag (each thread holds 1 row, 4 +// cols), we swap the A and B operands in both MFMAs: +// QK: A=K, B=Q → D[m=kv][n=q] = S[q][kv]; thread t's r-th reg = S[q=t%16][kv=col_group*4+r] +// PV: A=V, B=P → D[m=d][n=q] = O[q][d]; thread t's r-th reg = +// O[q=t%16][d=mma_d_abs*16+col_group*4+r] +// This avoids per-iteration in-register transposes of the s_frag tile. +#ifndef FLASHINFER_MLA_HIP_CUH_ +#define FLASHINFER_MLA_HIP_CUH_ + +#include + +#include "../fastdiv.cuh" +#include "gpu_iface/cooperative_groups.h" +#include "gpu_iface/enums.hpp" +#include "gpu_iface/math_ops.hpp" +#include "gpu_iface/mma_ops.hpp" +#include "gpu_iface/mma_types.hpp" +#include "gpu_iface/utils.cuh" +#include "mla_params.cuh" + +namespace flashinfer { +namespace mla { + +constexpr uint32_t MLA_HIP_WAVE_SIZE = 64; +constexpr uint32_t MLA_HIP_NUM_WAVES = 4; +constexpr uint32_t MLA_HIP_NUM_THREADS = MLA_HIP_WAVE_SIZE * MLA_HIP_NUM_WAVES; +constexpr uint32_t MLA_HIP_CTA_TILE_Q = 16; +constexpr uint32_t MLA_HIP_CTA_TILE_KV = 16; +// Sentinel for the running max in online softmax. Matches the CUDA path +// (state_t::init uses -math::inf, which has the same numeric value 5e4f). +constexpr float MLA_HIP_NEG_INF = -5e4f; + +// Two KV tiles in LDS to support double-buffered load/compute pipelining. +// CKV rows are padded by 8 fp16 (one uint4) so that the stride-row reads in +// compute_pv_hip don't collide on the same LDS bank — without padding all 4 +// strided reads from one thread land on bank 0 (HEAD_DIM_CKV*2/4 % 32 == 0), +// giving a 4-way bank conflict per thread. With +8 fp16 padding the per-thread +// reads spread across banks {b, b+8, b+16, b+24}. +constexpr uint32_t MLA_HIP_CKV_LDS_PAD = 8; +constexpr uint32_t MLA_HIP_NUM_STAGES = 2; + +// Wave-specialization LDS stage: wave 0 writes softmax weights and online-softmax +// scalars here after QK; waves 1..3 read them to skip QK MFMAs entirely. +// p_frag row stride is 17 (= CTA_TILE_KV + 1) to reduce LDS bank conflicts: +// thread (q_row=a, col_group=b) maps to float offset a*17 + b*4, giving a more +// uniform bank distribution than stride=16 which produces 8-way conflicts. +constexpr uint32_t MLA_HIP_P_FRAG_STRIDE = MLA_HIP_CTA_TILE_KV + 1; +struct WaveSpecStageMLAHIP { + float p_frag[MLA_HIP_CTA_TILE_Q][MLA_HIP_P_FRAG_STRIDE]; + float m_stage[MLA_HIP_CTA_TILE_Q]; + float o_scale_stage[MLA_HIP_CTA_TILE_Q]; + float d_stage[MLA_HIP_CTA_TILE_Q]; +}; + +template +struct SharedStorageMLAHIP { + DTypeKV ckv_smem[MLA_HIP_NUM_STAGES][MLA_HIP_CTA_TILE_KV][HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD]; + DTypeKV kpe_smem[MLA_HIP_NUM_STAGES][MLA_HIP_CTA_TILE_KV][HEAD_DIM_KPE]; + WaveSpecStageMLAHIP wave_spec; +}; + +template +__device__ __forceinline__ void load_kv_hip( + DTypeKV* __restrict__ ckv_smem, DTypeKV* __restrict__ kpe_smem, + const DTypeKV* __restrict__ ckv_global, const DTypeKV* __restrict__ kpe_global, + const IdType* __restrict__ kv_indices, uint32_t ckv_stride_page, uint32_t ckv_stride_n, + uint32_t kpe_stride_page, uint32_t kpe_stride_n, uint32_t packed_kv_tile_base, + uint32_t packed_kv_bound, const uint_fastdiv& block_size, uint32_t tid) { + // 128-bit (uint4 = 8 fp16) loads to maximize per-instruction DRAM throughput. + // ckv_smem rows are padded by MLA_HIP_CKV_LDS_PAD fp16 (= 1 uint4); load by + // (kv_row, col_u4) to handle the LDS stride correctly. + constexpr uint32_t CKV_U4_PER_ROW = HEAD_DIM_CKV / 8; + constexpr uint32_t CKV_LDS_U4_STRIDE = (HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD) / 8; + constexpr uint32_t CKV_TOTAL_U4 = MLA_HIP_CTA_TILE_KV * CKV_U4_PER_ROW; + constexpr uint32_t CKV_U4_PER_THREAD = CKV_TOTAL_U4 / MLA_HIP_NUM_THREADS; + static_assert(CKV_TOTAL_U4 % MLA_HIP_NUM_THREADS == 0, "CKV load not evenly divisible"); + + uint4* smem_ckv = reinterpret_cast(ckv_smem); +#pragma unroll + for (uint32_t i = 0; i < CKV_U4_PER_THREAD; ++i) { + uint32_t u4_idx = tid * CKV_U4_PER_THREAD + i; + uint32_t kv_row = u4_idx / CKV_U4_PER_ROW; + uint32_t col_u4 = u4_idx % CKV_U4_PER_ROW; + uint32_t packed = packed_kv_tile_base + kv_row; + uint32_t page_idx, row_in_page; + block_size.divmod(packed, page_idx, row_in_page); + bool valid = (packed < packed_kv_bound); + const DTypeKV* gptr = ckv_global + + (valid ? kv_indices[page_idx] : IdType(0)) * ckv_stride_page + + row_in_page * ckv_stride_n + col_u4 * 8; + smem_ckv[kv_row * CKV_LDS_U4_STRIDE + col_u4] = + valid ? *reinterpret_cast(gptr) : uint4{0u, 0u, 0u, 0u}; + } + + constexpr uint32_t KPE_U4_PER_ROW = HEAD_DIM_KPE / 8; + constexpr uint32_t KPE_TOTAL_U4 = MLA_HIP_CTA_TILE_KV * KPE_U4_PER_ROW; + constexpr uint32_t KPE_U4_PER_THREAD = KPE_TOTAL_U4 / MLA_HIP_NUM_THREADS; + // KPE may be too small for one uint4 per thread; fall back to a per-thread loop only if work + // exists. For HEAD_DIM_KPE=64, CTA_TILE_KV=16: total=128 uint4, per-thread=0 (only 128 of 256 + // threads load). Handle as "first 128 threads do 1 load each". + uint4* smem_kpe = reinterpret_cast(kpe_smem); + if constexpr (KPE_U4_PER_THREAD >= 1) { +#pragma unroll + for (uint32_t i = 0; i < KPE_U4_PER_THREAD; ++i) { + uint32_t u4_idx = tid * KPE_U4_PER_THREAD + i; + uint32_t kv_row = u4_idx / KPE_U4_PER_ROW; + uint32_t col_u4 = u4_idx % KPE_U4_PER_ROW; + uint32_t packed = packed_kv_tile_base + kv_row; + uint32_t page_idx, row_in_page; + block_size.divmod(packed, page_idx, row_in_page); + bool valid = (packed < packed_kv_bound); + const DTypeKV* gptr = kpe_global + + (valid ? kv_indices[page_idx] : IdType(0)) * kpe_stride_page + + row_in_page * kpe_stride_n + col_u4 * 8; + smem_kpe[u4_idx] = valid ? *reinterpret_cast(gptr) : uint4{0u, 0u, 0u, 0u}; + } + } else { + // Less than 1 uint4 per thread: only the first KPE_TOTAL_U4 threads do a load. + if (tid < KPE_TOTAL_U4) { + uint32_t kv_row = tid / KPE_U4_PER_ROW; + uint32_t col_u4 = tid % KPE_U4_PER_ROW; + uint32_t packed = packed_kv_tile_base + kv_row; + uint32_t page_idx, row_in_page; + block_size.divmod(packed, page_idx, row_in_page); + bool valid = (packed < packed_kv_bound); + const DTypeKV* gptr = kpe_global + + (valid ? kv_indices[page_idx] : IdType(0)) * kpe_stride_page + + row_in_page * kpe_stride_n + col_u4 * 8; + smem_kpe[tid] = valid ? *reinterpret_cast(gptr) : uint4{0u, 0u, 0u, 0u}; + } + } +} + +// Q is loop-invariant across the kv_tile loop; caching saves (num_kv_tiles - 1) reloads. +// q_pe_frag[mma_d][2] holds Q_pe[q=t%16][d=mma_d*16 + col_group*4 + 0..3] as 4 fp16. +// q_nope_frag is the same for Q_nope across NUM_MMA_D_CKV head_dim tiles. +// Out-of-range threads (batch_idx >= q_len) get zero-filled fragments. +template +__device__ __forceinline__ void load_q_frags_hip(uint32_t (*q_pe_frag)[2], + uint32_t (*q_nope_frag)[2], const DTypeQ* q_nope, + const DTypeQ* q_pe, uint32_t q_nope_stride_n, + uint32_t q_nope_stride_h, uint32_t q_pe_stride_n, + uint32_t q_pe_stride_h, uint32_t batch_idx, + uint32_t head_idx, bool q_valid, + uint32_t col_group) { + constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16; + constexpr uint32_t NUM_MMA_D_KPE = HEAD_DIM_KPE / 16; + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_KPE; ++mma_d) { + q_pe_frag[mma_d][0] = 0u; + q_pe_frag[mma_d][1] = 0u; + if (q_valid) { + const DTypeQ* ptr = + q_pe + batch_idx * q_pe_stride_n + head_idx * q_pe_stride_h + mma_d * 16 + col_group * 4; + *reinterpret_cast(q_pe_frag[mma_d]) = *reinterpret_cast(ptr); + } + } +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV; ++mma_d) { + q_nope_frag[mma_d][0] = 0u; + q_nope_frag[mma_d][1] = 0u; + if (q_valid) { + const DTypeQ* ptr = q_nope + batch_idx * q_nope_stride_n + head_idx * q_nope_stride_h + + mma_d * 16 + col_group * 4; + *reinterpret_cast(q_nope_frag[mma_d]) = *reinterpret_cast(ptr); + } + } +} + +// CDNA3 MFMA computes D = A*B where the D output is column-major per-thread: +// thread t holds D[m=(t/16)*4+r][n=t%16]. To get the natural row-major s_frag +// layout (thread t holds s_frag[r] = S[q=t%16][kv=col_group*4+r]) we swap the +// A and B operands: A=K, B=Q. The math: +// D[m=kv][n=q] = sum_d K[kv=m][d=k] * Q[q=n][d=k] = (K@Q^T)[kv][q] = S[q][kv]. +// Combined with the column-major D layout, thread t's r-th register = +// S[q=t%16][kv=(t/16)*4+r]. +template +__device__ __forceinline__ void compute_qk_hip(float* s_frag, const uint32_t (*q_pe_frag)[2], + const uint32_t (*q_nope_frag)[2], + const DTypeKV* ckv_smem, const DTypeKV* kpe_smem, + uint32_t q_row, uint32_t col_group) { + constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16; + constexpr uint32_t NUM_MMA_D_KPE = HEAD_DIM_KPE / 16; + constexpr uint32_t CKV_LDS_U2_STRIDE = (HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD) / 4; + constexpr uint32_t KPE_U2_PER_ROW = HEAD_DIM_KPE / 4; + + const uint2* smem_ckv = reinterpret_cast(ckv_smem); + const uint2* smem_kpe = reinterpret_cast(kpe_smem); + + // KPE tiles run first — kInit on tile 0 zeros the accumulator before CKV accumulates. +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_KPE; ++mma_d) { + uint32_t k_frag[2]; + *reinterpret_cast(k_frag) = smem_kpe[q_row * KPE_U2_PER_ROW + mma_d * 4 + col_group]; + if (mma_d == 0) { + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag, k_frag, const_cast(q_pe_frag[mma_d])); + } else { + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag, k_frag, const_cast(q_pe_frag[mma_d])); + } + } + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV; ++mma_d) { + uint32_t k_frag[2]; + *reinterpret_cast(k_frag) = smem_ckv[q_row * CKV_LDS_U2_STRIDE + mma_d * 4 + col_group]; + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag, k_frag, const_cast(q_nope_frag[mma_d])); + } +} + +// s_frag[r] = S[q_row=lane%16][kv_col = kv_idx_base + (lane/16)*4 + r] +template +__device__ __forceinline__ void logits_mask_hip(float* s_frag, uint32_t qo_packed_idx_base, + uint32_t kv_idx_base, uint32_t q_len, + uint32_t kv_len, uint32_t kv_end, + const uint_fastdiv& num_heads, uint32_t lane_idx) { + uint32_t q_idx = (qo_packed_idx_base + lane_idx % 16) / num_heads; + uint32_t col_group = lane_idx / 16; + +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) { + uint32_t kv_idx = kv_idx_base + col_group * 4 + r; + bool out_of_bound = (kv_idx >= kv_end); + bool causal_masked = CAUSAL && (kv_idx + q_len > kv_len + q_idx); + if (out_of_bound || causal_masked) s_frag[r] = MLA_HIP_NEG_INF; + } +} + +template +__device__ __forceinline__ void broadcast_softmax_hip(float* s_frag, float (*o_frag)[4], + float& m_val, float& d_scalar, + WaveSpecStageMLAHIP& stage, uint32_t q_row, + uint32_t col_group, float sm_scale_log2) { + float m_local = fmaxf(fmaxf(s_frag[0], s_frag[1]), fmaxf(s_frag[2], s_frag[3])); + m_local = fmaxf(m_local, math::shfl_xor_sync(m_local, 0x10)); + m_local = fmaxf(m_local, math::shfl_xor_sync(m_local, 0x20)); + + float m_prev = m_val; + m_val = fmaxf(m_prev, m_local); + float o_scale = math::ptx_exp2((m_prev - m_val) * sm_scale_log2); + d_scalar *= o_scale; + +#pragma unroll + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] *= o_scale; + + const float m_scaled = m_val * sm_scale_log2; + float partial_d = 0.f; +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) { + float p = math::ptx_exp2(s_frag[r] * sm_scale_log2 - m_scaled); + s_frag[r] = p; + partial_d += p; + } + partial_d += math::shfl_xor_sync(partial_d, 0x10); + partial_d += math::shfl_xor_sync(partial_d, 0x20); + d_scalar += partial_d; + + // Broadcast p_frag and per-q_row scalars to LDS for waves 1..3. +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) stage.p_frag[q_row][col_group * 4 + r] = s_frag[r]; + if (col_group == 0) { + stage.m_stage[q_row] = m_val; + stage.o_scale_stage[q_row] = o_scale; + stage.d_stage[q_row] = d_scalar; + } +} + +template +__device__ __forceinline__ void read_softmax_hip(float* s_frag, float (*o_frag)[4], float& m_val, + float& d_scalar, const WaveSpecStageMLAHIP& stage, + uint32_t q_row, uint32_t col_group) { +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) s_frag[r] = stage.p_frag[q_row][col_group * 4 + r]; + float o_scale = stage.o_scale_stage[q_row]; +#pragma unroll + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] *= o_scale; + m_val = stage.m_stage[q_row]; + d_scalar = stage.d_stage[q_row]; +} + +// As with compute_qk_hip, we swap A/B so the column-major D output gives a +// row-major o_frag: thread t's r-th register = O[q=t%16][d=mma_d_abs*16+col_group*4+r]. +// We pass A=V (loaded strided so f16x4[j]=V[kv=col_group*4+j][d=mma_d_abs*16+t%16]) +// and B=P (s_frag, with f16x4[r]=P[q=t%16][kv=col_group*4+r]). +// The math: D[m=d_local][n=q] = sum_kv V[kv][d] * P[q][kv] = O[q][d_global]. +template +__device__ __forceinline__ void compute_pv_hip(float (*o_frag)[4], const float* s_frag, + const DTypeKV* ckv_smem, uint32_t wave_idx, + uint32_t lane_idx) { + uint32_t q_row = lane_idx % 16; + uint32_t col_group = lane_idx / 16; + uint32_t d_start = wave_idx * NUM_MMA_D_PER_WAVE; + + DTypeKV p_f16[4]; +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) p_f16[r] = static_cast(s_frag[r]); + uint32_t p_frag[2]; + *reinterpret_cast(p_frag) = *reinterpret_cast(p_f16); + +#pragma unroll + for (uint32_t wave_mma_d = 0; wave_mma_d < NUM_MMA_D_PER_WAVE; ++wave_mma_d) { + uint32_t mma_d_abs = d_start + wave_mma_d; + // Load V via 4 strided scalar reads: V[kv=col_group*4+j][d=mma_d_abs*16+q_row]. + // Used as A operand: A[m=t%16][k=col_group*4+j] = V[kv=k][d=mma_d_abs*16+m]. + uint32_t d_col = mma_d_abs * 16 + q_row; + DTypeKV v_vals[4]; +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + v_vals[j] = ckv_smem[(col_group * 4 + j) * (HEAD_DIM_CKV + MLA_HIP_CKV_LDS_PAD) + d_col]; + } + uint32_t v_frag[2]; + *reinterpret_cast(v_frag) = *reinterpret_cast(v_vals); + gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32(o_frag[wave_mma_d], v_frag, + p_frag); + } +} + +template +__device__ __forceinline__ void normalize_d_hip(float (*o_frag)[4], float m_val, float d_scalar) { + float d_rcp = (m_val != MLA_HIP_NEG_INF) ? math::ptx_rcp(d_scalar) : 0.f; +#pragma unroll + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) { +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] *= d_rcp; + } +} + +// Each thread writes for q_row = lane_idx%16 and its wave's head_dim shard. +// Only wave_idx==0 && col_group==0 (lane_idx<16) writes the LSE scalar. +// partial_o / partial_lse are the GLOBAL arrays; partial_indptr = -1 when single-tile. +template +__device__ __forceinline__ void write_o_hip( + float (*o_frag)[4], float m_val, float d_scalar, DTypeO* final_o, float* final_lse, + DTypeO* partial_o, float* partial_lse, uint32_t o_stride_n, uint32_t o_stride_h, uint32_t q_len, + uint32_t qo_packed_idx_base, int32_t partial_indptr, const uint_fastdiv& num_heads, + bool return_lse_base_on_e, uint32_t lane_idx, uint32_t wave_idx) { + uint32_t q_row = lane_idx % 16; + uint32_t col_group = lane_idx / 16; + uint32_t q_packed = qo_packed_idx_base + q_row; + uint32_t batch_idx, head_idx; + num_heads.divmod(q_packed, batch_idx, head_idx); + bool q_valid = (batch_idx < q_len); + uint32_t d_start = wave_idx * NUM_MMA_D_PER_WAVE; + + if (partial_indptr >= 0) { + uint32_t partial_row = static_cast(partial_indptr) + q_row; + if (q_valid) { + // One thread per q_row writes the LSE (lane_idx < 16 means col_group == 0) + if (wave_idx == 0 && col_group == 0) { + partial_lse[partial_row] = math::ptx_log2(d_scalar) + m_val; + } +#pragma unroll + for (uint32_t wave_mma_d = 0; wave_mma_d < NUM_MMA_D_PER_WAVE; ++wave_mma_d) { + uint32_t head_dim_col = (d_start + wave_mma_d) * 16 + col_group * 4; + DTypeO* ptr = partial_o + partial_row * HEAD_DIM_CKV + head_dim_col; +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) ptr[r] = static_cast(o_frag[wave_mma_d][r]); + } + } + } else { + if (q_valid) { + if (final_lse && wave_idx == 0 && col_group == 0) { + float lse = math::ptx_log2(d_scalar) + m_val; + if (return_lse_base_on_e) lse *= math::loge2; + final_lse[batch_idx * (uint32_t)num_heads + head_idx] = lse; + } +#pragma unroll + for (uint32_t wave_mma_d = 0; wave_mma_d < NUM_MMA_D_PER_WAVE; ++wave_mma_d) { + uint32_t head_dim_col = (d_start + wave_mma_d) * 16 + col_group * 4; + DTypeO* ptr = final_o + batch_idx * o_stride_n + head_idx * o_stride_h + head_dim_col; +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) ptr[r] = static_cast(o_frag[wave_mma_d][r]); + } + } + } +} + +// partial_o[row * HEAD_DIM_CKV + col] stores DTypeO (normalized by each partial's d). +// partial_lse[row] = log2(d) + m for each partial block. +// Merge treats partial_lse as the combined log-sum-exp weight for each partial. +template +__device__ void DevicePersistentMergeStatesHIP( + const IdType* merge_packed_offset_start, const IdType* merge_packed_offset_end, + const IdType* merge_partial_packed_offset_start, const IdType* merge_partial_packed_offset_end, + const IdType* merge_partial_stride, const DTypeO* partial_o, const float* partial_lse, + DTypeO* final_o, float* final_lse, uint32_t o_stride_n, uint32_t o_stride_h, + const uint_fastdiv& num_heads, bool return_lse_base_on_e) { + constexpr uint32_t VEC_SIZE = 8; + constexpr uint32_t NUM_THRS_PER_ROW = HEAD_DIM_CKV / VEC_SIZE; + constexpr uint32_t ROWS_PER_ITER = MLA_HIP_NUM_THREADS / NUM_THRS_PER_ROW; + + uint32_t cta_idx = gridDim.x * blockIdx.y + blockIdx.x; + uint32_t thread_id = threadIdx.y * MLA_HIP_WAVE_SIZE + threadIdx.x; + + uint32_t offset_start = merge_packed_offset_start[cta_idx]; + uint32_t len = merge_packed_offset_end[cta_idx] - offset_start; + uint32_t partial_start = merge_partial_packed_offset_start[cta_idx]; + uint32_t partial_end = merge_partial_packed_offset_end[cta_idx]; + uint32_t stride = merge_partial_stride[cta_idx]; + + for (uint32_t local_row = thread_id / NUM_THRS_PER_ROW; local_row < len; + local_row += ROWS_PER_ITER) { + uint32_t global_packed = offset_start + local_row; + uint32_t q, r; + num_heads.divmod(global_packed, q, r); + uint32_t thr_col = thread_id % NUM_THRS_PER_ROW; + + // Accumulate state over partial blocks + // State representation: (st_m, st_d, st_o) where st_m tracks running max of partial_lse, + // st_d accumulates exp-weighted partial sums, st_o is the weighted output sum. + float st_o[VEC_SIZE]; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) st_o[i] = 0.f; + float st_m = MLA_HIP_NEG_INF; + float st_d = 1.f; + + for (uint32_t pp = partial_start + local_row; pp < partial_end; pp += stride) { + float other_lse = partial_lse[pp]; // = log2(d_partial) + m_partial + const DTypeO* src = partial_o + (uint64_t)pp * HEAD_DIM_CKV + thr_col * VEC_SIZE; + + float new_m = fmaxf(st_m, other_lse); + float scale_st = math::ptx_exp2(st_m - new_m); + float scale_other = math::ptx_exp2(other_lse - new_m); + st_d = st_d * scale_st + scale_other; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) + st_o[i] = st_o[i] * scale_st + static_cast(src[i]) * scale_other; + st_m = new_m; + } + + float d_rcp = (st_m != MLA_HIP_NEG_INF) ? math::ptx_rcp(st_d) : 0.f; + DTypeO* dst = final_o + q * o_stride_n + r * o_stride_h + thr_col * VEC_SIZE; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) dst[i] = static_cast(st_o[i] * d_rcp); + + if (final_lse && thr_col == 0) { + float lse = st_m + math::ptx_log2(st_d); + if (return_lse_base_on_e) lse *= math::loge2; + final_lse[q * (uint32_t)num_heads + r] = lse; + } + } +} + +template +__global__ __launch_bounds__(MLA_HIP_NUM_THREADS) void BatchMLAPagedAttentionKernelHIP( + Params params) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + + constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16; + constexpr uint32_t NUM_MMA_D_KPE = HEAD_DIM_KPE / 16; + constexpr uint32_t NUM_MMA_D_PER_WAVE = NUM_MMA_D_CKV / MLA_HIP_NUM_WAVES; + static_assert(NUM_MMA_D_CKV % MLA_HIP_NUM_WAVES == 0, + "HEAD_DIM_CKV must be divisible by 4 * 16 = 64"); + + uint32_t lane_idx = threadIdx.x; // 0..63 + uint32_t wave_idx = threadIdx.y; // 0..3 + uint32_t tid = wave_idx * MLA_HIP_WAVE_SIZE + lane_idx; // 0..255 + + extern __shared__ uint8_t smem[]; + using SharedStorage = SharedStorageMLAHIP; + SharedStorage& smem_storage = *reinterpret_cast(smem); + + const float sm_scale_log2 = params.sm_scale * math::log2e; + const uint_fastdiv& num_heads = params.num_heads; + const uint_fastdiv& block_size = params.block_size; + const uint32_t block_size_val = static_cast(block_size); + + const uint32_t q_row = lane_idx % 16; + const uint32_t col_group = lane_idx / 16; + + float s_frag[4]; + float o_frag[NUM_MMA_D_PER_WAVE][4]; + float m_val, d_scalar; + uint32_t q_pe_frag[NUM_MMA_D_KPE][2]; + uint32_t q_nope_frag[NUM_MMA_D_CKV][2]; + + for (IdType work_idx = params.work_indptr[blockIdx.y]; + work_idx < params.work_indptr[blockIdx.y + 1]; ++work_idx) { + const uint32_t q_indptr = params.q_indptr[work_idx]; + const uint32_t kv_indptr = params.kv_indptr[work_idx]; + const int32_t partial_indptr = params.partial_indptr[work_idx]; + const uint32_t q_len = params.q_len[work_idx]; + const uint32_t kv_len = params.kv_len[work_idx]; + const uint32_t kv_start = params.kv_start[work_idx]; + const uint32_t kv_end = params.kv_end[work_idx]; + const uint32_t packed_qo_start = params.q_start[work_idx]; + + const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * MLA_HIP_CTA_TILE_Q; + + const uint32_t q_packed = qo_packed_idx_base + q_row; + uint32_t batch_idx, head_idx; + num_heads.divmod(q_packed, batch_idx, head_idx); + const bool q_valid = (batch_idx < q_len); + + load_q_frags_hip( + q_pe_frag, q_nope_frag, params.q_nope + q_indptr * params.q_nope_stride_n, + params.q_pe + q_indptr * params.q_pe_stride_n, params.q_nope_stride_n, + params.q_nope_stride_h, params.q_pe_stride_n, params.q_pe_stride_h, batch_idx, head_idx, + q_valid, col_group); + + m_val = MLA_HIP_NEG_INF; + d_scalar = 1.f; +#pragma unroll + for (uint32_t d = 0; d < NUM_MMA_D_PER_WAVE; ++d) +#pragma unroll + for (uint32_t r = 0; r < 4; ++r) o_frag[d][r] = 0.f; + + const uint32_t packed_kv_bound = kv_indptr * block_size_val + kv_len; + const int32_t num_kv_tiles = + static_cast(ceil_div(kv_end - kv_start, MLA_HIP_CTA_TILE_KV)); + + // Double-buffered pipeline: stage 0 loaded outside the loop, then each + // iteration N loads stage[(N+1)%2] while computing on stage[N%2]. The hope + // is the AMD compiler interleaves the global→VGPR→LDS load chain with the + // MFMAs on the previous tile's data, hiding global-load latency. + auto kv_tile_to_packed_base = [&](int32_t kv_tile_idx) { + return kv_indptr * block_size_val + kv_start + + static_cast(kv_tile_idx) * MLA_HIP_CTA_TILE_KV; + }; + auto kv_tile_to_abs_start = [&](int32_t kv_tile_idx) { + return kv_start + static_cast(kv_tile_idx) * MLA_HIP_CTA_TILE_KV; + }; + + // Prologue: load tile (num_kv_tiles - 1) into stage 0. + { + const int32_t first_tile = num_kv_tiles - 1; + load_kv_hip( + smem_storage.ckv_smem[0][0], smem_storage.kpe_smem[0][0], params.ckv, params.kpe, + params.kv_indices, params.ckv_stride_page, params.ckv_stride_n, params.kpe_stride_page, + params.kpe_stride_n, kv_tile_to_packed_base(first_tile), packed_kv_bound, block_size, + tid); + } + + uint32_t cur_stage = 0; + for (int32_t kv_tile_idx = num_kv_tiles - 1; kv_tile_idx >= 0; --kv_tile_idx) { + const uint32_t next_stage = 1 - cur_stage; + + // Issue load for tile (kv_tile_idx - 1) into next_stage. Without + // cp.async, this still blocks the wave but lets the compiler schedule + // the global VGPR loads ahead of compute. + if (kv_tile_idx > 0) { + load_kv_hip( + smem_storage.ckv_smem[next_stage][0], smem_storage.kpe_smem[next_stage][0], params.ckv, + params.kpe, params.kv_indices, params.ckv_stride_page, params.ckv_stride_n, + params.kpe_stride_page, params.kpe_stride_n, kv_tile_to_packed_base(kv_tile_idx - 1), + packed_kv_bound, block_size, tid); + } + + __syncthreads(); + + if (wave_idx == 0) { + compute_qk_hip( + s_frag, q_pe_frag, q_nope_frag, smem_storage.ckv_smem[cur_stage][0], + smem_storage.kpe_smem[cur_stage][0], q_row, col_group); + logits_mask_hip(s_frag, qo_packed_idx_base, kv_tile_to_abs_start(kv_tile_idx), + q_len, kv_len, kv_end, num_heads, lane_idx); + broadcast_softmax_hip(s_frag, o_frag, m_val, d_scalar, + smem_storage.wave_spec, q_row, col_group, + sm_scale_log2); + } + __syncthreads(); // wave_spec visible to all waves + if (wave_idx > 0) { + read_softmax_hip(s_frag, o_frag, m_val, d_scalar, + smem_storage.wave_spec, q_row, col_group); + } + + compute_pv_hip( + o_frag, s_frag, smem_storage.ckv_smem[cur_stage][0], wave_idx, lane_idx); + + // All threads must finish PV reads from ckv_smem[cur_stage] before any thread + // starts the next iteration's load_kv_hip into next_stage. On CDNA3 the four + // wavefronts are not lockstep across MFMA; without this barrier, faster waves + // can overwrite the buffer (now the upcoming next_stage) while slower waves + // still read it — wrong attention whenever num_kv_tiles >= 3. + __syncthreads(); + + cur_stage = next_stage; + } + + normalize_d_hip(o_frag, m_val, d_scalar); + + // finalize_m: scale the running max into log2 space so that + // lse = log2(d_scaled) + m_scaled = log2(sum exp(s * sm_scale)) + // is a true log-sum-exp in log2 base. The downstream merge relies on this. + float m_scaled = (m_val != MLA_HIP_NEG_INF) ? m_val * sm_scale_log2 : m_val; + + write_o_hip( + o_frag, m_scaled, d_scalar, params.final_o + q_indptr * params.o_stride_n, + params.final_lse ? params.final_lse + q_indptr * (uint32_t)num_heads : nullptr, + params.partial_o, params.partial_lse, params.o_stride_n, params.o_stride_h, q_len, + qo_packed_idx_base, partial_indptr, num_heads, params.return_lse_base_on_e, lane_idx, + wave_idx); + } + + gpu_iface::cg::this_grid().sync(); + + DevicePersistentMergeStatesHIP( + params.merge_packed_offset_start, params.merge_packed_offset_end, + params.merge_partial_packed_offset_start, params.merge_partial_packed_offset_end, + params.merge_partial_stride, params.partial_o, params.partial_lse, params.final_o, + params.final_lse, params.o_stride_n, params.o_stride_h, num_heads, + params.return_lse_base_on_e); +} + +template +hipError_t BatchMLAPagedAttentionHIP(Params params, uint32_t num_blks_x, uint32_t num_blks_y, + hipStream_t stream) { + if (MASK_MODE == MaskMode::kCustom) return hipErrorNotSupported; + constexpr bool CAUSAL = (MASK_MODE == MaskMode::kCausal); + + using SharedStorage = SharedStorageMLAHIP; + size_t smem_size = sizeof(SharedStorage); + + dim3 nblks(num_blks_x, num_blks_y); + dim3 nthrs(MLA_HIP_WAVE_SIZE, MLA_HIP_NUM_WAVES, 1); + + auto kernel = BatchMLAPagedAttentionKernelHIP; + hipError_t err = + hipFuncSetAttribute(reinterpret_cast(kernel), + hipFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); + if (err != hipSuccess) return err; + + void* args[] = {reinterpret_cast(¶ms)}; + return hipLaunchCooperativeKernel(reinterpret_cast(kernel), nblks, nthrs, args, + smem_size, stream); +} + +} // namespace mla +} // namespace flashinfer + +#endif // FLASHINFER_MLA_HIP_CUH_ diff --git a/include/flashinfer/attention/mla_params.cuh b/include/flashinfer/attention/mla_params.cuh index 6da1ed7f53..9187cfc687 100644 --- a/include/flashinfer/attention/mla_params.cuh +++ b/include/flashinfer/attention/mla_params.cuh @@ -15,7 +15,9 @@ */ #ifndef FLASHINFER_MLA_PARAMS_CUH_ #define FLASHINFER_MLA_PARAMS_CUH_ +#if !defined(__HIPCC__) && !defined(__HIP__) #include +#endif #include "../fastdiv.cuh" #include "../profiler.cuh" diff --git a/include/flashinfer/profiler.cuh b/include/flashinfer/profiler.cuh index 7a4be6ee41..008c2e4989 100644 --- a/include/flashinfer/profiler.cuh +++ b/include/flashinfer/profiler.cuh @@ -16,7 +16,9 @@ #ifndef FLASHINFER_PROFILER_CUH_ #define FLASHINFER_PROFILER_CUH_ +#if !defined(__HIPCC__) && !defined(__HIP__) #include +#endif namespace flashinfer { @@ -55,7 +57,11 @@ __device__ __forceinline__ uint32_t encode_tag(uint32_t sm_id, uint32_t block_id __device__ __forceinline__ uint32_t get_timestamp() { volatile uint32_t ret; +#if !defined(__HIPCC__) && !defined(__HIP__) asm volatile("mov.u32 %0, %globaltimer_lo;" : "=r"(ret)); +#else + ret = 0; +#endif return ret; } diff --git a/tests/rocm_tests/test_deepseek_mla_hip.py b/tests/rocm_tests/test_deepseek_mla_hip.py new file mode 100644 index 0000000000..a1b3f34be8 --- /dev/null +++ b/tests/rocm_tests/test_deepseek_mla_hip.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# DeepSeek-style HIP/ROCm tests for MLA: prefill (qo_len > 1) + decode (qo_len = 1), +# LSE return, pre-allocated output buffers, varying kv_lens per batch, +# fp16 + bf16. Mirrors the core checks in tests/attention/test_deepseek_mla.py +# but uses the HIP backend and the head dimensions DeepSeek actually ships. + +import math + +import pytest +import torch + +import flashinfer.mla +from flashinfer.jit import build_jit_specs, gen_batch_mla_module + +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 + + +def _attention_ref(batch_size, q, k, v, causal, sm_scale): + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_heads = q.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_heads, head_dim_qk).float(), + ) + * sm_scale + ) + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_heads, head_dim_vo) + .to(q) + ) + # convert lse from natural log to log2 to match the kernel's return_lse_base_on_e=False default + return o_ref, lse_ref * math.log2(math.e) + + +def _kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): + bs_page_num, page_size, ckv_dim = ckv.shape + page_num = bs_page_num // batch_size + _, _, kpe_dim = kpe.shape + ckv = ckv.view(batch_size, page_num * page_size, ckv_dim)[:, :kv_len, :] + kpe = kpe.view(batch_size, page_num * page_size, kpe_dim)[:, :kv_len, :] + k = ( + torch.cat([ckv, kpe], dim=-1) + .unsqueeze(2) + .repeat(1, 1, num_heads, 1) + .view(batch_size * kv_len, num_heads, ckv_dim + kpe_dim) + ) + v = ( + ckv.unsqueeze(2) + .repeat(1, 1, num_heads, 1) + .view(batch_size * kv_len, num_heads, ckv_dim) + ) + return k, v + + +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + build_jit_specs( + [ + gen_batch_mla_module( + "hip", + torch.float16, + torch.float16, + torch.float16, + torch.int32, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + False, + ), + gen_batch_mla_module( + "hip", + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + torch.int32, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + False, + ), + ], + verbose=False, + ) + yield + + +@pytest.mark.parametrize("batch_size", [1, 3, 7]) +@pytest.mark.parametrize( + "kv_len", + [ + 17, + 33, + 96, + 514, + 1024, + pytest.param(4096, marks=pytest.mark.slow), + pytest.param(8192, marks=pytest.mark.slow), + ], +) +@pytest.mark.parametrize("num_heads", [16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("page_size", [1, 16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_batch_mla_page_attention( + batch_size, kv_len, num_heads, causal, page_size, dtype +): + # The HIP MLA kernel uses CTA_TILE_Q = 16, while the MLA planner assumes + # CTA_TILE_Q = 64. For decode (packed_qo_len = qo_len * num_heads <= 16) + # the planner generates one work item per batch and the CTA covers it exactly. + # Prefill (qo_len > 1) requires an inner Q-sub-tile loop — not yet implemented. + qo_len = 1 + if causal and qo_len > kv_len: + pytest.skip("qo_len > kv_len not supported for causal attention") + device = torch.device("cuda") + torch.manual_seed(42) + + q_nope = torch.randn( + batch_size * qo_len, num_heads, HEAD_DIM_CKV, dtype=dtype, device=device + ) + q_pe = torch.randn( + batch_size * qo_len, num_heads, HEAD_DIM_KPE, dtype=dtype, device=device + ) + pages_per_req = math.ceil(kv_len / page_size) + ckv = torch.randn( + batch_size * pages_per_req, page_size, HEAD_DIM_CKV, dtype=dtype, device=device + ) + kpe = torch.randn( + batch_size * pages_per_req, page_size, HEAD_DIM_KPE, dtype=dtype, device=device + ) + + sm_scale = 1.0 / ((128 + 64) ** 0.5) # head dim before matrix absorption + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") + + q_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) + * pages_per_req + ) + kv_indices = torch.arange( + 0, batch_size * pages_per_req, dtype=torch.int32, device=device + ) + kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device=device) + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + num_heads, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + page_size, + causal, + sm_scale, + q_nope.dtype, + ckv.dtype, + ) + o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) + + k, v = _kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) + q = torch.cat([q_nope, q_pe], dim=-1) + o_ref, lse_ref = _attention_ref(batch_size, q, k, v, causal, sm_scale) + lse_ref = lse_ref.flatten(0, 1) + + rtol, atol = (1e-3, 1e-3) if dtype == torch.float16 else (1e-2, 1e-2) + torch.testing.assert_close(o, o_ref, rtol=rtol, atol=atol) + torch.testing.assert_close(lse, lse_ref, rtol=rtol, atol=atol) + + # Pre-allocated output buffers must produce identical results. + o_buf = torch.empty_like(o) + lse_buf = torch.empty_like(lse) + wrapper.run(q_nope, q_pe, ckv, kpe, out=o_buf, lse=lse_buf) + torch.testing.assert_close(o, o_buf, rtol=rtol, atol=atol) + torch.testing.assert_close(lse, lse_buf, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batch_size", [3, 5]) +@pytest.mark.parametrize( + "kv_lens_list", + [ + [17, 33, 79], + pytest.param([96, 514, 2048], marks=pytest.mark.slow), + pytest.param([128, 256, 512, 1024, 2048], marks=pytest.mark.slow), + ], + ids=lambda v: "x".join(str(x) for x in v), +) +@pytest.mark.parametrize("causal", [False, True]) +def test_batch_mla_varlen_kv(batch_size, kv_lens_list, causal): + """Each request in the batch has a different kv_len (still page_size=1).""" + qo_len = 1 + if causal and qo_len > min(kv_lens_list): + pytest.skip("qo_len > min(kv_len) not supported for causal attention") + device = torch.device("cuda") + torch.manual_seed(0) + + num_heads = 16 + page_size = 1 + dtype = torch.float16 + kv_lens_full = kv_lens_list * batch_size + n_req = len(kv_lens_full) + pages = [math.ceil(kv / page_size) for kv in kv_lens_full] + total_pages = sum(pages) + + q_nope = torch.randn( + n_req * qo_len, num_heads, HEAD_DIM_CKV, dtype=dtype, device=device + ) + q_pe = torch.randn( + n_req * qo_len, num_heads, HEAD_DIM_KPE, dtype=dtype, device=device + ) + ckv = torch.randn(total_pages, page_size, HEAD_DIM_CKV, dtype=dtype, device=device) + kpe = torch.randn(total_pages, page_size, HEAD_DIM_KPE, dtype=dtype, device=device) + + sm_scale = 1.0 / ((128 + 64) ** 0.5) + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") + + q_indptr = torch.arange(0, n_req + 1, dtype=torch.int32, device=device) * qo_len + page_indptr = torch.zeros(n_req + 1, dtype=torch.int32, device=device) + page_indptr[1:] = torch.tensor(pages, dtype=torch.int32, device=device).cumsum(0) + kv_indices = torch.arange(0, total_pages, dtype=torch.int32, device=device) + kv_lens = torch.tensor(kv_lens_full, dtype=torch.int32, device=device) + + wrapper.plan( + q_indptr, + page_indptr, + kv_indices, + kv_lens, + num_heads, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + page_size, + causal, + sm_scale, + q_nope.dtype, + ckv.dtype, + ) + o = wrapper.run(q_nope, q_pe, ckv, kpe) + + # Per-request reference (each request has its own kv_len so we can't batch + # them in a single attention_ref call with padding zeros — that would shift + # the softmax denominator. Compute and concatenate.) + out_chunks = [] + page_offset = 0 + qo_offset = 0 + for kv_len in kv_lens_full: + k_p = ckv[page_offset : page_offset + kv_len].view(1, kv_len, HEAD_DIM_CKV) + kpe_p = kpe[page_offset : page_offset + kv_len].view(1, kv_len, HEAD_DIM_KPE) + page_offset += kv_len + q_n = q_nope[qo_offset : qo_offset + qo_len] + q_p = q_pe[qo_offset : qo_offset + qo_len] + qo_offset += qo_len + if kv_len == 0: + out_chunks.append(torch.zeros_like(q_n)) + continue + # k = [ckv ; kpe] repeated across heads; v = ckv repeated across heads. + k = torch.cat([k_p, kpe_p], dim=-1).expand( + 1, kv_len, HEAD_DIM_CKV + HEAD_DIM_KPE + ) + k = ( + k.unsqueeze(2) + .repeat(1, 1, num_heads, 1) + .view(kv_len, num_heads, HEAD_DIM_CKV + HEAD_DIM_KPE) + ) + v = k_p.expand(1, kv_len, HEAD_DIM_CKV) + v = ( + v.unsqueeze(2) + .repeat(1, 1, num_heads, 1) + .view(kv_len, num_heads, HEAD_DIM_CKV) + ) + q = torch.cat([q_n, q_p], dim=-1) + ref, _ = _attention_ref(1, q, k, v, causal, sm_scale) + out_chunks.append(ref) + o_ref = torch.cat(out_chunks, dim=0) + + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) diff --git a/tests/rocm_tests/test_mla_hip.py b/tests/rocm_tests/test_mla_hip.py new file mode 100644 index 0000000000..55a91674b2 --- /dev/null +++ b/tests/rocm_tests/test_mla_hip.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +# HIP/ROCm tests for MLA (Multi-head Latent Attention) batch paged attention. + +import math + +import pytest +import torch + +import flashinfer.mla +from flashinfer.jit import build_jit_specs, gen_batch_mla_module +from flashinfer.utils import determine_mla_backend + +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 +NUM_HEADS = 16 + + +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + build_jit_specs( + [ + gen_batch_mla_module( + "hip", + torch.float16, + torch.float16, + torch.float16, + torch.int32, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + False, + ) + ], + verbose=False, + ) + yield + + +def _make_wrapper(): + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") + return flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto") + + +def _plan(wrapper, batch_size, kv_len, page_size, causal, dtype): + qo_len = 1 # decode: one query token per request + pages_per_req = math.ceil(kv_len / page_size) + q_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") + * pages_per_req + ) + kv_indices = torch.arange( + 0, batch_size * pages_per_req, dtype=torch.int32, device="cuda" + ) + kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda") + sm_scale = 1.0 / ((HEAD_DIM_CKV + HEAD_DIM_KPE) ** 0.5) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + NUM_HEADS, + HEAD_DIM_CKV, + HEAD_DIM_KPE, + page_size, + causal, + sm_scale, + dtype, + dtype, + ) + return kv_lens, pages_per_req + + +def _mla_reference(q_nope, q_pe, ckv, kpe, kv_lens, page_size, sm_scale, _causal): + """Pure-PyTorch MLA reference: S = q_pe @ kpe^T + q_nope @ ckv^T; O = softmax(S) @ ckv.""" + batch_size, num_heads, _ = q_nope.shape + dtype = q_nope.dtype + + max_kv_len = int(kv_lens.max().item()) + # pages are laid out sequentially per request + pages_per_req = math.ceil(max_kv_len / page_size) + + ckv_flat = ckv.reshape(batch_size, pages_per_req * page_size, HEAD_DIM_CKV)[ + :, :max_kv_len, : + ] + kpe_flat = kpe.reshape(batch_size, pages_per_req * page_size, HEAD_DIM_KPE)[ + :, :max_kv_len, : + ] + + # decode is qo_len=1 so causal masking is a no-op; only kv-length padding matters. + scores = torch.einsum("bhd,bsd->bhs", q_pe.float(), kpe_flat.float()) + scores += torch.einsum("bhd,bsd->bhs", q_nope.float(), ckv_flat.float()) + scores = scores * sm_scale + + for b in range(batch_size): + kl = int(kv_lens[b].item()) + if kl < max_kv_len: + scores[b, :, kl:] = float("-inf") + + weights = torch.softmax(scores, dim=-1) + out = torch.einsum("bhs,bsd->bhd", weights, ckv_flat.float()) + return out.to(dtype) + + +def test_determine_mla_backend(): + device = torch.device("cuda") + backend = determine_mla_backend(device) + assert backend == "hip", f"expected 'hip', got {backend!r}" + + +@pytest.mark.parametrize("batch_size", [1, 4, 16]) +@pytest.mark.parametrize("kv_len", [1, 64, 512]) +@pytest.mark.parametrize("page_size", [1, 16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_batch_mla_plan(batch_size, kv_len, page_size, causal, dtype): + if causal and kv_len == 0: + pytest.skip("causal with kv_len=0 unsupported") + wrapper = _make_wrapper() + _plan(wrapper, batch_size, kv_len, page_size, causal, dtype) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_len", [64, 256]) +@pytest.mark.parametrize("page_size", [16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_batch_mla_correctness(batch_size, kv_len, page_size, causal, dtype): + torch.manual_seed(42) + wrapper = _make_wrapper() + sm_scale = 1.0 / ((HEAD_DIM_CKV + HEAD_DIM_KPE) ** 0.5) + kv_lens, pages_per_req = _plan( + wrapper, batch_size, kv_len, page_size, causal, dtype + ) + + q_nope = torch.randn( + batch_size, NUM_HEADS, HEAD_DIM_CKV, dtype=dtype, device="cuda" + ) + q_pe = torch.randn(batch_size, NUM_HEADS, HEAD_DIM_KPE, dtype=dtype, device="cuda") + ckv = torch.randn( + batch_size * pages_per_req, page_size, HEAD_DIM_CKV, dtype=dtype, device="cuda" + ) + kpe = torch.randn( + batch_size * pages_per_req, page_size, HEAD_DIM_KPE, dtype=dtype, device="cuda" + ) + + out = wrapper.run(q_nope, q_pe, ckv, kpe) + + ref = _mla_reference( + q_nope, + q_pe, + ckv, + kpe, + kv_lens, + page_size, + sm_scale, + causal, + ) + + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) diff --git a/tests/rocm_tests/test_mla_page_hip.py b/tests/rocm_tests/test_mla_page_hip.py new file mode 100644 index 0000000000..44a132cf01 --- /dev/null +++ b/tests/rocm_tests/test_mla_page_hip.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# HIP/ROCm test for append_paged_mla_kv_cache (MLA KV-cache write path). +# Mirrors tests/attention/test_mla_page.py. + +import math + +import pytest +import torch + +import flashinfer + +CKV_DIM = 512 +KPE_DIM = 64 + + +def _last_page_lens(kv_lens, page_size): + return [kl % page_size if kl % page_size != 0 else page_size for kl in kv_lens] + + +_KV_LEN_CONFIGS = [ + [45], + [4096], + [45, 8, 25], + [45, 8, 25, 22], + [45, 8, 25, 22, 400], + [45, 8, 25, 22, 100], +] + + +@pytest.mark.parametrize("kv_lens", _KV_LEN_CONFIGS) +@pytest.mark.parametrize("page_size", [1, 16, 64]) +def test_append_mla_paged_kv_cache(kv_lens, page_size): + device = torch.device("cuda") + nnz_kv = sum(kv_lens) + ckv_append = torch.randn(nnz_kv, CKV_DIM, dtype=torch.float16, device=device) + kpe_append = torch.randn(nnz_kv, KPE_DIM, dtype=torch.float16, device=device) + + num_pages_per_req = torch.tensor( + [math.ceil(kl / page_size) for kl in kv_lens], dtype=torch.int32, device=device + ) + kv_lens_t = torch.tensor(kv_lens, dtype=torch.int32, device=device) + kv_append_indptr = torch.cat( + [torch.zeros(1, dtype=torch.int32, device=device), kv_lens_t.cumsum(0).int()] + ) + + max_num_pages = int(num_pages_per_req.sum().item()) + ckv_cache = torch.zeros( + max_num_pages, page_size, CKV_DIM, dtype=torch.float16, device=device + ) + kpe_cache = torch.zeros( + max_num_pages, page_size, KPE_DIM, dtype=torch.float16, device=device + ) + + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + num_pages_per_req.cumsum(0).int(), + ] + ) + kv_page_indices = torch.arange(max_num_pages, dtype=torch.int32, device=device) + kv_last_page_len = torch.tensor( + _last_page_lens(kv_lens, page_size), dtype=torch.int32, device=device + ) + + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, + flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size), + nnz_kv, + ) + flashinfer.append_paged_mla_kv_cache( + ckv_append, + kpe_append, + batch_indices, + positions, + ckv_cache, + kpe_cache, + kv_page_indices, + kv_page_indptr, + kv_last_page_len, + ) + + ckv_flat = ckv_cache.view(-1, CKV_DIM) + kpe_flat = kpe_cache.view(-1, KPE_DIM) + + acc_kv = 0 + acc_pad = 0 + for i, kl in enumerate(kv_lens): + pages_i = int(num_pages_per_req[i].item()) + torch.testing.assert_close( + ckv_append[acc_kv : acc_kv + kl], ckv_flat[acc_pad : acc_pad + kl] + ) + torch.testing.assert_close( + kpe_append[acc_kv : acc_kv + kl], kpe_flat[acc_pad : acc_pad + kl] + ) + assert torch.all(ckv_flat[acc_pad + kl : acc_pad + pages_i * page_size] == 0) + assert torch.all(kpe_flat[acc_pad + kl : acc_pad + pages_i * page_size] == 0) + acc_kv += kl + acc_pad += pages_i * page_size