diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 69f55db0e4..df42f73c99 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -1,10 +1,120 @@ #include "mha_bwd.h" #include "aiter_hip_common.h" #include "asm_fmha_v3_bwd_configs.hpp" +#include +#include #include #include +#include namespace aiter { +namespace { + +struct KernelArgBufferWriter +{ + template + void append_raw(const T& value) + { + const auto* bytes = reinterpret_cast(&value); + storage.insert(storage.end(), bytes, bytes + sizeof(T)); + } + + template + void append_ptr(PtrT ptr, bool compact) + { + append_raw(ptr); + if(!compact) + { + append_raw(uint32_t{0}); + append_raw(uint32_t{0}); + } + } + + void append_u32(uint32_t value, bool compact) + { + append_raw(value); + if(!compact) + { + append_raw(uint32_t{0}); + append_raw(uint32_t{0}); + append_raw(uint32_t{0}); + } + } + + std::vector storage; +}; + +struct fmha_bwd_odo_logical_args +{ + const void* ptr_o; + const void* ptr_do; + void* ptr_d; + uint32_t Hs_o; + uint32_t BAs_o; + uint32_t Seqs_o; + uint32_t Hs_do; + uint32_t BAs_do; + uint32_t Seqs_do; + uint32_t Hs_d; + uint32_t BAs_d; + uint32_t Seqs_d; + uint32_t seqlen_q; + uint32_t head_dim; + const void* ptr_qseq; + const void* ptr_qseq_padded; +}; + +bool use_compact_fmha_bwd_kernel_args(const std::string& arch_id) +{ + return arch_id == "gfx1250"; +} + +fmha_bwd_odo_logical_args make_fmha_bwd_odo_logical_args(const mha_bwd_args& a) +{ + return { + a.o_ptr, + a.do_ptr, + a.d_ptr, + static_cast(a.nhead_stride_o * 2), + static_cast(a.batch_stride_o * 2), + static_cast(a.stride_o * 2), + static_cast(a.nhead_stride_do * 2), + static_cast(a.batch_stride_do * 2), + static_cast(a.stride_do * 2), + static_cast(a.nhead_stride_lsed * 4), + static_cast(a.batch_stride_lsed * 4), + 1u * 4u, + static_cast(a.seqlen_q), + static_cast(a.hdim_q), + (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) ? a.cu_seqlen_q_ptr : a.seqstart_q_ptr, + a.seqstart_q_ptr, + }; +} + +std::vector pack_fmha_bwd_odo_args(const fmha_bwd_odo_logical_args& args, bool compact) +{ + KernelArgBufferWriter writer; + writer.append_ptr(args.ptr_o, compact); + writer.append_ptr(args.ptr_do, compact); + writer.append_ptr(args.ptr_d, compact); + writer.append_u32(args.Hs_o, compact); + writer.append_u32(args.BAs_o, compact); + writer.append_u32(args.Seqs_o, compact); + writer.append_u32(args.Hs_do, compact); + writer.append_u32(args.BAs_do, compact); + writer.append_u32(args.Seqs_do, compact); + writer.append_u32(args.Hs_d, compact); + writer.append_u32(args.BAs_d, compact); + writer.append_u32(args.Seqs_d, compact); + writer.append_u32(args.seqlen_q, compact); + writer.append_u32(args.head_dim, compact); + writer.append_ptr(args.ptr_qseq, compact); + writer.append_ptr(args.ptr_qseq_padded, compact); + return writer.storage; +} + +} // namespace + std::tuple get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id) { if(hdim_q == 192 && hdim_v == 128 && arch_id == "gfx950") @@ -257,7 +367,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) std::string arch_id = get_gpu_arch(); if((!a.use_asm_v3) || (a.hdim_q % 8 != 0) || (a.hdim_v % 8 != 0) || (a.has_dbias) || (a.bias_type != 0) || (a.has_dropout) || (a.is_deterministic) || - ((arch_id != "gfx942") && (arch_id != "gfx950"))) + ((arch_id != "gfx942") && (arch_id != "gfx950") && (arch_id != "gfx1250"))) { return -1; } @@ -311,7 +421,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) return &cfg_fmha_bwd_dq_shuffle; } } - else + else if (arch_id == "gfx942") { if(a.v3_atomic_fp32) { @@ -321,11 +431,13 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) { return static_cast(nullptr); } + } else { + return &cfg_fmha_bwd_dq_convert; // gfx1250 only support atomic32=1 } }(); bool need_post_processing = - ((arch_id == "gfx950") && (a.hdim_q != 64)) || (a.v3_atomic_fp32 == 1); + ((arch_id == "gfx950") && (a.hdim_q != 64)) || (a.v3_atomic_fp32 == 1) || (arch_id == "gfx1250"); int mt = asm_mask_type(); @@ -367,6 +479,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) int ts_kv; int ts_dq; size_t arg_size; + const bool compact_odo_args = use_compact_fmha_bwd_kernel_args(arch_id); AiterAsmKernel* impl_ptr_pre = nullptr; AiterAsmKernel* impl_ptr_dqdkdv = nullptr; @@ -427,34 +540,18 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) if(a.v3_api_check) return 1; - fmha_bwd_odo_args odo_args; - - odo_args.ptr_o = a.o_ptr; - odo_args.ptr_do = a.do_ptr; - odo_args.ptr_d = a.d_ptr; - odo_args.Hs_o = a.nhead_stride_o * 2; - odo_args.BAs_o = a.batch_stride_o * 2; - odo_args.Seqs_o = a.stride_o * 2; - odo_args.Hs_do = a.nhead_stride_do * 2; - odo_args.BAs_do = a.batch_stride_do * 2; - odo_args.Seqs_do = a.stride_do * 2; - odo_args.Hs_d = a.nhead_stride_lsed * 4; - odo_args.BAs_d = a.batch_stride_lsed * 4; - odo_args.Seqs_d = 1 * 4; - odo_args.seqlen_q = a.seqlen_q; - odo_args.head_dim = a.hdim_q; - odo_args.ptr_qseq_padded = a.seqstart_q_ptr; - odo_args.ptr_qseq = - (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) ? a.cu_seqlen_q_ptr : a.seqstart_q_ptr; + const auto odo_args = make_fmha_bwd_odo_logical_args(a); + auto odo_arg_storage = pack_fmha_bwd_odo_args(odo_args, compact_odo_args); auto pre_kernel_launch = [&]() { - arg_size = sizeof(odo_args); - int bdx = 256; + arg_size = odo_arg_storage.size(); + int bdx = (arch_id == "gfx1250") ? 128 : 256; int gdx = (a.max_seqlen_q + ts_odo - 1) / ts_odo; int gdy = a.nhead_q; int gdz = a.batch; - impl_ptr_pre->launch_kernel({&odo_args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, s.stream_id_}); + impl_ptr_pre->launch_kernel( + {odo_arg_storage.data(), &arg_size, gdx, gdy, gdz, bdx, 1, 1, s.stream_id_}); }; fmha_bwd_dqdkdv_args dqdkdv_args; @@ -546,13 +643,13 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) auto dqdkdv_kernel_launch = [&]() { arg_size = sizeof(dqdkdv_args); - int bdx = 256; + int bdx = (arch_id == "gfx1250") ? 128 : 256; int gdx = (a.max_seqlen_k + ts_kv - 1) / ts_kv; int gdy = a.nhead_q; int gdz = a.batch; if((mt == 1) || (mt == 2)) - { // causal + { // mask kb gdx = (gdx + 1) / 2; } @@ -586,8 +683,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) ? a.cu_seqlen_q_ptr : a.seqstart_q_ptr; auto post_kernel_launch = [&]() { - arg_size = sizeof(post_args); - int bdx = 256; + arg_size = sizeof(post_args); + int bdx = (arch_id == "gfx1250") ? 128 : 256; int gdx = (a.max_seqlen_q + ts_dq - 1) / ts_dq; int gdy = a.nhead_q; int gdz = a.batch; diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index 0afaeadd22..9652db34eb 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -247,42 +247,6 @@ struct __attribute__((packed)) fmha_bwd_dqdkdv_args p3 _p43; }; -struct __attribute__((packed)) fmha_bwd_odo_args -{ - const void* ptr_o; - p2 _p0; - const void* ptr_do; - p2 _p1; - void* ptr_d; - p2 _p2; - unsigned int Hs_o; - p3 _p3; - unsigned int BAs_o; - p3 _p4; - unsigned int Seqs_o; - p3 _p5; - unsigned int Hs_do; - p3 _p6; - unsigned int BAs_do; - p3 _p7; - unsigned int Seqs_do; - p3 _p8; - unsigned int Hs_d; - p3 _p9; - unsigned int BAs_d; - p3 _p10; - unsigned int Seqs_d; - p3 _p11; - unsigned int seqlen_q; - p3 _p12; - unsigned int head_dim; - p3 _p13; - const void* ptr_qseq; - p2 _p14; - const void* ptr_qseq_padded; - p2 _p15; -}; - // dq_shuffle & dq_convert post process kernel args struct __attribute__((packed)) fmha_bwd_post_kernel_args { diff --git a/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_bf16_a32_pssk.co b/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_bf16_a32_pssk.co new file mode 100755 index 0000000000..beda8958a9 Binary files /dev/null and b/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_bf16_a32_pssk.co differ diff --git a/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_bf16_causal_br_a32_pssk.co b/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_bf16_causal_br_a32_pssk.co new file mode 100755 index 0000000000..bc5e5e9711 Binary files /dev/null and b/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_bf16_causal_br_a32_pssk.co differ diff --git a/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_dq_convert_bf16.co b/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_dq_convert_bf16.co new file mode 100755 index 0000000000..9d9fa68b5b Binary files /dev/null and b/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_dq_convert_bf16.co differ diff --git a/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_odo_bf16.co b/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_odo_bf16.co new file mode 100755 index 0000000000..7d6b3d8fe1 Binary files /dev/null and b/hsa/gfx1250/fmha_v3_bwd/bwd_hd128_odo_bf16.co differ diff --git a/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_dq_convert.csv b/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_dq_convert.csv new file mode 100644 index 0000000000..d6368def22 --- /dev/null +++ b/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_dq_convert.csv @@ -0,0 +1,2 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +bf16,128,128,0,0,0,0,0,3,0,64,_ZN5aiter30fmha_bwd_hd128_dq_convert_bf16E,bwd_hd128_dq_convert_bf16.co diff --git a/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_dqdkdv.csv b/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_dqdkdv.csv new file mode 100644 index 0000000000..2ff9e1eafd --- /dev/null +++ b/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_dqdkdv.csv @@ -0,0 +1,3 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +bf16,128,128,0,1,0,0,0,3,32,128,_ZN5aiter28fmha_bwd_hd128_bf16_a32_psskE,bwd_hd128_bf16_a32_pssk.co +bf16,128,128,2,1,0,0,0,3,32,128,_ZN5aiter38fmha_bwd_hd128_bf16_causal_br_a32_psskE,bwd_hd128_bf16_causal_br_a32_pssk.co \ No newline at end of file diff --git a/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_odo.csv b/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_odo.csv new file mode 100644 index 0000000000..b36a5f6235 --- /dev/null +++ b/hsa/gfx1250/fmha_v3_bwd/fmha_bwd_odo.csv @@ -0,0 +1,2 @@ +dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name +bf16,128,128,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd128_odo_bf16E,bwd_hd128_odo_bf16.co