Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: further optimize mul_mat_vec using larger loads #10387

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 63 additions & 38 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif

#include "types.comp"

#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
Expand All @@ -20,6 +29,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d;
}
#endif

#if defined(DATA_A_Q4_1)
Expand All @@ -29,6 +43,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(vui & 0xF, vui >> 4) * d + m;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const float m = float(data_a_packed16[a_offset + ib].m);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) * d + m;
}
#endif

#if defined(DATA_A_Q5_0)
Expand All @@ -39,6 +59,14 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d;
}
#endif

#if defined(DATA_A_Q5_1)
Expand All @@ -50,13 +78,28 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const float m = float(data_a_packed16[a_offset + ib].m);
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * d + m;
}
#endif

#if defined(DATA_A_Q8_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a[a_offset + ib].d);
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d;
}
#endif

#if defined(DATA_A_IQ4_NL)
Expand All @@ -65,4 +108,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d;
}
#endif
82 changes: 72 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types : require

#include "mul_mat_vec_base.comp"

Expand All @@ -12,16 +12,48 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;

#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
#endif


uint a_offset, b_offset, d_offset, y_offset;

shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];

void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
const uint col = i*BLOCK_SIZE + 2*tid;
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index

#if K_PER_ITER == 8
#if QUANT_R == 2
B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
#else
B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
#endif
#else
// Check if the second of the pair of elements is OOB, and don't fetch B or
// accumulate it. We still fetch a pair of elements for A, which is fine for
// quantized formats since they'll be within the same block. We should
Expand All @@ -34,16 +66,32 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
if (!OOB) {
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
}
#endif
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index

#if K_PER_ITER == 8
const vec4 v = dequantize4(ib, iqs, a_offset);
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);

// matrix multiplication
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
#else
const vec2 v = dequantize(ib, iqs, a_offset);

// matrix multiplication
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
if (!OOB) {
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
}
#endif
}
}

Expand All @@ -61,22 +109,33 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
temp[i] = FLOAT_TYPE(0);
}

const int unroll_count = 8;

const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);

uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i, false);
i += 2;
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i, true);
i += 2;
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
i++;
}

// sum up partial sums and write back result
Expand Down Expand Up @@ -106,6 +165,9 @@ void main() {
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};

layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

if (row >= p.stride_d) {
return;
}

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

if (row >= p.stride_d) {
return;
}

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

Expand Down
34 changes: 9 additions & 25 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,14 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;

shared FLOAT_TYPE tmp[32];

// Declare aliased versions of A and B bindings that can use 16b/32b loads for
// the quantized values, and vec4 loads for B.
struct block_q4_K_u32
{
f16vec2 d;
uint32_t scales[3*QUANT_K/64/4];
uint32_t qs[QUANT_K/2/4];
};

struct block_q4_K_u16
{
f16vec2 d;
uint16_t scales[3*QUANT_K/64/2];
uint16_t qs[QUANT_K/2/2];
};

layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];};
layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};

// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;

if (row >= p.stride_d) {
return;
}

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

Expand Down Expand Up @@ -64,9 +48,9 @@ void main() {
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4];
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
Expand All @@ -80,8 +64,8 @@ void main() {
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));

uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16];
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];

uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
Expand Down
Loading
Loading