Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 199 additions & 58 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

7 changes: 0 additions & 7 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@

#include "types.glsl"

#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ layout (push_constant) uniform parameter
} p;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};

Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
} p;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

uint get_idx() {
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.glsl"
#include "dequant_funcs.glsl"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@

#include "types.glsl"

#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
Expand All @@ -35,8 +37,6 @@ layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
layout (binding = 4) readonly buffer IDS {int data_ids[];};
#endif

#include "dequant_funcs.glsl"

layout (push_constant) uniform parameter
{
uint ncols;
Expand Down
62 changes: 27 additions & 35 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
Original file line number Diff line number Diff line change
Expand Up @@ -10,60 +10,56 @@

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
#define K_PER_ITER 8

#include "mul_mmq_funcs.glsl"
#elif defined(DATA_A_QUANT_K)
#define K_PER_ITER 16
#else
#error unimplemented
#endif

uint a_offset, b_offset, d_offset;

int32_t cache_b_qs[2];
int32_t cache_b_qs[K_PER_ITER / 4];
vec2 cache_b_ds;

#include "mul_mat_vecq_funcs.glsl"

void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;

// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4;
const uint b_qs_idx = tid % (32 / K_PER_ITER);
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);

#if QUANT_R == 2
// Assumes K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else
#if K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#elif K_PER_ITER == 16
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
#else
#error unimplemented
#endif
#endif

uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
ibi += p.ncols;

int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif

#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
}
}
}
Expand All @@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;

get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
a_offset /= QUANT_K_Q8_1;
b_offset /= QUANT_K_Q8_1;

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
Expand Down Expand Up @@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);

#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif

while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
Expand All @@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif

// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
Expand Down
Loading
Loading