Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ def mla_decode_fwd(
and kv_buffer.dtype == dtypes.fp8
and max_seqlen_q == 1
)
or (
get_gfx() == "gfx950"
and nhead == 32
and q.dtype == dtypes.fp8
and kv_buffer.dtype == dtypes.fp8
and max_seqlen_q == 1
)
or (
get_gfx() == "gfx950"
and q.dtype == dtypes.bf16
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba

const bool natively_supported =
(num_heads == 16) ||
((arch_id == "gfx950") && (num_heads == 32) && q_is_fp8 && kv_is_fp8 &&
(max_seqlen_qo == 1)) ||
((arch_id == "gfx950") && (num_heads == 32) && q_is_fp8 && kv_is_fp8 &&
(max_seqlen_qo == 2)) ||
((arch_id == "gfx950") && (num_heads == 32) && q_is_fp8 && kv_is_fp8 &&
Expand Down
7 changes: 5 additions & 2 deletions csrc/py_itfs_cu/asm_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,12 @@ void mla_decode_stage1_asm_fwd(
} else if((max_seqlen_q == 2) && persistent){
config_max_seqlen_q = 2;
sub_Q = 128;
} else if((max_seqlen_q == 1) && persistent){
config_max_seqlen_q = 1;
sub_Q = 32;
} else {
AITER_CHECK(false, __func__,
": fp8/fp8 with gqa_ratio=32 only supports decode_qlen=2,4 in persistent mode");
AITER_CHECK(false, __func__,
": fp8/fp8 with gqa_ratio=32 only supports decode_qlen=1,2,4 in persistent mode");
}
}
} else if (gqa_ratio == 64){
Expand Down
Binary file not shown.
1 change: 1 addition & 0 deletions hsa/gfx950/mla/mla_asm.csv
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fp8,fp8,16,0,1,0,0,0,_ZN5aiter33mla_a8w8_qh16_qseqlen1_gqaratio16E,mla_a8w8_qh16
fp8,fp8,16,0,2,0,0,0,_ZN5aiter33mla_a8w8_qh16_qseqlen2_gqaratio16E,mla_a8w8_qh16_qseqlen2_gqaratio16.co
fp8,fp8,16,0,4,0,0,0,_ZN5aiter33mla_a8w8_qh64_qseqlen4_gqaratio16E,mla_a8w8_qh64_qseqlen4_gqaratio16.co
fp8,fp8,32,1,4,0,0,0,_ZN5aiter36mla_a8w8_qh32_qseqlen4_gqaratio32_psE,mla_a8w8_qh32_qseqlen4_gqaratio32_ps.co
fp8,fp8,32,1,1,0,0,0,_ZN5aiter36mla_a8w8_qh32_qseqlen1_gqaratio32_psE,mla_a8w8_qh32_qseqlen1_gqaratio32_ps.co
fp8,fp8,32,1,2,0,0,0,_ZN5aiter39mla_a8w8_qh32_qseqlen2_gqaratio32_v3_psE,mla_a8w8_qh32_qseqlen2_gqaratio32_v3_ps.co
fp8,fp8,32,1,2,0,0,1,_ZN5aiter43mla_a8w8_qh32_qseqlen2_gqaratio32_lse_v3_psE,mla_a8w8_qh32_qseqlen2_gqaratio32_lse_v3_ps.co
fp8,fp8,128,0,0,0,0,0,_ZN5aiter31mla_a8w8_qh128_m32x4_n16x2_msk1E,mla_a8w8_qh128_m32x4_n16x2_msk1.co
Expand Down
2 changes: 1 addition & 1 deletion op_tests/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_absorb_decode_gluon():
128,
]:
err, us_asm_decode = test_absorb_decode_bf16()
elif kvtype == dtypes.fp8 and nhead in [8, 16, 128]:
elif kvtype == dtypes.fp8 and nhead in [8, 16, 32, 128]:
err, us_asm_decode = test_absorb_decode_fp8()

ret["decode:err"] = err
Expand Down
7 changes: 7 additions & 0 deletions op_tests/test_mla_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,13 @@ def torch_mla_extend_split_kv(
and is_fp8_kvc
and max_seqlen_q == 2
)
or (
get_gfx() == "gfx950"
and nheads == 32
and is_fp8_q
and is_fp8_kvc
and max_seqlen_q == 1
)
or (
get_gfx() == "gfx950"
and nheads == 8
Expand Down
Loading