Skip to content

Commit edd4d9b

Browse files
authored
vulkan: add FA dequant for q4_1, q5_0, q5_1, iq4_nl (ggml-org#21029)
Add dequantize4() implementations for Q4_1, Q5_0, Q5_1, and IQ4_NL in the flash attention base shader. Register them in the shader generator, pipeline creation, and enable in the scalar/coopmat1 FA support check.
1 parent 482192f commit edd4d9b

3 files changed

Lines changed: 118 additions & 12 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3447,18 +3447,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
34473447
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
34483448
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
34493449
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
3450+
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
3451+
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
3452+
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
3453+
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
34503454
} else {
34513455
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
34523456
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
34533457
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
34543458
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
3459+
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
3460+
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
3461+
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
3462+
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
34553463
}
34563464
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
34573465
if (device->coopmat1_fa_support) {
34583466
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
34593467
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
34603468
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
34613469
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
3470+
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1)
3471+
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1)
3472+
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1)
3473+
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1)
34623474
}
34633475
#endif
34643476
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -15331,11 +15343,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1533115343
case GGML_TYPE_F32:
1533215344
case GGML_TYPE_Q4_0:
1533315345
case GGML_TYPE_Q8_0:
15334-
// supported in scalar and coopmat2 paths
15335-
break;
1533615346
case GGML_TYPE_Q4_1:
1533715347
case GGML_TYPE_Q5_0:
1533815348
case GGML_TYPE_Q5_1:
15349+
case GGML_TYPE_IQ4_NL:
15350+
// supported in scalar and coopmat2 paths
15351+
break;
1533915352
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
1534015353
//case GGML_TYPE_Q2_K:
1534115354
//case GGML_TYPE_Q3_K:
@@ -15350,12 +15363,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1535015363
//case GGML_TYPE_IQ3_XXS:
1535115364
//case GGML_TYPE_IQ3_S:
1535215365
//case GGML_TYPE_IQ4_XS:
15353-
case GGML_TYPE_IQ4_NL:
15354-
// currently supported only in coopmat2 path
15355-
if (!coopmat2) {
15356-
return false;
15357-
}
15358-
break;
15366+
1535915367
default:
1536015368
return false;
1536115369
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,11 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
110110

111111
#if defined(DATA_A_Q4_0)
112112
#define BLOCK_BYTE_SIZE 18
113+
#elif defined(DATA_A_Q4_1)
114+
#define BLOCK_BYTE_SIZE 20
115+
#endif
113116

117+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
114118
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
115119
if (binding_idx == BINDING_IDX_K) {
116120
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
@@ -119,19 +123,113 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
119123
vui_lo >>= shift;
120124
vui_hi >>= shift;
121125

122-
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
126+
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
127+
#ifdef DATA_A_Q4_1
128+
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
129+
#else
130+
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
131+
#endif
123132
} else {
124133
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
125134
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
126135
uint shift = (iqs & 0x10) >> 2;
127136
vui_lo >>= shift;
128137
vui_hi >>= shift;
129138

130-
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
139+
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
140+
#ifdef DATA_A_Q4_1
141+
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
142+
#else
143+
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f));
144+
#endif
131145
}
132146
}
133147
#endif
134148

149+
#if defined(DATA_A_Q5_0)
150+
#define BLOCK_BYTE_SIZE 22
151+
#elif defined(DATA_A_Q5_1)
152+
#define BLOCK_BYTE_SIZE 24
153+
#endif
154+
155+
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
156+
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
157+
if (binding_idx == BINDING_IDX_K) {
158+
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
159+
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
160+
uint shift = (iqs & 0x10) >> 2;
161+
vui_lo >>= shift;
162+
vui_hi >>= shift;
163+
164+
#ifdef DATA_A_Q5_1
165+
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
166+
#else
167+
uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16);
168+
#endif
169+
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
170+
171+
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
172+
#ifdef DATA_A_Q5_1
173+
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m);
174+
#else
175+
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
176+
#endif
177+
} else {
178+
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
179+
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
180+
uint shift = (iqs & 0x10) >> 2;
181+
vui_lo >>= shift;
182+
vui_hi >>= shift;
183+
184+
#ifdef DATA_A_Q5_1
185+
uint qh = v_packed.v_data_packed16[a_offset + ib].qh;
186+
#else
187+
uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16);
188+
#endif
189+
FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f);
190+
191+
FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF);
192+
#ifdef DATA_A_Q5_1
193+
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m);
194+
#else
195+
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f));
196+
#endif
197+
}
198+
}
199+
#endif
200+
201+
202+
#if defined(DATA_A_IQ4_NL)
203+
#define BLOCK_BYTE_SIZE 18
204+
205+
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
206+
if (binding_idx == BINDING_IDX_K) {
207+
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
208+
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
209+
uint shift = (iqs & 0x10) >> 2;
210+
vui_lo >>= shift;
211+
vui_hi >>= shift;
212+
213+
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
214+
kvalues_iq4nl[vui_lo & 0xF],
215+
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
216+
kvalues_iq4nl[vui_hi & 0xF],
217+
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
218+
} else {
219+
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
220+
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
221+
uint shift = (iqs & 0x10) >> 2;
222+
vui_lo >>= shift;
223+
vui_hi >>= shift;
224+
225+
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(
226+
kvalues_iq4nl[vui_lo & 0xF],
227+
kvalues_iq4nl[(vui_lo >> 8) & 0xF],
228+
kvalues_iq4nl[vui_hi & 0xF],
229+
kvalues_iq4nl[(vui_hi >> 8) & 0xF]);
230+
}
231+
}
232+
#endif
135233
#if defined(DATA_A_Q8_0)
136234
#define BLOCK_BYTE_SIZE 34
137235
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ void process_shaders() {
655655
if (tname == "f16") {
656656
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
657657
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
658-
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
658+
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
659659
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
660660
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
661661
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
@@ -666,7 +666,7 @@ void process_shaders() {
666666
if (tname == "f16") {
667667
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
668668
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
669-
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
669+
} else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") {
670670
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
671671
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
672672
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);

0 commit comments

Comments
 (0)