Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 90 additions & 39 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,86 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
// ------------------ calc_superblock (final optimized version) ------------------
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i,
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
// Compute starting index in matrix B for this superblock
const uint y_idx = i * QUANT_K + 32 * ib32;

uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;

// Precompute indices for quantization lookup tables
const uint qh_base = 2 * ib32;
const uint qs_base = 4 * ib32;
const uint sc_index = ib32 / 2;
const uint sc_shift = 6 * (ib32 & 1);

// Loop over rows in the superblock
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
// Load per-block scales and shift for quantization
const uint16_t[4] scales = data_a[ibi].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);

const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1));
const uint sc = data_a[ibi].scales[sc_index] >> sc_shift;

// Temporary caches for decoding
FLOAT_TYPE dl_cache[4];
uint16_t gvf_cache[4];
float delta_cache[4];

// Precompute the multiplier and lookup values for 4 sub-blocks
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1));
const uint qs = data_a[ibi].qs[4 * ib32 + l];
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1);

const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);

FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int k = 0; k < 4; ++k) {
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
}
temp[j][n] = fma(dl, sum, temp[j][n]);
dl_cache[l] = FLOAT_TYPE(d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1));
const uint qh = data_a[ibi].qh[qh_base + l / 2] >> (4 * (l & 1));
const uint qs = data_a[ibi].qs[qs_base + l];
gvf_cache[l] = iq1s_grid[qs | ((qh & 7) << 8)];
delta_cache[l] = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
}

// Loop over columns of the output
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
// Compute base index for matrix B
const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4;
vec4 b_vals[8];

// Load 8 vec4 values from matrix B
[[unroll]] for (int idx = 0; idx < 8; ++idx) {
b_vals[idx] = vec4(data_b_v4[base_b_idx + idx]);
}

FLOAT_TYPE col_sum = FLOAT_TYPE(0.0);

// Loop over sub-blocks
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint16_t grid = gvf_cache[l];
const float dl = dl_cache[l];

// Decode 8 2-bit fbits from gvf_cache
float f0 = float(bitfieldExtract(grid, 0, 2));
float f1 = float(bitfieldExtract(grid, 2, 2));
float f2 = float(bitfieldExtract(grid, 4, 2));
float f3 = float(bitfieldExtract(grid, 6, 2));
float f4 = float(bitfieldExtract(grid, 8, 2));
float f5 = float(bitfieldExtract(grid, 10, 2));
float f6 = float(bitfieldExtract(grid, 12, 2));
float f7 = float(bitfieldExtract(grid, 14, 2));

// Pack into vec4 for vectorized FMA
const vec4 fbits_v0 = vec4(f0, f1, f2, f3);
const vec4 fbits_v1 = vec4(f4, f5, f6, f7);
const vec4 delta_v = vec4(delta_cache[l]);

// Vectorized fused multiply-add
vec4 sum_v = fma(b_vals[2*l + 0], fbits_v0 + delta_v, vec4(0.0));
sum_v = fma(b_vals[2*l + 1], fbits_v1 + delta_v, sum_v);

// Horizontal add to get scalar sum
FLOAT_TYPE sum = sum_v.x + sum_v.y + sum_v.z + sum_v.w;

// Accumulate to column sum
col_sum = fma(dl, sum, col_sum);
}
// Write result to temporary buffer
temp[j][n] += col_sum;
}
ibi += num_blocks_per_row;
}
Expand All @@ -44,39 +95,39 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

const uint num_blocks_per_row = p.ncols / QUANT_K;

// 8 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
const uint blocks_per_wg = gl_WorkGroupSize.x / 8;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 8; // 0...7
const uint itid = tid % 8;
const uint ix = tid / 8;

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
// Initialize temporary storage
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j)
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i)
temp[j][i] = FLOAT_TYPE(0);
}
}

// Loop over blocks assigned to this thread
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);

// Reduce results from temporary buffer to output
reduce_result(temp, d_offset, first_row, num_rows, tid);
}

void main() {
// Compute first row for this workgroup
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

// Initialize shared memory for quantization lookups
init_iq_shmem(gl_WorkGroupSize);

// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
// Early exit if out-of-bounds
if (first_row >= p.stride_d)
return;

// Number of rows to process for this workgroup
const uint rows_to_process = min(NUM_ROWS, p.stride_d - first_row);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty surprised if it helped to make the changes in this function - this will prevent the compiler from unrolling loops.


// Compute outputs for assigned rows
compute_outputs(first_row, rows_to_process);
}