diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 977aff62d81..e0f17eef8bb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -426,10 +426,12 @@ struct vk_fa_pipeline_state { bool f32acc; uint32_t flags; uint32_t limit_occupancy_shmem; + ggml_type k_type; + ggml_type v_type; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) < - std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type); } }; @@ -2967,7 +2969,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device return result; } -static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { GGML_UNUSED(n_kv); GGML_UNUSED(f32acc); @@ -2981,7 +2983,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device if (small_rows) { result.block_rows = 32; result.block_cols = 32; - } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) { + } else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) { result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64; result.block_cols = 32; } else { @@ -2995,7 +2997,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device return result; } -static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { + // Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1. + if (k_type != v_type) { + GGML_ASSERT(device->coopmat2); + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); + } + FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; @@ -3007,7 +3015,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ if (path == FA_COOPMAT1) { bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); - const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); if (!shape_ok || !shmem_ok) { @@ -3020,20 +3028,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ path = FA_SCALAR; } + // Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it. + if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) { + path = FA_COOPMAT2; + } + switch (path) { case FA_SCALAR: - return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); case FA_COOPMAT1: - return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); case FA_COOPMAT2: - return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); default: throw std::runtime_error("unsupported FaCodePath"); } } static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, - bool use_mask, bool use_mask_opt, bool use_logit_softcap) { + bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) { const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary && (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2); @@ -3044,12 +3057,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; - return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem}; + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type}; } static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { - return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split, - state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem}; + const auto fa_block_bytes = [](ggml_type t) -> uint32_t { + // decodeBufF32 uses a block of vec4s for a better memory access pattern. + return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t); + }; + return { + /* 0 WorkGroupSize */ state.workgroup_size, + /* 1 Br */ state.Br, + /* 2 Bc */ state.Bc, + /* 3 HSK */ state.HSK, + /* 4 HSV */ state.HSV, + /* 5 Clamp */ static_cast(!state.aligned), + /* 6 D_split */ state.D_split, + /* 7 row_split */ state.row_split, + /* 8 SubGroupSize */ state.subgroup_size, + /* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u, + /*10 Flags */ state.flags, + /*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem, + /*12 FaTypeK */ static_cast(state.k_type), + /*13 FaTypeV */ static_cast(state.v_type), + /*14 FaBlockBytesK */ fa_block_bytes(state.k_type), + /*15 FaBlockBytesV */ fa_block_bytes(state.v_type), + }; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3474,16 +3507,35 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#define CREATE_FA_CM2_MIXED() \ + for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \ + for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \ + FaCodePath path = fa.first.path; \ + uint32_t Br = fa.first.Br; \ + uint32_t Bc = fa.first.Bc; \ + bool aligned = fa.first.aligned; \ + bool f32acc = fa.first.f32acc; \ + if (path == FA_COOPMAT2) { \ + if (aligned) { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ + } \ + } else { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ + } \ + } \ + } \ + } \ + } if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + CREATE_FA_CM2_MIXED(); } +#undef CREATE_FA_CM2_MIXED #endif #undef CREATE_FA @@ -8902,8 +8954,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(dst->type == GGML_TYPE_F32); assert(q->type == GGML_TYPE_F32); - assert(k->type == v->type); - uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; uint32_t workgroups_x = (uint32_t)neq1; @@ -8914,7 +8964,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). - vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc); + vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc); const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && @@ -8927,7 +8977,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); + tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); + + if (tuning_params.path != FA_COOPMAT2) { + GGML_ASSERT(k->type == v->type); + } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); @@ -8966,7 +9020,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16; vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, - mask != nullptr, use_mask_opt, logit_softcap != 0); + mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type); vk_pipeline pipeline = nullptr; @@ -15366,38 +15420,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // It's straightforward to support different K/V dequant, but would - // significantly increase the number of pipelines - if (op->src[1]->type != op->src[2]->type) { + // mismatching K/V type is currently supported for coopmat2 only. + if (op->src[1]->type != op->src[2]->type && !coopmat2) { return false; } - switch (op->src[1]->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_IQ4_NL: - // supported in scalar and coopmat2 paths - break; - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //case GGML_TYPE_Q2_K: - //case GGML_TYPE_Q3_K: - //case GGML_TYPE_Q4_K: - //case GGML_TYPE_Q5_K: - //case GGML_TYPE_Q6_K: - //case GGML_TYPE_IQ1_S: - //case GGML_TYPE_IQ1_M: - //case GGML_TYPE_IQ2_XXS: - //case GGML_TYPE_IQ2_XS: - //case GGML_TYPE_IQ2_S: - //case GGML_TYPE_IQ3_XXS: - //case GGML_TYPE_IQ3_S: - //case GGML_TYPE_IQ4_XS: - - default: + auto fa_kv_ok = [coopmat2](ggml_type t) { + switch (t) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + return true; + case GGML_TYPE_Q1_0: + return coopmat2; + default: + return false; + } + }; + if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { return false; } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index b30dee86871..f3999603b76 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32; layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; layout (constant_id = 10) const uint32_t Flags = 0; layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; +// ggml_type enumerant for K/V +layout (constant_id = 12) const uint32_t FaTypeK = 0; +layout (constant_id = 13) const uint32_t FaTypeV = 0; +// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4). +layout (constant_id = 14) const uint32_t FaBlockBytesK = 2; +layout (constant_id = 15) const uint32_t FaBlockBytesV = 2; const bool USE_MASK_OPT = (Flags & 1) != 0; const bool MASK_ENABLE = (Flags & 2) != 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 0ea181342ce..8a7bbaeb92c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -17,8 +17,57 @@ #extension GL_EXT_null_initializer : enable #include "types.glsl" -#include "dequant_funcs_cm2.glsl" #include "flash_attn_base.glsl" +#include "dequant_funcs_cm2.glsl" + +// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { + uint8_t raw[FaBlockBytesK]; +}; +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V { + uint8_t raw[FaBlockBytesV]; +}; + +uint fa_block_elems(uint ty) { + switch (ty) { + case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32) + case 1u: return 1u; // GGML_TYPE_F16 + case 2u: return uint(QUANT_K_Q4_0); + case 3u: return uint(QUANT_K_Q4_1); + case 6u: return uint(QUANT_K_Q5_0); + case 7u: return uint(QUANT_K_Q5_1); + case 8u: return uint(QUANT_K_Q8_0); + case 41u: return uint(QUANT_K_Q1_0); + default: + return 1u; + } +} + +float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); + case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} + +float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); + case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -55,12 +104,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele return max(elem0, elem1); } -#if BLOCK_SIZE > 1 -#define DECODEFUNC , DEQUANTFUNC -#else -#define DECODEFUNC -#endif - // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) @@ -95,10 +138,6 @@ ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c } void main() { -#ifdef NEEDS_INIT_IQ_SHMEM - init_iq_shmem(gl_WorkGroupSize); -#endif - init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); @@ -107,10 +146,10 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -#if BLOCK_SIZE > 1 - tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); - tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); -#endif + const uint bs_k = fa_block_elems(FaTypeK); + const uint bs_v = fa_block_elems(FaTypeV); + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v); tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); @@ -120,10 +159,12 @@ void main() { if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; -#if BLOCK_SIZE == 1 - k_stride &= ~7; - v_stride &= ~7; -#endif + if (bs_k == 1u) { + k_stride &= ~7; + } + if (bs_v == 1u) { + v_stride &= ~7; + } m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); @@ -230,7 +271,13 @@ void main() { coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); + // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. + const bool k_use_decode = (bs_k > 1u); + if (k_use_decode) { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK); + } else { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); + } S = coopMatMulAdd(Qf16, K_T, S); if (LOGIT_SOFTCAP) { @@ -291,7 +338,12 @@ void main() { coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); + const bool v_use_decode = (bs_v > 1u); + if (v_use_decode) { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV); + } else { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); + } L = eM*L + rowsum; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77a55ea812b..04d8ccd7da1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -639,20 +639,17 @@ void process_shaders() { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } + if (fp16) { +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); +#endif + } + for (const auto& tname : type_names) { if (tname == "bf16") continue; if (fp16) { -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); - } else { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); - } -#endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e2efae94bcf..20a681eb980 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6151,11 +6151,12 @@ struct test_flash_attn_ext : public test_case { const float logit_softcap; // Gemma 2 const ggml_prec prec; - const ggml_type type_KV; + const ggml_type type_K; + const ggml_type type_V; std::array permute; std::string vars() override { - return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute); + return VARS_TO_STR14(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_K, type_V, permute); } double max_nmse_err() override { @@ -6171,12 +6172,13 @@ struct test_flash_attn_ext : public test_case { test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8, bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32, - ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3}) - : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} + ggml_type type_K = GGML_TYPE_F16, ggml_type type_V = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3}) + : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), + type_K(type_K), type_V(type_V), permute(permute) {} ggml_tensor * build_graph(ggml_context * ctx) override { - const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV)); - const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV)); + const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_K)); + const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_V)); auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * { int64_t ne[4] = {ne0, ne1, ne2, ne3}; @@ -6200,11 +6202,11 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false); ggml_set_name(q, "q"); - ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache + ggml_tensor * k = create_permuted(type_K, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache ggml_set_name(k, "k"); ggml_tensor * v = nullptr; - if (hsk_padded == 576 && hsv_padded == 512) { + if (type_K == type_V && hsk_padded == 576 && hsv_padded == 512) { // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models @@ -6214,7 +6216,7 @@ struct test_flash_attn_ext : public test_case { // - https://github.com/ggml-org/llama.cpp/pull/18986 v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0); } else { - v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache + v = create_permuted(type_V, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache } ggml_set_name(v, "v"); @@ -8582,11 +8584,11 @@ static std::vector> make_test_cases_eval() { for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue; test_cases.emplace_back(new test_flash_attn_ext( - hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV)); + hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, type_KV)); // run fewer test cases permuted if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { test_cases.emplace_back(new test_flash_attn_ext( - hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3})); + hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, type_KV, {0, 2, 1, 3})); } } } @@ -8602,6 +8604,16 @@ static std::vector> make_test_cases_eval() { } } + // mixed quant and Q1_0 test cases + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 4, {1, 1}, 128, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 4, {1, 1}, 128, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q4_0, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(72, 72, 4, {1, 1}, 96, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 4, {1, 1}, 96, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_F32)); + test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 1}, 96, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q1_0, GGML_TYPE_Q1_0)); + test_cases.emplace_back(new test_flash_attn_ext(128, 64, 4, {1, 1}, 128, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q1_0, GGML_TYPE_Q4_0)); + test_cases.emplace_back(new test_flash_attn_ext(64, 128, 4, {1, 1}, 128, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q4_0, GGML_TYPE_Q1_0)); + test_cases.emplace_back(new test_flash_attn_ext(128, 64, 4, {1, 1}, 64, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q1_0, GGML_TYPE_F16)); + test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3})); test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1})); test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, { 10, 5, 4, 3})); @@ -8846,15 +8858,19 @@ static std::vector> make_test_cases_perf() { } // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 - test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_F16)); - test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); - test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 512, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 512, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)); for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_F16)); } } }