Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 126 additions & 29 deletions csrc/cpp_itfs/mha_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,120 @@
#include "mha_bwd.h"
#include "aiter_hip_common.h"
#include "asm_fmha_v3_bwd_configs.hpp"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>

namespace aiter {
namespace {

struct KernelArgBufferWriter
{
template <typename T>
void append_raw(const T& value)
{
const auto* bytes = reinterpret_cast<const std::byte*>(&value);
storage.insert(storage.end(), bytes, bytes + sizeof(T));
}

template <typename PtrT>
void append_ptr(PtrT ptr, bool compact)
{
append_raw(ptr);
if(!compact)
{
append_raw(uint32_t{0});
append_raw(uint32_t{0});
}
}

void append_u32(uint32_t value, bool compact)
{
append_raw(value);
if(!compact)
{
append_raw(uint32_t{0});
append_raw(uint32_t{0});
append_raw(uint32_t{0});
}
}

std::vector<std::byte> storage;
};

struct fmha_bwd_odo_logical_args
{
const void* ptr_o;
const void* ptr_do;
void* ptr_d;
uint32_t Hs_o;
uint32_t BAs_o;
uint32_t Seqs_o;
uint32_t Hs_do;
uint32_t BAs_do;
uint32_t Seqs_do;
uint32_t Hs_d;
uint32_t BAs_d;
uint32_t Seqs_d;
uint32_t seqlen_q;
uint32_t head_dim;
const void* ptr_qseq;
const void* ptr_qseq_padded;
};

bool use_compact_fmha_bwd_kernel_args(const std::string& arch_id)
{
return arch_id == "gfx1250";
}

fmha_bwd_odo_logical_args make_fmha_bwd_odo_logical_args(const mha_bwd_args& a)
{
return {
a.o_ptr,
a.do_ptr,
a.d_ptr,
static_cast<uint32_t>(a.nhead_stride_o * 2),
static_cast<uint32_t>(a.batch_stride_o * 2),
static_cast<uint32_t>(a.stride_o * 2),
static_cast<uint32_t>(a.nhead_stride_do * 2),
static_cast<uint32_t>(a.batch_stride_do * 2),
static_cast<uint32_t>(a.stride_do * 2),
static_cast<uint32_t>(a.nhead_stride_lsed * 4),
static_cast<uint32_t>(a.batch_stride_lsed * 4),
1u * 4u,
static_cast<uint32_t>(a.seqlen_q),
static_cast<uint32_t>(a.hdim_q),
(a.cu_seqlen_q_ptr && a.seqstart_q_ptr) ? a.cu_seqlen_q_ptr : a.seqstart_q_ptr,
a.seqstart_q_ptr,
};
}

std::vector<std::byte> pack_fmha_bwd_odo_args(const fmha_bwd_odo_logical_args& args, bool compact)
{
KernelArgBufferWriter writer;
writer.append_ptr(args.ptr_o, compact);
writer.append_ptr(args.ptr_do, compact);
writer.append_ptr(args.ptr_d, compact);
writer.append_u32(args.Hs_o, compact);
writer.append_u32(args.BAs_o, compact);
writer.append_u32(args.Seqs_o, compact);
writer.append_u32(args.Hs_do, compact);
writer.append_u32(args.BAs_do, compact);
writer.append_u32(args.Seqs_do, compact);
writer.append_u32(args.Hs_d, compact);
writer.append_u32(args.BAs_d, compact);
writer.append_u32(args.Seqs_d, compact);
writer.append_u32(args.seqlen_q, compact);
writer.append_u32(args.head_dim, compact);
writer.append_ptr(args.ptr_qseq, compact);
writer.append_ptr(args.ptr_qseq_padded, compact);
return writer.storage;
}

} // namespace

std::tuple<int, int> get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id)
{
if(hdim_q == 192 && hdim_v == 128 && arch_id == "gfx950")
Expand Down Expand Up @@ -257,7 +367,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
std::string arch_id = get_gpu_arch();
if((!a.use_asm_v3) || (a.hdim_q % 8 != 0) || (a.hdim_v % 8 != 0) || (a.has_dbias) ||
(a.bias_type != 0) || (a.has_dropout) || (a.is_deterministic) ||
((arch_id != "gfx942") && (arch_id != "gfx950")))
((arch_id != "gfx942") && (arch_id != "gfx950") && (arch_id != "gfx1250")))
{
return -1;
}
Expand Down Expand Up @@ -311,7 +421,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
return &cfg_fmha_bwd_dq_shuffle;
}
}
else
else if (arch_id == "gfx942")
{
if(a.v3_atomic_fp32)
{
Expand All @@ -321,11 +431,13 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
{
return static_cast<CFG*>(nullptr);
}
} else {
return &cfg_fmha_bwd_dq_convert; // gfx1250 only support atomic32=1
}
Comment on lines 424 to 436
}();

bool need_post_processing =
((arch_id == "gfx950") && (a.hdim_q != 64)) || (a.v3_atomic_fp32 == 1);
((arch_id == "gfx950") && (a.hdim_q != 64)) || (a.v3_atomic_fp32 == 1) || (arch_id == "gfx1250");

int mt = asm_mask_type();

Expand Down Expand Up @@ -367,6 +479,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
int ts_kv;
int ts_dq;
size_t arg_size;
const bool compact_odo_args = use_compact_fmha_bwd_kernel_args(arch_id);

AiterAsmKernel* impl_ptr_pre = nullptr;
AiterAsmKernel* impl_ptr_dqdkdv = nullptr;
Expand Down Expand Up @@ -427,34 +540,18 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
if(a.v3_api_check)
return 1;

fmha_bwd_odo_args odo_args;

odo_args.ptr_o = a.o_ptr;
odo_args.ptr_do = a.do_ptr;
odo_args.ptr_d = a.d_ptr;
odo_args.Hs_o = a.nhead_stride_o * 2;
odo_args.BAs_o = a.batch_stride_o * 2;
odo_args.Seqs_o = a.stride_o * 2;
odo_args.Hs_do = a.nhead_stride_do * 2;
odo_args.BAs_do = a.batch_stride_do * 2;
odo_args.Seqs_do = a.stride_do * 2;
odo_args.Hs_d = a.nhead_stride_lsed * 4;
odo_args.BAs_d = a.batch_stride_lsed * 4;
odo_args.Seqs_d = 1 * 4;
odo_args.seqlen_q = a.seqlen_q;
odo_args.head_dim = a.hdim_q;
odo_args.ptr_qseq_padded = a.seqstart_q_ptr;
odo_args.ptr_qseq =
(a.cu_seqlen_q_ptr && a.seqstart_q_ptr) ? a.cu_seqlen_q_ptr : a.seqstart_q_ptr;
const auto odo_args = make_fmha_bwd_odo_logical_args(a);
auto odo_arg_storage = pack_fmha_bwd_odo_args(odo_args, compact_odo_args);

auto pre_kernel_launch = [&]() {
arg_size = sizeof(odo_args);
int bdx = 256;
arg_size = odo_arg_storage.size();
int bdx = (arch_id == "gfx1250") ? 128 : 256;
int gdx = (a.max_seqlen_q + ts_odo - 1) / ts_odo;
int gdy = a.nhead_q;
int gdz = a.batch;

impl_ptr_pre->launch_kernel({&odo_args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, s.stream_id_});
impl_ptr_pre->launch_kernel(
{odo_arg_storage.data(), &arg_size, gdx, gdy, gdz, bdx, 1, 1, s.stream_id_});
};

fmha_bwd_dqdkdv_args dqdkdv_args;
Expand Down Expand Up @@ -546,13 +643,13 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)

auto dqdkdv_kernel_launch = [&]() {
arg_size = sizeof(dqdkdv_args);
int bdx = 256;
int bdx = (arch_id == "gfx1250") ? 128 : 256;
int gdx = (a.max_seqlen_k + ts_kv - 1) / ts_kv;
int gdy = a.nhead_q;
int gdz = a.batch;

if((mt == 1) || (mt == 2))
{ // causal
{ // mask kb
gdx = (gdx + 1) / 2;
}

Expand Down Expand Up @@ -586,8 +683,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
(a.cu_seqlen_q_ptr && a.seqstart_q_ptr) ? a.cu_seqlen_q_ptr : a.seqstart_q_ptr;

auto post_kernel_launch = [&]() {
arg_size = sizeof(post_args);
int bdx = 256;
arg_size = sizeof(post_args);
int bdx = (arch_id == "gfx1250") ? 128 : 256;
int gdx = (a.max_seqlen_q + ts_dq - 1) / ts_dq;
int gdy = a.nhead_q;
int gdz = a.batch;
Expand Down
36 changes: 0 additions & 36 deletions csrc/include/mha_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,42 +247,6 @@ struct __attribute__((packed)) fmha_bwd_dqdkdv_args
p3 _p43;
};

struct __attribute__((packed)) fmha_bwd_odo_args
{
const void* ptr_o;
p2 _p0;
const void* ptr_do;
p2 _p1;
void* ptr_d;
p2 _p2;
unsigned int Hs_o;
p3 _p3;
unsigned int BAs_o;
p3 _p4;
unsigned int Seqs_o;
p3 _p5;
unsigned int Hs_do;
p3 _p6;
unsigned int BAs_do;
p3 _p7;
unsigned int Seqs_do;
p3 _p8;
unsigned int Hs_d;
p3 _p9;
unsigned int BAs_d;
p3 _p10;
unsigned int Seqs_d;
p3 _p11;
unsigned int seqlen_q;
p3 _p12;
unsigned int head_dim;
p3 _p13;
const void* ptr_qseq;
p2 _p14;
const void* ptr_qseq_padded;
p2 _p15;
};

// dq_shuffle & dq_convert post process kernel args
struct __attribute__((packed)) fmha_bwd_post_kernel_args
{
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added hsa/gfx1250/fmha_v3_bwd/bwd_hd128_odo_bf16.co
Binary file not shown.
2 changes: 2 additions & 0 deletions hsa/gfx1250/fmha_v3_bwd/fmha_bwd_dq_convert.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name
bf16,128,128,0,0,0,0,0,3,0,64,_ZN5aiter30fmha_bwd_hd128_dq_convert_bf16E,bwd_hd128_dq_convert_bf16.co
3 changes: 3 additions & 0 deletions hsa/gfx1250/fmha_v3_bwd/fmha_bwd_dqdkdv.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name
bf16,128,128,0,1,0,0,0,3,32,128,_ZN5aiter28fmha_bwd_hd128_bf16_a32_psskE,bwd_hd128_bf16_a32_pssk.co
bf16,128,128,2,1,0,0,0,3,32,128,_ZN5aiter38fmha_bwd_hd128_bf16_causal_br_a32_psskE,bwd_hd128_bf16_causal_br_a32_pssk.co
2 changes: 2 additions & 0 deletions hsa/gfx1250/fmha_v3_bwd/fmha_bwd_odo.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
dtype,hdim_q,hdim_v,mask,atomic32,pssk,pddv,mode,bf16_cvt,ts_qo,ts,knl_name,co_name
bf16,128,128,0,0,0,0,0,3,0,128,_ZN5aiter23fmha_bwd_hd128_odo_bf16E,bwd_hd128_odo_bf16.co