diff --git a/aiter/mla.py b/aiter/mla.py index 2b97c9bfac..926b4a517c 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -395,6 +395,13 @@ def mla_decode_fwd( and kv_buffer.dtype == dtypes.fp8 and max_seqlen_q == 1 ) + or ( + get_gfx() == "gfx950" + and nhead == 32 + and q.dtype == dtypes.fp8 + and kv_buffer.dtype == dtypes.fp8 + and max_seqlen_q == 1 + ) or ( get_gfx() == "gfx950" and q.dtype == dtypes.bf16 diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index 7f26f9c16c..6b3d3ed2d5 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -501,6 +501,8 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba const bool natively_supported = (num_heads == 16) || + ((arch_id == "gfx950") && (num_heads == 32) && q_is_fp8 && kv_is_fp8 && + (max_seqlen_qo == 1)) || ((arch_id == "gfx950") && (num_heads == 32) && q_is_fp8 && kv_is_fp8 && (max_seqlen_qo == 2)) || ((arch_id == "gfx950") && (num_heads == 32) && q_is_fp8 && kv_is_fp8 && diff --git a/csrc/py_itfs_cu/asm_mla.cu b/csrc/py_itfs_cu/asm_mla.cu index 8733f89713..ff2d9b8451 100644 --- a/csrc/py_itfs_cu/asm_mla.cu +++ b/csrc/py_itfs_cu/asm_mla.cu @@ -328,9 +328,12 @@ void mla_decode_stage1_asm_fwd( } else if((max_seqlen_q == 2) && persistent){ config_max_seqlen_q = 2; sub_Q = 128; + } else if((max_seqlen_q == 1) && persistent){ + config_max_seqlen_q = 1; + sub_Q = 32; } else { - AITER_CHECK(false, __func__, - ": fp8/fp8 with gqa_ratio=32 only supports decode_qlen=2,4 in persistent mode"); + AITER_CHECK(false, __func__, + ": fp8/fp8 with gqa_ratio=32 only supports decode_qlen=1,2,4 in persistent mode"); } } } else if (gqa_ratio == 64){ diff --git a/hsa/gfx950/mla/mla_a8w8_qh32_qseqlen1_gqaratio32_ps.co b/hsa/gfx950/mla/mla_a8w8_qh32_qseqlen1_gqaratio32_ps.co new file mode 100644 index 0000000000..3ff1ad2fbe Binary files /dev/null and b/hsa/gfx950/mla/mla_a8w8_qh32_qseqlen1_gqaratio32_ps.co differ diff --git a/hsa/gfx950/mla/mla_asm.csv b/hsa/gfx950/mla/mla_asm.csv index 6c27175923..99005f1cd4 100644 --- a/hsa/gfx950/mla/mla_asm.csv +++ b/hsa/gfx950/mla/mla_asm.csv @@ -16,6 +16,7 @@ fp8,fp8,16,0,1,0,0,0,_ZN5aiter33mla_a8w8_qh16_qseqlen1_gqaratio16E,mla_a8w8_qh16 fp8,fp8,16,0,2,0,0,0,_ZN5aiter33mla_a8w8_qh16_qseqlen2_gqaratio16E,mla_a8w8_qh16_qseqlen2_gqaratio16.co fp8,fp8,16,0,4,0,0,0,_ZN5aiter33mla_a8w8_qh64_qseqlen4_gqaratio16E,mla_a8w8_qh64_qseqlen4_gqaratio16.co fp8,fp8,32,1,4,0,0,0,_ZN5aiter36mla_a8w8_qh32_qseqlen4_gqaratio32_psE,mla_a8w8_qh32_qseqlen4_gqaratio32_ps.co +fp8,fp8,32,1,1,0,0,0,_ZN5aiter36mla_a8w8_qh32_qseqlen1_gqaratio32_psE,mla_a8w8_qh32_qseqlen1_gqaratio32_ps.co fp8,fp8,32,1,2,0,0,0,_ZN5aiter39mla_a8w8_qh32_qseqlen2_gqaratio32_v3_psE,mla_a8w8_qh32_qseqlen2_gqaratio32_v3_ps.co fp8,fp8,32,1,2,0,0,1,_ZN5aiter43mla_a8w8_qh32_qseqlen2_gqaratio32_lse_v3_psE,mla_a8w8_qh32_qseqlen2_gqaratio32_lse_v3_ps.co fp8,fp8,128,0,0,0,0,0,_ZN5aiter31mla_a8w8_qh128_m32x4_n16x2_msk1E,mla_a8w8_qh128_m32x4_n16x2_msk1.co diff --git a/op_tests/test_mla.py b/op_tests/test_mla.py index 8996e10ac3..50f15f31e9 100644 --- a/op_tests/test_mla.py +++ b/op_tests/test_mla.py @@ -515,7 +515,7 @@ def test_absorb_decode_gluon(): 128, ]: err, us_asm_decode = test_absorb_decode_bf16() - elif kvtype == dtypes.fp8 and nhead in [8, 16, 128]: + elif kvtype == dtypes.fp8 and nhead in [8, 16, 32, 128]: err, us_asm_decode = test_absorb_decode_fp8() ret["decode:err"] = err diff --git a/op_tests/test_mla_persistent.py b/op_tests/test_mla_persistent.py index c42449f3bb..25799dee6b 100644 --- a/op_tests/test_mla_persistent.py +++ b/op_tests/test_mla_persistent.py @@ -490,6 +490,13 @@ def torch_mla_extend_split_kv( and is_fp8_kvc and max_seqlen_q == 2 ) + or ( + get_gfx() == "gfx950" + and nheads == 32 + and is_fp8_q + and is_fp8_kvc + and max_seqlen_q == 1 + ) or ( get_gfx() == "gfx950" and nheads == 8