Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,22 @@ vec2 get_dm_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}

const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0],
data_a_packed32[ib_k].scales[1],
data_a_packed32[ib_k].scales[2]);
const uint scalesoffs = (is & 3) * 8;

const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;

const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
u8vec2 scale_dm = u8vec2(sc, mbyte);

return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm);
}
Expand Down
54 changes: 28 additions & 26 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin

const vec2 loadd = vec2(data_a[ib].dm);

const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;

const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const uvec3 scales = uvec3(data_a_packed32[ib].scales[0],
data_a_packed32[ib].scales[1],
data_a_packed32[ib].scales[2]);
const uint scalesoffs = (is & 3) * 8;

const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;

const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));

const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
Expand All @@ -237,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin

const vec2 loadd = vec2(data_a[ib].dm);

const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;

const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const uvec3 scales = uvec3(data_a_packed32[ib].scales[0],
data_a_packed32[ib].scales[1],
data_a_packed32[ib].scales[2]);
const uint scalesoffs = (is & 3) * 8;

const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;

const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));

const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
Expand Down