diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 7e0379e285..b06eb25dfa 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1299,13 +1299,15 @@ def is_fmha_v3_fp8(): ret = ret and ( q_descale is not None and k_descale is not None and v_descale is not None ) - # support per tensor and per head quant scale - ret = ret and ( + pertensor_or_perhead = ( q_descale.shape == (1,) or q_descale.shape == (batch_size, nhead_k) + ) and q_descale.shape == k_descale.shape and q_descale.shape == v_descale.shape + qkptph_vph = ( + q_descale.shape == (batch_size, nhead_q, seqlen_q) + and k_descale.shape == (batch_size, nhead_k, seqlen_k) + and v_descale.shape in ((nhead_k,), (batch_size, nhead_k)) ) - ret = ret and ( - q_descale.shape == k_descale.shape and q_descale.shape == v_descale.shape - ) + ret = ret and (pertensor_or_perhead or qkptph_vph) return ret def can_impl_fmha_v3_fwd(): diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index 1673fa6a05..bb57851b76 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -29,6 +29,7 @@ std::string get_kernel_name_key(const std::string& arch_id, int hdim_v, int mask_type, int bf16_cvt, + int qscale_type, bool mode, const CFG* cfgs) { @@ -42,7 +43,7 @@ std::string get_kernel_name_key(const std::string& arch_id, } if(cfg.dtype == data_type && cfg.hdim_q == hdim_q && cfg.hdim_v == hdim_v && - cfg.mask == mask_type && cfg.mode == mode) + cfg.mask == mask_type && cfg.qscale == qscale_type && cfg.mode == mode) { if(arch_id == "gfx950") { @@ -232,6 +233,7 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) a.hdim_v, cfg_mask_type, a.how_v3_bf16_cvt, + a.qscale_type, a.is_group_mode, fwd_cfgs); auto it = fwd_cfgs->find(kernel_name_key); @@ -374,7 +376,7 @@ float mha_fwd(mha_fwd_args args, const ck_tile::stream_config& s) #endif #if FAV2_ON - if(ret == -1 && !args.v3_api_check) + if(ret == -1 && !args.v3_api_check && args.qscale_type == 0) { ret = fmha_fwd_ck(args, s); } diff --git a/csrc/py_itfs_cu/asm_mha_fwd.cu b/csrc/py_itfs_cu/asm_mha_fwd.cu index d186279c5a..d8218059f4 100644 --- a/csrc/py_itfs_cu/asm_mha_fwd.cu +++ b/csrc/py_itfs_cu/asm_mha_fwd.cu @@ -80,6 +80,12 @@ mha_fwd_args get_asm_fmha_fwd_args(bool has_lse, ck_tile::index_t batch_stride_descale_k = 0; ck_tile::index_t batch_stride_descale_v = 0; + constexpr int asm_qscale_pertensor = 0; + constexpr int asm_qscale_qkptph_vph = 1; + const bool use_qkptph_vph = + q_descale_.has_value() && q_descale_.value().dim() == 3; + int asm_qscale_type = use_qkptph_vph ? asm_qscale_qkptph_vph : asm_qscale_pertensor; + void *q_descale_ptr = nullptr; void *k_descale_ptr = nullptr; void *v_descale_ptr = nullptr; @@ -104,30 +110,63 @@ mha_fwd_args get_asm_fmha_fwd_args(bool has_lse, if (q_descale_.has_value()) { auto q_descale = q_descale_.value(); CHECK_DEVICE(q_descale); - TORCH_CHECK(q_descale.sizes() == torch::IntArrayRef({1}) || q_descale.sizes() == torch::IntArrayRef({b, h_k})); - if (q_descale.dim() == 2) { + if (use_qkptph_vph) { + TORCH_CHECK(q_descale.sizes() == torch::IntArrayRef({b, h, seqlen_q}), + "q_descale for qkptph_vph must be [batch, q_heads, seqlen_q]"); + TORCH_CHECK(q_descale.stride(2) == 1, + "q_descale for qkptph_vph must be contiguous in token dimension"); batch_stride_descale_q = q_descale.stride(0); nhead_stride_descale_q = q_descale.stride(1); + } else { + TORCH_CHECK(q_descale.sizes() == torch::IntArrayRef({1}) || q_descale.sizes() == torch::IntArrayRef({b, h_k})); + if (q_descale.dim() == 2) { + batch_stride_descale_q = q_descale.stride(0); + nhead_stride_descale_q = q_descale.stride(1); + } } q_descale_ptr = q_descale.data_ptr(); } if (k_descale_.has_value()) { auto k_descale = k_descale_.value(); CHECK_DEVICE(k_descale); - TORCH_CHECK(k_descale.sizes() == torch::IntArrayRef({1}) || k_descale.sizes() == torch::IntArrayRef({b, h_k})); - if (k_descale.dim() == 2) { + if (use_qkptph_vph) { + TORCH_CHECK(k_descale.sizes() == torch::IntArrayRef({b, h_k, seqlen_k}), + "k_descale for qkptph_vph must be [batch, kv_heads, seqlen_k]"); + TORCH_CHECK(k_descale.stride(2) == 1, + "k_descale for qkptph_vph must be contiguous in token dimension"); batch_stride_descale_k = k_descale.stride(0); nhead_stride_descale_k = k_descale.stride(1); + } else { + TORCH_CHECK(k_descale.sizes() == torch::IntArrayRef({1}) || k_descale.sizes() == torch::IntArrayRef({b, h_k})); + if (k_descale.dim() == 2) { + batch_stride_descale_k = k_descale.stride(0); + nhead_stride_descale_k = k_descale.stride(1); + } } k_descale_ptr = k_descale.data_ptr(); } if (v_descale_.has_value()) { auto v_descale = v_descale_.value(); CHECK_DEVICE(v_descale); - TORCH_CHECK(v_descale.sizes() == torch::IntArrayRef({1}) || v_descale.sizes() == torch::IntArrayRef({b, h_k})); - if (v_descale.dim() == 2) { - batch_stride_descale_v = v_descale.stride(0); - nhead_stride_descale_v = v_descale.stride(1); + if (use_qkptph_vph) { + TORCH_CHECK(v_descale.sizes() == torch::IntArrayRef({h_k}) || + v_descale.sizes() == torch::IntArrayRef({b, h_k}), + "v_descale for qkptph_vph must be [kv_heads] or [batch, kv_heads]"); + TORCH_CHECK(v_descale.stride(-1) == 1, + "v_descale for qkptph_vph must be contiguous in head dimension"); + if (v_descale.dim() == 2) { + batch_stride_descale_v = v_descale.stride(0); + nhead_stride_descale_v = v_descale.stride(1); + } else { + batch_stride_descale_v = 0; + nhead_stride_descale_v = v_descale.stride(0); + } + } else { + TORCH_CHECK(v_descale.sizes() == torch::IntArrayRef({1}) || v_descale.sizes() == torch::IntArrayRef({b, h_k})); + if (v_descale.dim() == 2) { + batch_stride_descale_v = v_descale.stride(0); + nhead_stride_descale_v = v_descale.stride(1); + } } v_descale_ptr = v_descale.data_ptr(); } @@ -139,7 +178,7 @@ mha_fwd_args get_asm_fmha_fwd_args(bool has_lse, false, // is_group_mode static_cast(bias_type), has_lse, - 0, // qscale_type + asm_qscale_type, false, //has_sink q.data_ptr(), k.data_ptr(), @@ -312,7 +351,8 @@ std::vector fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d] // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_q % 8 == 0 && - !alibi_slopes_.has_value() && !bias_.has_value(); + !alibi_slopes_.has_value() && !bias_.has_value() && + !(is_qkv_fp8 && q_descale_.has_value() && q_descale_.value().dim() == 3); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size_q}).transpose(1, 2); diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_causal_qkptph_vph.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_causal_qkptph_vph.co new file mode 100755 index 0000000000..624fce143b Binary files /dev/null and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_causal_qkptph_vph.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_causal_qkptph_vph_group.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_causal_qkptph_vph_group.co new file mode 100755 index 0000000000..a4c42325cd Binary files /dev/null and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_causal_qkptph_vph_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_qkptph_vph.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_qkptph_vph.co new file mode 100755 index 0000000000..58144a2423 Binary files /dev/null and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_qkptph_vph.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_qkptph_vph_group.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_qkptph_vph_group.co new file mode 100755 index 0000000000..3424a00491 Binary files /dev/null and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_fp8_qkptph_vph_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/fmha_fwd.csv b/hsa/gfx942/fmha_v3_fwd/fmha_fwd.csv index e64533c7fd..72c08c0d42 100644 --- a/hsa/gfx942/fmha_v3_fwd/fmha_fwd.csv +++ b/hsa/gfx942/fmha_v3_fwd/fmha_fwd.csv @@ -1,29 +1,33 @@ -dtype,hdim_q,hdim_v,mask,mode,bf16_cvt,ts_qo,ts_kv,knl_name,co_name -bf16,128,128,0,0,0,256,32,_ZN5aiter24fmha_fwd_hd128_bf16_rtneE,fwd_hd128_bf16_rtne.co -bf16,128,128,0,0,1,256,32,_ZN5aiter24fmha_fwd_hd128_bf16_rtnaE,fwd_hd128_bf16_rtna.co -bf16,128,128,0,0,2,256,32,_ZN5aiter23fmha_fwd_hd128_bf16_rtzE,fwd_hd128_bf16_rtz.co -bf16,128,128,2,0,0,256,32,_ZN5aiter31fmha_fwd_hd128_bf16_causal_rtneE,fwd_hd128_bf16_causal_rtne.co -bf16,128,128,2,0,1,256,32,_ZN5aiter31fmha_fwd_hd128_bf16_causal_rtnaE,fwd_hd128_bf16_causal_rtna.co -bf16,128,128,2,0,2,256,32,_ZN5aiter30fmha_fwd_hd128_bf16_causal_rtzE,fwd_hd128_bf16_causal_rtz.co -bf16,128,128,0,1,0,256,32,_ZN5aiter30fmha_fwd_hd128_bf16_rtne_groupE,fwd_hd128_bf16_rtne_group.co -bf16,128,128,0,1,1,256,32,_ZN5aiter30fmha_fwd_hd128_bf16_rtna_groupE,fwd_hd128_bf16_rtna_group.co -bf16,128,128,0,1,2,256,32,_ZN5aiter29fmha_fwd_hd128_bf16_rtz_groupE,fwd_hd128_bf16_rtz_group.co -bf16,128,128,2,1,0,256,32,_ZN5aiter37fmha_fwd_hd128_bf16_causal_rtne_groupE,fwd_hd128_bf16_causal_rtne_group.co -bf16,128,128,2,1,1,256,32,_ZN5aiter37fmha_fwd_hd128_bf16_causal_rtna_groupE,fwd_hd128_bf16_causal_rtna_group.co -bf16,128,128,2,1,2,256,32,_ZN5aiter36fmha_fwd_hd128_bf16_causal_rtz_groupE,fwd_hd128_bf16_causal_rtz_group.co -bf16,192,128,0,0,0,128,32,_ZN5aiter28fmha_fwd_hd192x128_bf16_rtneE,fwd_hd192x128_bf16_rtne.co -bf16,192,128,0,0,1,128,32,_ZN5aiter28fmha_fwd_hd192x128_bf16_rtnaE,fwd_hd192x128_bf16_rtna.co -bf16,192,128,0,0,2,128,32,_ZN5aiter27fmha_fwd_hd192x128_bf16_rtzE,fwd_hd192x128_bf16_rtz.co -bf16,192,128,2,0,0,128,32,_ZN5aiter35fmha_fwd_hd192x128_bf16_causal_rtneE,fwd_hd192x128_bf16_causal_rtne.co -bf16,192,128,2,0,1,128,32,_ZN5aiter35fmha_fwd_hd192x128_bf16_causal_rtnaE,fwd_hd192x128_bf16_causal_rtna.co -bf16,192,128,2,0,2,128,32,_ZN5aiter34fmha_fwd_hd192x128_bf16_causal_rtzE,fwd_hd192x128_bf16_causal_rtz.co -bf16,192,128,0,1,0,128,32,_ZN5aiter34fmha_fwd_hd192x128_bf16_rtne_groupE,fwd_hd192x128_bf16_rtne_group.co -bf16,192,128,0,1,1,128,32,_ZN5aiter34fmha_fwd_hd192x128_bf16_rtna_groupE,fwd_hd192x128_bf16_rtna_group.co -bf16,192,128,0,1,2,128,32,_ZN5aiter33fmha_fwd_hd192x128_bf16_rtz_groupE,fwd_hd192x128_bf16_rtz_group.co -bf16,192,128,2,1,0,128,32,_ZN5aiter41fmha_fwd_hd192x128_bf16_causal_rtne_groupE,fwd_hd192x128_bf16_causal_rtne_group.co -bf16,192,128,2,1,1,128,32,_ZN5aiter41fmha_fwd_hd192x128_bf16_causal_rtna_groupE,fwd_hd192x128_bf16_causal_rtna_group.co -bf16,192,128,2,1,2,128,32,_ZN5aiter40fmha_fwd_hd192x128_bf16_causal_rtz_groupE,fwd_hd192x128_bf16_causal_rtz_group.co -fp8bf16,128,128,0,0,1,256,64,_ZN5aiter18fmha_fwd_hd128_fp8E,fwd_hd128_fp8.co -fp8bf16,128,128,2,0,1,256,64,_ZN5aiter25fmha_fwd_hd128_fp8_causalE,fwd_hd128_fp8_causal.co -fp8bf16,128,128,0,1,1,256,64,_ZN5aiter24fmha_fwd_hd128_fp8_groupE,fwd_hd128_fp8_group.co -fp8bf16,128,128,2,1,1,256,64,_ZN5aiter31fmha_fwd_hd128_fp8_causal_groupE,fwd_hd128_fp8_causal_group.co \ No newline at end of file +dtype,hdim_q,hdim_v,mask,mode,bf16_cvt,qscale,ts_qo,ts_kv,knl_name,co_name +bf16,128,128,0,0,0,0,256,32,_ZN5aiter24fmha_fwd_hd128_bf16_rtneE,fwd_hd128_bf16_rtne.co +bf16,128,128,0,0,1,0,256,32,_ZN5aiter24fmha_fwd_hd128_bf16_rtnaE,fwd_hd128_bf16_rtna.co +bf16,128,128,0,0,2,0,256,32,_ZN5aiter23fmha_fwd_hd128_bf16_rtzE,fwd_hd128_bf16_rtz.co +bf16,128,128,2,0,0,0,256,32,_ZN5aiter31fmha_fwd_hd128_bf16_causal_rtneE,fwd_hd128_bf16_causal_rtne.co +bf16,128,128,2,0,1,0,256,32,_ZN5aiter31fmha_fwd_hd128_bf16_causal_rtnaE,fwd_hd128_bf16_causal_rtna.co +bf16,128,128,2,0,2,0,256,32,_ZN5aiter30fmha_fwd_hd128_bf16_causal_rtzE,fwd_hd128_bf16_causal_rtz.co +bf16,128,128,0,1,0,0,256,32,_ZN5aiter30fmha_fwd_hd128_bf16_rtne_groupE,fwd_hd128_bf16_rtne_group.co +bf16,128,128,0,1,1,0,256,32,_ZN5aiter30fmha_fwd_hd128_bf16_rtna_groupE,fwd_hd128_bf16_rtna_group.co +bf16,128,128,0,1,2,0,256,32,_ZN5aiter29fmha_fwd_hd128_bf16_rtz_groupE,fwd_hd128_bf16_rtz_group.co +bf16,128,128,2,1,0,0,256,32,_ZN5aiter37fmha_fwd_hd128_bf16_causal_rtne_groupE,fwd_hd128_bf16_causal_rtne_group.co +bf16,128,128,2,1,1,0,256,32,_ZN5aiter37fmha_fwd_hd128_bf16_causal_rtna_groupE,fwd_hd128_bf16_causal_rtna_group.co +bf16,128,128,2,1,2,0,256,32,_ZN5aiter36fmha_fwd_hd128_bf16_causal_rtz_groupE,fwd_hd128_bf16_causal_rtz_group.co +bf16,192,128,0,0,0,0,128,32,_ZN5aiter28fmha_fwd_hd192x128_bf16_rtneE,fwd_hd192x128_bf16_rtne.co +bf16,192,128,0,0,1,0,128,32,_ZN5aiter28fmha_fwd_hd192x128_bf16_rtnaE,fwd_hd192x128_bf16_rtna.co +bf16,192,128,0,0,2,0,128,32,_ZN5aiter27fmha_fwd_hd192x128_bf16_rtzE,fwd_hd192x128_bf16_rtz.co +bf16,192,128,2,0,0,0,128,32,_ZN5aiter35fmha_fwd_hd192x128_bf16_causal_rtneE,fwd_hd192x128_bf16_causal_rtne.co +bf16,192,128,2,0,1,0,128,32,_ZN5aiter35fmha_fwd_hd192x128_bf16_causal_rtnaE,fwd_hd192x128_bf16_causal_rtna.co +bf16,192,128,2,0,2,0,128,32,_ZN5aiter34fmha_fwd_hd192x128_bf16_causal_rtzE,fwd_hd192x128_bf16_causal_rtz.co +bf16,192,128,0,1,0,0,128,32,_ZN5aiter34fmha_fwd_hd192x128_bf16_rtne_groupE,fwd_hd192x128_bf16_rtne_group.co +bf16,192,128,0,1,1,0,128,32,_ZN5aiter34fmha_fwd_hd192x128_bf16_rtna_groupE,fwd_hd192x128_bf16_rtna_group.co +bf16,192,128,0,1,2,0,128,32,_ZN5aiter33fmha_fwd_hd192x128_bf16_rtz_groupE,fwd_hd192x128_bf16_rtz_group.co +bf16,192,128,2,1,0,0,128,32,_ZN5aiter41fmha_fwd_hd192x128_bf16_causal_rtne_groupE,fwd_hd192x128_bf16_causal_rtne_group.co +bf16,192,128,2,1,1,0,128,32,_ZN5aiter41fmha_fwd_hd192x128_bf16_causal_rtna_groupE,fwd_hd192x128_bf16_causal_rtna_group.co +bf16,192,128,2,1,2,0,128,32,_ZN5aiter40fmha_fwd_hd192x128_bf16_causal_rtz_groupE,fwd_hd192x128_bf16_causal_rtz_group.co +fp8bf16,128,128,0,0,1,0,256,64,_ZN5aiter18fmha_fwd_hd128_fp8E,fwd_hd128_fp8.co +fp8bf16,128,128,2,0,1,0,256,64,_ZN5aiter25fmha_fwd_hd128_fp8_causalE,fwd_hd128_fp8_causal.co +fp8bf16,128,128,0,1,1,0,256,64,_ZN5aiter24fmha_fwd_hd128_fp8_groupE,fwd_hd128_fp8_group.co +fp8bf16,128,128,2,1,1,0,256,64,_ZN5aiter31fmha_fwd_hd128_fp8_causal_groupE,fwd_hd128_fp8_causal_group.co +fp8bf16,128,128,0,0,1,1,256,64,_ZN5aiter29fmha_fwd_hd128_fp8_qkptph_vphE,fwd_hd128_fp8_qkptph_vph.co +fp8bf16,128,128,2,0,1,1,256,64,_ZN5aiter36fmha_fwd_hd128_fp8_causal_qkptph_vphE,fwd_hd128_fp8_causal_qkptph_vph.co +fp8bf16,128,128,0,1,1,1,256,64,_ZN5aiter35fmha_fwd_hd128_fp8_qkptph_vph_groupE,fwd_hd128_fp8_qkptph_vph_group.co +fp8bf16,128,128,2,1,1,1,256,64,_ZN5aiter42fmha_fwd_hd128_fp8_causal_qkptph_vph_groupE,fwd_hd128_fp8_causal_qkptph_vph_group.co \ No newline at end of file diff --git a/hsa/gfx942/fmha_v3_fwd/fwd_hd128_fp8_causal.co b/hsa/gfx942/fmha_v3_fwd/fwd_hd128_fp8_causal.co new file mode 100755 index 0000000000..6b564efe4e Binary files /dev/null and b/hsa/gfx942/fmha_v3_fwd/fwd_hd128_fp8_causal.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fmha_fwd.csv b/hsa/gfx950/fmha_v3_fwd/fmha_fwd.csv index f418e10cd0..e640af834f 100644 --- a/hsa/gfx950/fmha_v3_fwd/fmha_fwd.csv +++ b/hsa/gfx950/fmha_v3_fwd/fmha_fwd.csv @@ -1,13 +1,17 @@ -dtype,hdim_q,hdim_v,mask,mode,bf16_cvt,ts_qo,ts_kv,knl_name,co_name -bf16,128,128,0,0,0,256,64,_ZN5aiter19fmha_fwd_hd128_bf16E,fwd_hd128_bf16.co -bf16,128,128,2,0,0,256,64,_ZN5aiter26fmha_fwd_hd128_bf16_causalE,fwd_hd128_bf16_causal.co -bf16,128,128,0,1,0,256,64,_ZN5aiter25fmha_fwd_hd128_bf16_groupE,fwd_hd128_bf16_group.co -bf16,128,128,2,1,0,256,64,_ZN5aiter32fmha_fwd_hd128_bf16_causal_groupE,fwd_hd128_bf16_causal_group.co -bf16,192,128,0,0,0,128,128,_ZN5aiter25fmha_fwd_hd192_hd128_bf16E,fwd_hd192_hd128_bf16.co -bf16,192,128,2,0,0,128,128,_ZN5aiter32fmha_fwd_hd192_hd128_bf16_causalE,fwd_hd192_hd128_bf16_causal.co -bf16,192,128,0,1,0,128,128,_ZN5aiter31fmha_fwd_hd192_hd128_bf16_groupE,fwd_hd192_hd128_bf16_group.co -bf16,192,128,2,1,0,128,128,_ZN5aiter38fmha_fwd_hd192_hd128_bf16_causal_groupE,fwd_hd192_hd128_bf16_causal_group.co -fp8bf16,128,128,0,0,0,256,128,_ZN5aiter24fmha_fwd_hd128_fp8_gfx950E,fwd_hd128_fp8.co -fp8bf16,128,128,2,0,0,256,128,_ZN5aiter31fmha_fwd_hd128_fp8_causal_gfx950E,fwd_hd128_fp8_causal.co -fp8bf16,128,128,0,1,0,256,128,_ZN5aiter30fmha_fwd_hd128_fp8_group_gfx950E,fwd_hd128_fp8_group.co -fp8bf16,128,128,2,1,0,256,128,_ZN5aiter37fmha_fwd_hd128_fp8_causal_group_gfx950E,fwd_hd128_fp8_causal_group.co \ No newline at end of file +dtype,hdim_q,hdim_v,mask,mode,bf16_cvt,qscale,ts_qo,ts_kv,knl_name,co_name +bf16,128,128,0,0,0,0,256,64,_ZN5aiter19fmha_fwd_hd128_bf16E,fwd_hd128_bf16.co +bf16,128,128,2,0,0,0,256,64,_ZN5aiter26fmha_fwd_hd128_bf16_causalE,fwd_hd128_bf16_causal.co +bf16,128,128,0,1,0,0,256,64,_ZN5aiter25fmha_fwd_hd128_bf16_groupE,fwd_hd128_bf16_group.co +bf16,128,128,2,1,0,0,256,64,_ZN5aiter32fmha_fwd_hd128_bf16_causal_groupE,fwd_hd128_bf16_causal_group.co +bf16,192,128,0,0,0,0,128,128,_ZN5aiter25fmha_fwd_hd192_hd128_bf16E,fwd_hd192_hd128_bf16.co +bf16,192,128,2,0,0,0,128,128,_ZN5aiter32fmha_fwd_hd192_hd128_bf16_causalE,fwd_hd192_hd128_bf16_causal.co +bf16,192,128,0,1,0,0,128,128,_ZN5aiter31fmha_fwd_hd192_hd128_bf16_groupE,fwd_hd192_hd128_bf16_group.co +bf16,192,128,2,1,0,0,128,128,_ZN5aiter38fmha_fwd_hd192_hd128_bf16_causal_groupE,fwd_hd192_hd128_bf16_causal_group.co +fp8bf16,128,128,0,0,0,0,256,128,_ZN5aiter24fmha_fwd_hd128_fp8_gfx950E,fwd_hd128_fp8.co +fp8bf16,128,128,2,0,0,0,256,128,_ZN5aiter31fmha_fwd_hd128_fp8_causal_gfx950E,fwd_hd128_fp8_causal.co +fp8bf16,128,128,0,1,0,0,256,128,_ZN5aiter30fmha_fwd_hd128_fp8_group_gfx950E,fwd_hd128_fp8_group.co +fp8bf16,128,128,2,1,0,0,256,128,_ZN5aiter37fmha_fwd_hd128_fp8_causal_group_gfx950E,fwd_hd128_fp8_causal_group.co +fp8bf16,128,128,0,0,0,1,256,128,_ZN5aiter36fmha_fwd_hd128_fp8_qkptph_vph_gfx950E,fwd_hd128_fp8_qkptph_vph.co +fp8bf16,128,128,2,0,0,1,256,128,_ZN5aiter43fmha_fwd_hd128_fp8_causal_qkptph_vph_gfx950E,fwd_hd128_fp8_causal_qkptph_vph.co +fp8bf16,128,128,0,1,0,1,256,128,_ZN5aiter42fmha_fwd_hd128_fp8_qkptph_vph_group_gfx950E,fwd_hd128_fp8_qkptph_vph_group.co +fp8bf16,128,128,2,1,0,1,256,128,_ZN5aiter49fmha_fwd_hd128_fp8_causal_qkptph_vph_group_gfx950E,fwd_hd128_fp8_causal_qkptph_vph_group.co \ No newline at end of file diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_causal_qkptph_vph.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_causal_qkptph_vph.co new file mode 100755 index 0000000000..f29060b234 Binary files /dev/null and b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_causal_qkptph_vph.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_causal_qkptph_vph_group.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_causal_qkptph_vph_group.co new file mode 100755 index 0000000000..4d1ff4349c Binary files /dev/null and b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_causal_qkptph_vph_group.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_qkptph_vph.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_qkptph_vph.co new file mode 100755 index 0000000000..3a5cf4699a Binary files /dev/null and b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_qkptph_vph.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_qkptph_vph_group.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_qkptph_vph_group.co new file mode 100755 index 0000000000..a3d3403872 Binary files /dev/null and b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_fp8_qkptph_vph_group.co differ diff --git a/op_tests/test_mha_fp8.py b/op_tests/test_mha_fp8.py index e6ad0479e8..19009517d7 100644 --- a/op_tests/test_mha_fp8.py +++ b/op_tests/test_mha_fp8.py @@ -175,6 +175,78 @@ def test_flash_attn_output( benchmark["fwd_gb_per_sec"] = (fwd_num_bytes) / 1.0e3 / us_fwd +def _per_token_per_head_quant(x, quant_dtype): + fp8_max = torch.finfo(quant_dtype).max + scale = x.float().abs().amax(dim=-1).clamp(min=1.0e-6) / fp8_max + x_quant = (x.float() / scale.unsqueeze(-1)).to(quant_dtype) + return x_quant, scale.float() + + +def _v_per_head_quant(v, quant_dtype): + fp8_max = torch.finfo(quant_dtype).max + scale = v.float().abs().amax(dim=(0, 1, 3)).clamp(min=1.0e-6) / fp8_max + v_quant = (v.float() / scale.view(1, 1, -1, 1)).to(quant_dtype) + return v_quant, scale.float() + + +def _ref_fp8_ptph_attention(q, k, v, qscale, kscale, vscale, causal): + _, _, nheads, d = q.shape + nheads_k = k.shape[2] + num_group = nheads // nheads_k + + q_deq = q.float() * qscale.permute(0, 2, 1).unsqueeze(-1) + k_deq = k.float() * kscale.permute(0, 2, 1).unsqueeze(-1) + v_deq = v.float() * vscale.view(1, 1, nheads_k, 1) + k_deq = k_deq.repeat_interleave(num_group, dim=2) + v_deq = v_deq.repeat_interleave(num_group, dim=2) + + scores = torch.einsum("bqhd,bkhd->bhqk", q_deq, k_deq) / math.sqrt(d) + if causal: + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + mask = torch.tril( + torch.ones((seqlen_k, seqlen_k), device=q.device, dtype=torch.bool) + )[(seqlen_k - seqlen_q) :, :] + scores = scores.masked_fill(~mask.view(1, 1, seqlen_q, seqlen_k), float("-inf")) + + p = torch.exp(scores - scores.max(dim=-1, keepdim=True).values) + rowsum = p.sum(dim=-1, keepdim=True) + p = p.to(dtypes.fp8).float() + out = torch.einsum("bhqk,bkhd->bqhd", p, v_deq) / rowsum.permute(0, 2, 1, 3) + return out.to(torch.bfloat16) + + +@pytest.mark.parametrize("causal", [False, True]) +def test_flash_attn_fp8_qk_per_token_per_head_v_per_head(causal): + if torch.cuda.get_device_properties(0).gcnArchName.split(":")[0] not in ("gfx942", "gfx950"): + pytest.skip("ASM v3 FP8 PTPH path is only enabled on gfx942/gfx950") + + torch.random.manual_seed(0) + batch_size, seqlen_q, seqlen_k = 1, 256, 256 + nheads, nheads_k = 4, 2 + d = d_v = 128 + dtype = torch.bfloat16 + quant_dtype = dtypes.fp8 + + q = torch.randn(batch_size, seqlen_q, nheads, d, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_k, d, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_k, d_v, device="cuda", dtype=dtype) + + q_quant, qscale_bsqh = _per_token_per_head_quant(q, quant_dtype) + k_quant, kscale_bskh = _per_token_per_head_quant(k, quant_dtype) + v_quant, vscale = _v_per_head_quant(v, quant_dtype) + qscale = qscale_bsqh.permute(0, 2, 1).contiguous() + kscale = kscale_bskh.permute(0, 2, 1).contiguous() + + out = flash_attn_fp8_pertensor_func( + q_quant, k_quant, v_quant, qscale, kscale, vscale, causal=causal + ) + out_ref = _ref_fp8_ptph_attention(q_quant, k_quant, v_quant, qscale, kscale, vscale, causal) + + abs_diff = (out - out_ref).abs() + max_diff = abs_diff.max().item() + assert max_diff < 0.08 + + parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="config input of test",