Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
171dd62
fix(hip): guard cuda.h includes in mla_params.cuh and profiler.cuh
demandal25 May 5, 2026
cf542f1
feat(hip): route MLA backend to HIP path in utils and mla.py
demandal25 May 5, 2026
bc739da
feat(hip): add JIT module generator for batch MLA attention
demandal25 May 5, 2026
8df86cc
feat(hip): add batch MLA kernel source files (Phase 1: Plan only)
demandal25 May 5, 2026
d7a0e3f
test(hip): add test_mla_hip.py for Phase 1 MLA plan validation
demandal25 May 12, 2026
b40b435
feat(hip): implement MLA attention Run path with CDNA3 MFMA correctness
demandal25 May 12, 2026
33eeb2a
perf(hip): MLA Phase 2a — uint4 KV loads, LDS pad, hoisted Q & divmod
demandal25 May 12, 2026
bf6629c
perf(hip): MLA Phase 2b — double-buffered KV pipeline
demandal25 May 12, 2026
ba10f70
feat(hip): MLA Phase 3 — extended tests, AOT bake, README update
demandal25 May 12, 2026
1cc73c1
fix(mypy): add partial_state to BatchPrefillWithPagedKVCacheWrapper.r…
demandal25 May 18, 2026
c4ebae7
refactor(hip): production cleanup of MLA kernel and tests
demandal25 May 18, 2026
1a0b5bc
perf(hip): MLA Phase 2b — wave-specialized QK and Phase 3 completion
demandal25 May 18, 2026
f41fee0
fix(hip): MLA cleanup and double-buffer race fix
demandal25 May 19, 2026
29747bc
fix(hip): address PR review nits
demandal25 May 19, 2026
5fb346f
style(hip): trim what-commentary and inline single-use variable in ml…
demandal25 May 19, 2026
05ed67b
refactor(hip): inline stride locals in batch_mla.cu
demandal25 May 19, 2026
4642d3f
style(hip): consolidate deferred imports in aot_hip.py and utils.py
demandal25 May 19, 2026
49e722e
style(hip): clean up HIP MLA tests and benchmark
demandal25 May 19, 2026
9e17034
fix(hip): reject use_profiler=True in gen_batch_mla_module on HIP
demandal25 May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ to its corresponding upstream tag (e.g., `0.2.5+amd.2` is second release of amd-
| **Decode Attention** | ✅ | ✅ | No | Supports MHA, GQA, and MQA |
| **Prefill Attention** | ✅ | WIP | ✅ | Supports MHA, GQA, and MQA |
| **Cascade Attention** | TBD | TBD | No | Not Yet Ported |
| **MLA** | TBD | TBD | No | Not Yet Ported |
| **MLA** | | TBD | No | Decode-path (qo_len=1) on CDNA3; prefill TODO |
| **POD** | TBD | TBD | No | Not Yet Ported |
| **Positional Encoding** | TBD | TBD | No | Not Yet Ported |
| **Sampling** | ✅ | TBD | No | Supports Top-K/Top-P Sampling/OnlineSoftmax/SamplingFromLogits |
Expand Down
145 changes: 145 additions & 0 deletions benchmarks/rocm_benchmarks/bench_mla_hip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: Apache-2.0
#
# MLA decode benchmark for HIP/ROCm (DeepSeek shapes).
# Mirrors benchmarks/bench_deepseek_mla.py with backend="hip".
#
# Run:
# python benchmarks/rocm_benchmarks/bench_mla_hip.py
# python benchmarks/rocm_benchmarks/bench_mla_hip.py --batch 64 --seq 8192 --heads 128

import argparse
import math

import numpy as np
import torch

import flashinfer.mla
from flashinfer.jit import build_jit_specs, gen_batch_mla_module
from flashinfer.testing.utils import bench_gpu_time

HEAD_DIM_CKV = 512
HEAD_DIM_KPE = 64
_PAGE_SIZE = 1


def _warmup_jit(dtype: torch.dtype) -> None:
build_jit_specs(
[
gen_batch_mla_module(
"hip",
dtype,
dtype,
dtype,
torch.int32,
HEAD_DIM_CKV,
HEAD_DIM_KPE,
False,
)
],
verbose=False,
)


@torch.inference_mode()
def bench(batch_size: int, seq_len: int, num_heads: int, dtype: torch.dtype) -> None:
sm_scale = 1.0 / math.sqrt(HEAD_DIM_CKV + HEAD_DIM_KPE)

q_nope = torch.randn(
batch_size, num_heads, HEAD_DIM_CKV, dtype=dtype, device="cuda"
)
q_pe = torch.randn(batch_size, num_heads, HEAD_DIM_KPE, dtype=dtype, device="cuda")
ckv = torch.randn(
batch_size * seq_len, _PAGE_SIZE, HEAD_DIM_CKV, dtype=dtype, device="cuda"
)
kpe = torch.randn(
batch_size * seq_len, _PAGE_SIZE, HEAD_DIM_KPE, dtype=dtype, device="cuda"
)

workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend="auto")

q_indptr = torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda")
kv_indptr = (
torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * seq_len
)
kv_indices = torch.arange(0, batch_size * seq_len, dtype=torch.int32, device="cuda")
kv_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")

wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
num_heads,
HEAD_DIM_CKV,
HEAD_DIM_KPE,
_PAGE_SIZE,
False,
sm_scale,
dtype,
dtype,
)

o = wrapper.run(q_nope, q_pe, ckv, kpe)
assert o.shape == (batch_size, num_heads, HEAD_DIM_CKV)

measurements = bench_gpu_time(
lambda: wrapper.run(q_nope, q_pe, ckv, kpe),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

io_bytes = sum(t.numel() * t.element_size() for t in [q_nope, q_pe, ckv, kpe, o])
# Two GEMMs: Q·KVᵀ (ckv + kpe dim) and P·V (ckv dim).
flops = 2 * batch_size * num_heads * (2 * HEAD_DIM_CKV + HEAD_DIM_KPE) * seq_len

dtype_str = "fp16" if dtype == torch.float16 else "bf16"
print(
f"[{dtype_str}] batch={batch_size:>4d} seq={seq_len:>6d} heads={num_heads:>4d} | "
f"lat={ms * 1e3:.1f} µs | "
f"BW={io_bytes * 1e-6 / ms:.1f} GB/s | "
f"TFLOPS={flops * 1e-9 / ms:.3f}"
)


def main() -> None:
parser = argparse.ArgumentParser(description="MLA HIP decode benchmark")
parser.add_argument(
"--batch", type=int, default=0, help="single batch size (0 = sweep)"
)
parser.add_argument("--seq", type=int, default=0, help="single seq len (0 = sweep)")
parser.add_argument(
"--heads", type=int, default=16, help="number of attention heads"
)
parser.add_argument(
"--dtype",
choices=["fp16", "bf16", "both"],
default="fp16",
help="data type (default: fp16)",
)
args = parser.parse_args()

dtypes = []
if args.dtype in ("fp16", "both"):
dtypes.append(torch.float16)
if args.dtype in ("bf16", "both"):
dtypes.append(torch.bfloat16)

batch_sizes = [args.batch] if args.batch > 0 else [64, 128, 768]
seq_lens = [args.seq] if args.seq > 0 else [1024, 2048, 8192]

for dtype in dtypes:
print(
f"\n=== Warming up JIT ({('fp16' if dtype == torch.float16 else 'bf16')}) ==="
)
_warmup_jit(dtype)
print("=== Benchmarking ===")
for seq_len in seq_lens:
for batch_size in batch_sizes:
bench(batch_size, seq_len, args.heads, dtype)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@
from .norm import layernorm as layernorm
from .norm import rmsnorm as rmsnorm
from .page import append_paged_kv_cache as append_paged_kv_cache
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
from .page import get_batch_indices_positions as get_batch_indices_positions
from .page import get_seq_lens as get_seq_lens
from .prefill_rocm import ( # type: ignore[assignment]
Expand Down
15 changes: 15 additions & 0 deletions flashinfer/aot_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def gen_attention(
use_sliding_window_: List[bool],
use_logits_soft_cap_: List[bool],
) -> Iterator:
from .jit.attention import gen_batch_mla_module

# FA2 MHA / MQA / GQA
for (
(head_dim_qk, head_dim_vo),
Expand All @@ -118,6 +120,19 @@ def gen_attention(
use_logits_soft_cap=use_logits_soft_cap,
)

# MLA (DeepSeek shapes: head_dim_ckv=512, head_dim_kpe=64).
for dtype in f16_dtype_:
yield gen_batch_mla_module(
backend="hip",
dtype_q=dtype,
dtype_kv=dtype,
dtype_o=dtype,
dtype_idx=torch.int32,
head_dim_ckv=512,
head_dim_kpe=64,
use_profiler=False,
)


def gen_all_modules(
f16_dtype_: List[torch.dtype],
Expand Down
125 changes: 125 additions & 0 deletions flashinfer/csrc_rocm/batch_mla.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: Apache-2.0
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>

#include <flashinfer/attention/mla_hip.cuh>

#include "batch_mla_config.inc"
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor pin_memory_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t num_heads, int64_t head_dim_o,
bool causal) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

MLAPlanInfo plan_info;
int64_t batch_size = kv_len_arr.size(0);

const c10::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(float_workspace_buffer.device());
const hipStream_t stream = c10::hip::getCurrentHIPStream();

hipError_t status =
MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), pin_memory_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(qo_indptr.data_ptr()),
static_cast<IdType*>(kv_indptr.data_ptr()),
static_cast<IdType*>(kv_len_arr.data_ptr()), static_cast<uint32_t>(batch_size),
static_cast<uint32_t>(num_heads), static_cast<uint32_t>(head_dim_o), causal, stream);

TORCH_CHECK(status == hipSuccess,
"BatchMLAPagedAttentionPlan failed with error: ", hipGetErrorString(status));

return vec_to_tensor(plan_info.ToVector());
}

void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices,
at::Tensor o, std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code, int64_t num_heads, int64_t page_size,
double sm_scale, bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS) {
MLAPlanInfo plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));

void* float_buffer_ptr = float_workspace_buffer.data_ptr();
void* int_buffer_ptr = int_workspace_buffer.data_ptr();

const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);

const c10::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(q_nope.device());
const hipStream_t stream = c10::hip::getCurrentHIPStream();

DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] {
Params params;

params.q_nope = static_cast<DTypeQ*>(q_nope.data_ptr());
params.q_pe = static_cast<DTypeQ*>(q_pe.data_ptr());
params.ckv = static_cast<DTypeKV*>(ckv_cache.data_ptr());
params.kpe = static_cast<DTypeKV*>(kpe_cache.data_ptr());

params.q_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_indptr_offset);
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
params.partial_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.partial_indptr_offset);
params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr());
params.q_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_len_offset);
params.kv_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
params.q_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_start_offset);
params.kv_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_start_offset);
params.kv_end = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_end_offset);
params.work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.merge_packed_offset_start = GetPtrFromBaseOffset<IdType>(
int_buffer_ptr, plan_info.merge_packed_offset_start_offset);
params.merge_packed_offset_end =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_packed_offset_end_offset);
params.merge_partial_packed_offset_start = GetPtrFromBaseOffset<IdType>(
int_buffer_ptr, plan_info.merge_partial_packed_offset_start_offset);
params.merge_partial_packed_offset_end = GetPtrFromBaseOffset<IdType>(
int_buffer_ptr, plan_info.merge_partial_packed_offset_end_offset);
params.merge_partial_stride =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_partial_stride_offset);

params.final_o = static_cast<DTypeO*>(o.data_ptr());
params.final_lse =
maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value().data_ptr()) : nullptr;
params.partial_o =
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset);
params.partial_lse =
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);

params.num_heads = uint_fastdiv(num_heads);
params.block_size = uint_fastdiv(page_size);

params.q_nope_stride_n = q_nope.stride(0);
params.q_nope_stride_h = q_nope.stride(1);
params.q_pe_stride_n = q_pe.stride(0);
params.q_pe_stride_h = q_pe.stride(1);
params.ckv_stride_page = ckv_cache.stride(0);
params.ckv_stride_n = ckv_cache.stride(1);
params.kpe_stride_page = kpe_cache.stride(0);
params.kpe_stride_n = kpe_cache.stride(1);
params.o_stride_n = o.stride(0);
params.o_stride_h = o.stride(1);

params.sm_scale = static_cast<float>(sm_scale);
params.return_lse_base_on_e = return_lse_base_on_e;

ADDITIONAL_PARAMS_SETTER

hipError_t status = mla::BatchMLAPagedAttentionHIP<MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE>(
params, plan_info.num_blks_x, plan_info.num_blks_y, stream);
TORCH_CHECK(status == hipSuccess,
"BatchMLAPagedAttentionRun failed: ", hipGetErrorString(status));
});
}
31 changes: 31 additions & 0 deletions flashinfer/csrc_rocm/batch_mla_customize_config.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: Apache-2.0
//
// NOTE: This file is generated from batch_mla_customize_config.jinja.
// Do not edit manually.
#pragma once
#include <flashinfer/attention/generic/scheduler.cuh>

using namespace flashinfer;

#ifdef FLASHINFER_ENABLE_PROFILER
#define ADDITIONAL_FUNC_PARAMS , at::Tensor profiler_buffer
#define ADDITIONAL_PARAMS_SETTER \
params.profiler_buffer = static_cast<uint64_t*>(profiler_buffer.data_ptr());
#else
#define ADDITIONAL_FUNC_PARAMS
#define ADDITIONAL_PARAMS_SETTER
#endif

using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ dtype_idx }};
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};

#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
using Params = MLAParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
__VA_ARGS__(); \
})
23 changes: 23 additions & 0 deletions flashinfer/csrc_rocm/batch_mla_jit_pybind.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-FileCopyrightText: 2026 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: Apache-2.0
#include "batch_mla_config.inc"
#include "pytorch_extension_utils.h"

at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor pin_memory_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t num_heads, int64_t head_dim_o,
bool causal);

void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices,
at::Tensor o, std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code, int64_t num_heads, int64_t page_size,
double sm_scale, bool return_lse_base_on_e ADDITIONAL_FUNC_PARAMS);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("plan", BatchMLAPagedAttentionPlan);
m.def("run", BatchMLAPagedAttentionRun);
}
2 changes: 2 additions & 0 deletions flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def get_cudnn_fmha_gen_module():
from .activation import gen_act_and_mul_module as gen_act_and_mul_module
from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str
from .attention import gen_batch_decode_module as gen_batch_decode_module
from .attention import gen_batch_mla_module as gen_batch_mla_module
from .attention import get_batch_mla_uri as get_batch_mla_uri
from .attention import gen_batch_prefill_module as gen_batch_prefill_module
from .attention import (
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/jit/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
)
from .modules_hip import gen_single_decode_module as gen_single_decode_module
from .modules_hip import gen_single_prefill_module as gen_single_prefill_module
from .modules_hip import gen_batch_mla_module as gen_batch_mla_module
from .modules_hip import get_batch_mla_uri as get_batch_mla_uri
from .modules_hip import get_batch_decode_uri as get_batch_decode_uri
from .modules_hip import get_batch_prefill_uri as get_batch_prefill_uri
from .modules_hip import get_single_decode_uri as get_single_decode_uri
Expand Down
Loading
Loading