Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 99 additions & 56 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,12 @@ struct vk_fa_pipeline_state {
bool f32acc;
uint32_t flags;
uint32_t limit_occupancy_shmem;
ggml_type k_type;
ggml_type v_type;

bool operator<(const vk_fa_pipeline_state &b) const {
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) <
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type);
}
};

Expand Down Expand Up @@ -2967,7 +2969,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device
return result;
}

static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
GGML_UNUSED(n_kv);
GGML_UNUSED(f32acc);

Expand All @@ -2981,7 +2983,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
if (small_rows) {
result.block_rows = 32;
result.block_cols = 32;
} else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
} else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) {
result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
result.block_cols = 32;
} else {
Expand All @@ -2995,7 +2997,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
return result;
}

static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
if (k_type != v_type) {
GGML_ASSERT(device->coopmat2);
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
}

FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;

Expand All @@ -3007,7 +3015,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
if (path == FA_COOPMAT1) {
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);

if (!shape_ok || !shmem_ok) {
Expand All @@ -3020,20 +3028,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
path = FA_SCALAR;
}

// Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it.
if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) {
path = FA_COOPMAT2;
}

switch (path) {
case FA_SCALAR:
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
case FA_COOPMAT1:
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
case FA_COOPMAT2:
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
default:
throw std::runtime_error("unsupported FaCodePath");
}
}

static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) {
const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
(device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);

Expand All @@ -3044,12 +3057,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const

const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;

return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type};
}

static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
// decodeBufF32 uses a block of vec4s for a better memory access pattern.
return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t);
};
return {
/* 0 WorkGroupSize */ state.workgroup_size,
/* 1 Br */ state.Br,
/* 2 Bc */ state.Bc,
/* 3 HSK */ state.HSK,
/* 4 HSV */ state.HSV,
/* 5 Clamp */ static_cast<uint32_t>(!state.aligned),
/* 6 D_split */ state.D_split,
/* 7 row_split */ state.row_split,
/* 8 SubGroupSize */ state.subgroup_size,
/* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u,
/*10 Flags */ state.flags,
/*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem,
/*12 FaTypeK */ static_cast<uint32_t>(state.k_type),
/*13 FaTypeV */ static_cast<uint32_t>(state.v_type),
/*14 FaBlockBytesK */ fa_block_bytes(state.k_type),
/*15 FaBlockBytesV */ fa_block_bytes(state.v_type),
};
}

static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
Expand Down Expand Up @@ -3474,16 +3507,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
#define CREATE_FA_CM2_MIXED() \
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
FaCodePath path = fa.first.path; \
uint32_t Br = fa.first.Br; \
uint32_t Bc = fa.first.Bc; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
if (path == FA_COOPMAT2) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
} \
} \
} \
} \
}
if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
CREATE_FA_CM2_MIXED();
}
#undef CREATE_FA_CM2_MIXED
#endif
#undef CREATE_FA

Expand Down Expand Up @@ -8902,8 +8954,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx

assert(dst->type == GGML_TYPE_F32);
assert(q->type == GGML_TYPE_F32);
assert(k->type == v->type);

uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2;
uint32_t workgroups_x = (uint32_t)neq1;
Expand All @@ -8914,7 +8964,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx

// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc);
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);

if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
Expand All @@ -8927,7 +8977,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
workgroups_y /= gqa_ratio;
}

tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);

if (tuning_params.path != FA_COOPMAT2) {
GGML_ASSERT(k->type == v->type);
}

const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
Expand Down Expand Up @@ -8966,7 +9020,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
mask != nullptr, use_mask_opt, logit_softcap != 0);
mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type);

vk_pipeline pipeline = nullptr;

Expand Down Expand Up @@ -15366,38 +15420,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
// mismatching K/V type is currently supported for coopmat2 only.
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
return false;
}
switch (op->src[1]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
// supported in scalar and coopmat2 paths
break;
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
//case GGML_TYPE_Q2_K:
//case GGML_TYPE_Q3_K:
//case GGML_TYPE_Q4_K:
//case GGML_TYPE_Q5_K:
//case GGML_TYPE_Q6_K:
//case GGML_TYPE_IQ1_S:
//case GGML_TYPE_IQ1_M:
//case GGML_TYPE_IQ2_XXS:
//case GGML_TYPE_IQ2_XS:
//case GGML_TYPE_IQ2_S:
//case GGML_TYPE_IQ3_XXS:
//case GGML_TYPE_IQ3_S:
//case GGML_TYPE_IQ4_XS:

default:
auto fa_kv_ok = [coopmat2](ggml_type t) {
switch (t) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
return true;
case GGML_TYPE_Q1_0:
return coopmat2;
default:
return false;
}
};
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
return false;
}
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32;
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
layout (constant_id = 10) const uint32_t Flags = 0;
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
// ggml_type enumerant for K/V
layout (constant_id = 12) const uint32_t FaTypeK = 0;
layout (constant_id = 13) const uint32_t FaTypeV = 0;
// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4).
layout (constant_id = 14) const uint32_t FaBlockBytesK = 2;
layout (constant_id = 15) const uint32_t FaBlockBytesV = 2;

const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;
Expand Down
Loading
Loading