diff --git a/ggml-metal.m b/ggml-metal.m index 4267db9be3e61..f0d10e19720b0 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -76,14 +76,15 @@ GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); @@ -205,14 +206,15 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); @@ -270,14 +272,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); @@ -844,97 +847,45 @@ void ggml_metal_graph_compute( [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - + } else if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + src1t == GGML_TYPE_F32 && + ne00%32 == 0) { // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F16: - { - nth0 = 32; - nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; - } break; - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32]; - } break; - case GGML_TYPE_Q2_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; - } break; - case GGML_TYPE_Q3_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; - } break; - case GGML_TYPE_Q4_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; - } break; - case GGML_TYPE_Q5_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; - } break; - case GGML_TYPE_Q6_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; - } break; - default: - { - metal_printf("Asserting on type %d\n",(int)src0t); - GGML_ASSERT(false && "not implemented"); - } - }; - + switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; break; + default: GGML_ASSERT(false && "MUL MAT-VEC not implemented"); + } + int buffer_size_aligned = (512 / ggml_blck_size(src0t) * ggml_element_size(src0) + 31) / 32 * 32; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + // only for k-quants we use threadgroup memory + if (ggml_blck_size(src0t) >= 64){ + [encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0]; + } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + } else { + switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; break; + default: GGML_ASSERT(false && " not implemented"); + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -953,27 +904,8 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { -#ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#else - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#endif - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } } break; case GGML_OP_GET_ROWS: diff --git a/ggml-metal.metal b/ggml-metal.metal index 8cdf0b9d2ba0a..f4a820f872766 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3,26 +3,7 @@ using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) - -#define QK4_0 32 -#define QR4_0 2 -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; - -#define QK4_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; - -#define QK8_0 32 -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; +#define MIN(x, y) ((x) < (y) ? (x) : (y)) kernel void kernel_add( device const float4 * src0, @@ -309,194 +290,6 @@ kernel void kernel_rms_norm( } } -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); - float2 acc = 0.f; - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// giard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template -void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, - int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, - uint3 tgpig, uint tiisg, uint sgitg) { - const int nb = ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; - const uint offset0 = first_row * nb + im/gqa*(nb*ne0); - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; // src1 vector cache - float sumf[nr]={0.f}; - - const int ix = tiisg/2; - const int il = 8*(tiisg%2); - - device const float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); - } - - yb += QK4_0 * 16; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -kernel void kernel_mul_mat_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mat_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mat_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - - const int nb = ne00/QK8_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; - const uint offset0 = first_row * nb + im/gqa*(nb*ne0); - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; - float sumf[nr]={0.f}; - - const int ix = tiisg/2; - const int il = tiisg%2; - - device const float * yb = y + ix * QK8_0 + 16*il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - for (int i = 0; i < 16; ++i) { - yl[i] = yb[i]; - } - - for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + 16*il; - float sumq = 0.f; - for (int iq = 0; iq < 16; ++iq) { - sumq += qs[iq] * yl[iq]; - } - sumf[row] += sumq*x[ib+row*nb].d; - } - - yb += QK8_0 * 16; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1, @@ -808,7 +601,27 @@ kernel void kernel_cpy_f32_f32( } } -//============================================ k-quants ====================================================== +//============================================ quant blocks ====================================================== + +#define QK4_0 32 +#define QR4_0 2 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + +#define QK4_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; + +#define QK8_0 32 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; #ifndef QK_K #define QK_K 256 @@ -882,1012 +695,547 @@ typedef struct { } block_q6_K; // 210 bytes / block -static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { - uchar4 r; - if (j < 4) { - r[0] = q[j+0] & 63; - r[2] = q[j+1] & 63; - r[1] = q[j+4] & 63; - r[3] = q[j+5] & 63; - } else { - r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); - r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); +//============================================ quant drivers ====================================================== +// load quantized blocks from device/threadgroup memory, dequantize 16 weights to a half4x4 or float4x4 type +// init(il) : prepare some values that can be reused as long as il doesn't change. +// dequantize(...) : dequantize 16 continuous weights. +// inner_product_pre(il, yl) : multiply yl elements by a factor to speed up inner product calculations. +// inner_product(...) : do inner product, may not use continuous weights. + +static inline void fix_y_v1(thread float & sumy, thread float4x4 & yl) { + sumy = 0.f; + for (int i = 0; i < 8; i += 2) { + sumy += yl[ i/4][i%4]; sumy += yl[ i/4][i%4+1]; + sumy += yl[2+i/4][i%4]; sumy += yl[2+i/4][i%4+1]; + yl[i/4 ][i%4 ] = yl[ i/4][i%4]; + yl[i/4 ][i%4+1] = 1/256.f * yl[ i/4][i%4+1]; + yl[i/4+2][i%4 ] = 1/16.f * yl[2+i/4][i%4]; + yl[i/4+2][i%4+1] = 1/4096.f * yl[2+i/4][i%4+1]; } - return r; } -//====================================== dot products ========================= - -kernel void kernel_mul_mat_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; +static inline void fix_y_v2(thread float & coef1, thread float & coef2, thread float & sumy, thread float4x4 & yl) { + sumy = 0.f; + for (int i = 0; i < 16; i += 2) { + sumy += yl[i/4][i%4]; + sumy += yl[i/4][i%4+1]; + yl[i/4][i%4] = coef1 * yl[i/4][i%4]; + yl[i/4][i%4+1] = coef2 * yl[i/4][i%4+1]; + } +} - const int step = sizeof(block_q2_K) * nb; +template +class q4_0_driver { + public: + uint16_t mask1, mask2, q_offset; + float coef1, coef2, sumy; -#if QK_K == 256 - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int im = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 - - device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + void init(int il) { + mask1 = il ? 0x00F0 : 0x000F; mask2 = mask1 << 8; + coef1 = il ? 1/16.f : 1.f; coef2 = coef1 / 256.f; + q_offset = il ? 4 : 0; } - device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { + void inner_product_pre(int il, thread float4x4 & yl){ + fix_y_v1(sumy, yl); + } - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + const half d = xb->d; + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + sum += yl[i/4 ][i%4] * (q[i/2] & 0x000F); + sum += yl[i/4 ][i%4+1] * (q[i/2] & 0x0F00); + sum += yl[i/4+2][i%4] * (q[i/2] & 0x00F0); + sum += yl[i/4+2][i%4+1] * (q[i/2] & 0xF000); } - float dall = dh[0]; - float dmin = dh[1] * 1.f/16.f; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - - qs += step/2; - sc += step; - dh += step/2; + sum = d * (sum - 8.f * sumy); } - y4 += 4 * QK_K; - } -#else - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0...1 - - device const float * y4 = y + ix * QK_K + 8 * it; - - for (int ib = ix; ib < nb; ib += 16) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + const half d = xb->d; + addr_uint16_p q = (addr_uint16_p)xb->qs; + for (int i = 0; i < 16; i += 2) { + reg[i/4][i%4] = d * (coef1 * (q[i/2] & mask1) - 8.f); + reg[i/4][i%4+1] = d * (coef2 * (q[i/2] & mask2) - 8.f); + } + } +}; + +template +class q4_1_driver { + public: + uint16_t mask1, mask2, q_offset; + float coef1, coef2, sumy; + + void init(int il) { + mask1 = il ? 0x00F0 : 0x000F; mask2 = mask1 << 8; + coef1 = il ? 1/16.f : 1.f; coef2 = coef1 / 256.f; + q_offset = il ? 4 : 0; } - device const uint8_t * sc = (device const uint8_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { + void inner_product_pre(int il, thread float4x4 & yl){ + fix_y_v1(sumy, yl); + } - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + const half d = xb->d; + const half m = xb->m; + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + sum += yl[i/4 ][i%4] * (q[i/2] & 0x000F); + sum += yl[i/4 ][i%4+1] * (q[i/2] & 0x0F00); + sum += yl[i/4+2][i%4] * (q[i/2] & 0x00F0); + sum += yl[i/4+2][i%4+1] * (q[i/2] & 0xF000); } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); - - qs += step/2; - sc += step; - dh += step/2; + sum = d * sum + m * sumy; } - y4 += 16 * QK_K; - } -#endif - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + const half d = xb->d; + const half m = xb->m; + addr_uint16_p q = (addr_uint16_p)xb->qs; + for (int i = 0; i < 16; i += 2) { + reg[i/4][i%4] = d * coef1 * (q[i/2] & mask1) + m; + reg[i/4][i%4+1] = d * coef2 * (q[i/2] & mask2) + m; + } } - } -} - -#if QK_K == 256 -kernel void kernel_mul_mat_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float yl[16]; - - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid/2 - 4*ip; // 0...3 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; - - const uint16_t m1 = 1 << (4*ip + il); - const uint16_t m2 = m1 << 8; +}; - const int shift = 2*il; - const uint16_t qm1 = 0x0003 << shift; - const uint16_t qm2 = 0x0300 << shift; - const int32_t v1 = 4 << shift; - const int32_t v2 = 1024 << shift; +template +class q8_0_driver { + public: + uint16_t mask1, mask2, q_offset; + float coef1, coef2, sumy; - const uint16_t s_shift1 = 4*ip; - const uint16_t s_shift2 = s_shift1 + 2*(il/2); - const int ik = 4 + (il%2); - - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; - - const int step = sizeof(block_q3_K) * nb / 2; - - device const float * y1 = yy + ix*QK_K + y_offset; + void init(int il) { + q_offset = il * 16; + } - float sumf1[2] = {0.f}, sumf2[2] = {0.f}; - for (int i = ix; i < nb; i += 2) { + void inner_product_pre(int il, thread float4x4 & yl){ + } - for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; - yl[l+8] = y1[l+16]; + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + const half d = xb->d; + for (int i = 0; i < 16; i++) { + sum += yl[i/4][i%4] * xb->qs[i + q_offset]; + } + sum = d * sum; } - device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); - device const uint16_t * a = (device const uint16_t *)(x[i].scales); - device const half * dh = &x[i].d; + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + const half d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (xb->qs[i + q_offset] * d); + } + } +}; - for (int row = 0; row < 2; ++row) { +template +class f16_driver { + public: + void init(int il) {} - const float d_all = (float)dh[0]; - const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + void inner_product_pre(int il, thread float4x4 & yl) {} - float s1 = 0, s2 = 0; - for (int l = 0; l < n; l += 2) { - const uint16_t qs = q[l/2]; - s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); - s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); - } - float d = d_all * (s1 + 1.f/256.f * s2); - sumf1[row] += d * scales[0]; - sumf2[row] += d; - - s1 = s2 = 0; - for (int l = 0; l < n; l += 2) { - const uint16_t qs = q[l/2+8]; - s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); - s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + half4x4 temp = *xb; + for (int i = 0; i < 16; i++){ + sum += yl[i/4][i%4] * temp[i/4][i%4]; } - d = d_all * (s1 + 1.f/256.f * s2); - sumf1[row] += d * scales[1]; - sumf2[row] += d; - - q += step; - h += step; - a += step; - dh += step; - } - y1 += 2 * QK_K; - - } + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + half4x4 temp = *xb; + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } + } +}; + +template +class q2_K_driver { + public: + uint16_t mask1, mask2, q_offset; + float coef1, coef2, sumy; + + void init(int il) { + #if QK_K == 256 + q_offset = 16*(il/8) + 8*(il&1); + il = (il/2)%4; + #else + q_offset = 0; + #endif + coef1 = il>1 ? (il>2 ? 1/64.f : 1/16.f) : (il>0 ? 1/4.f : 1.f); coef2 = coef1/256.f; + mask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); mask2 = mask1 << 8; + } - for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; + void get_scales(addr_block_q_p xb, int il, thread float & dl, thread float & ml) { + const float d = (float)(xb->d); + const float min = (float)(xb->dmin); + dl = d * (xb->scales[il] & 0xF), ml = min * (xb->scales[il] >> 4); } - } -} -#else -kernel void kernel_mul_mat_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nb = ne00/QK_K; + void inner_product_pre(int il, thread float4x4 & yl){ + fix_y_v2(coef1, coef2, sumy, yl); + } - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t r2 = tgpig.z; - - const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - const int ix = tiisg/4; - const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - float2 sum = {0.f, 0.f}; - - for (int i = ix; i < nb; i += 8) { - - const float d_all = (float)(x[i].d); - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); - device const uint16_t * s = (device const uint16_t *)(x[i].scales); - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); - const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; - const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; - const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - - for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> im; - sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) - + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) - + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) - + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); - sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) - + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) - + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) - + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + float dl, ml; + get_scales(xb, il, dl, ml); + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; + for (int i = 0; i < 16; i += 2) { + sum += yl[i/4][i%4 ] * (q[i/2] & mask1); + sum += yl[i/4][i%4+1] * (q[i/2] & mask2); + } + sum = dl * sum - ml * sumy; } - } - const float sumf = sum[0] + sum[1] * 1.f/256.f; + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + float dl, ml; + get_scales(xb, il, dl, ml); + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; + for (int i = 0; i < 16; i += 2) { + reg[i/4][i%4 ] = coef1 * dl * (q[i/2] & mask1) - ml; + reg[i/4][i%4+1] = coef2 * dl * (q[i/2] & mask2) - ml; + } + } +}; - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + row] = tot; - } +template +class q3_K_driver { + public: + uint16_t m1, m2, kmask1, kmask2 ,mask1, mask2; + float coef1, coef2, sumy; + float4x4 yl_str; + uint16_t q_offset, h_offset, d_loc1, d_loc2; -} + void init(int il) { +#if QK_K == 256 + d_loc1 = 8 + il%4; d_loc2 = il%8; + q_offset = 16 * (il/8) + 8 * (il&1); h_offset = 8 * (il&1); + kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : ((il/4)>0 ? 12 : 3); kmask2 = il/8 ? 0xF0 : 0x0F; + m1 = 1 << (il/2); m2 = m1 << 8; + il = (il/2)%4; +#else + m1 = 1 << (il*2); m2 = m1 << 8; + q_offset = 0; h_offset = 0; + kmask1 = il&1 ? 0xF0 : 0x0F; + d_loc1 = il/2;; #endif + coef1 = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); coef2 = coef1 / 256.h; + mask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); mask2 = mask1 << 8; + } + void get_scales(addr_block_q_p xb, int il, thread float & dl) { + const half d_all = xb->d; #if QK_K == 256 -kernel void kernel_mul_mat_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int im = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[16]; - float yh[16]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; - yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + uint16_t scale_1 = xb->scales[d_loc1], scale_2 = xb->scales[d_loc2]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \ + (scale_2&kmask2) | ((scale_1&kmask1) << 4); + dl = il < 8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); +#else + float kcoef = il&1 ? 1.f/16.f : 1.f; + dl = d_all * ((xb->scales[d_loc1] & kmask1) * kcoef - 8); +#endif } - device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; - device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & kmask1; - sc16[1] = sc[2] & kmask1; - sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); - sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); - - device const uint16_t * q2 = q1 + 32; + void inner_product_pre(int il, thread float4x4 & yl){ + fix_y_v2(coef1, coef2, sumy, yl); + } - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + float dl; + get_scales(xb, il, dl); + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; + addr_uint16_p h = (addr_uint16_p)xb->hmask + h_offset; +#if QK_K == 256 + for (int i = 0; i < 16; i += 2) { + sum += yl[i/4][i%4 ] * ((q[i/2] & mask1) - ((h[i/2] & m1) ? 0 : 4.f/coef1)); + sum += yl[i/4][i%4+1] * ((q[i/2] & mask2) - ((h[i/2] & m2) ? 0 : 4.f/coef2)); + } +#else for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + sum +=yl[i/4 ][i%4 ] * ((q[i/2 ] & mask1) - (h[i/2] & m1 ? 0 : 4.f/coef1)); + sum +=yl[i/4 ][i%4+1] * ((q[i/2 ] & mask2) - (h[i/2] & m2 ? 0 : 4.f/coef2)); + sum +=yl[i/4+2][i%4 ] * ((q[i/2+4] & mask1) - (h[i/2] & (2*m1) ? 0 : 4.f/coef1)); + sum +=yl[i/4+2][i%4+1] * ((q[i/2+4] & mask2) - (h[i/2] & (2*m2) ? 0 : 4.f/coef2)); } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - sc += step; - dh += step; +#endif + sum = dl * sum; } - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; - } - } -} + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + float dl; + get_scales(xb, il, dl); + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; + addr_uint16_p h = (addr_uint16_p)xb->hmask + h_offset; +#if QK_K == 256 + for (int i = 0; i < 16; i += 2) { + reg[i/4][i%4] = coef1 * dl * ((q[i/2] & mask1) - ((h[i/2] & m1) ? 0 : 4.f/coef1)); + reg[i/4][i%4+1] = coef2 * dl * ((q[i/2] & mask2) - ((h[i/2] & m2) ? 0 : 4.f/coef2)); + } #else -kernel void kernel_mul_mat_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int ix = tiisg/4; // 0...7 - const int it = tiisg%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[8]; - float yh[8]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 8 * it; - - uint16_t sc16[4]; - - for (int ib = ix; ib < nb; ib += 8) { - - float2 sumy = {0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i+ 0]; sumy[0] += yl[i]; - yh[i] = y4[i+32]; sumy[1] += yh[i]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & 0x000f; - sc16[1] = sc[0] & 0x0f00; - sc16[2] = sc[0] & 0x00f0; - sc16[3] = sc[0] & 0xf000; - - float2 acc1 = {0.f, 0.f}; - float2 acc2 = {0.f, 0.f}; for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); - acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); - acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); - acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); + reg[i/4 ][i%4 ] = coef1 * dl * ((q[i/2 ] & mask1) - (h[i/2] & m1 ? 0 : 4.f/coef1)); + reg[i/4 ][i%4+1] = coef2 * dl * ((q[i/2 ] & mask2) - (h[i/2] & m2 ? 0 : 4.f/coef2)); + reg[i/4+2][i%4 ] = coef1 * dl * ((q[i/2+4] & mask1) - (h[i/2] & (2*m1) ? 0 : 4.f/coef1)); + reg[i/4+2][i%4+1] = coef2 * dl * ((q[i/2+4] & mask2) - (h[i/2] & (2*m2) ? 0 : 4.f/coef2)); } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + - (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - - dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); - - qs += step; - sc += step; - dh += step; +#endif } - - y4 += 8 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum; +}; + +template +class q4_K_driver { + public: + uint16_t d_mask1, d_mask2, m_mask1, mask1, mask2; + float coef1, coef2, sumy1, sumy2; + uint16_t d_loc1, d_loc2, m_loc1, m_loc2, q_offset; + + void init(int il) { + q_offset = (il/4) * 16 + 4 * (il%4); + d_mask1 = il < 8 ? 0x3F3F : 0x0F0F; d_mask2 = il < 8 ? 0x0000 : 0xC0C0; + d_loc1 = il < 8 ? il/4 : il/4 + 2; d_loc2 = il < 8 ? il/4 : il/4 - 2; + m_mask1 = il < 8 ? 0x3F3F : 0xF0F0; + m_loc1 = il/4 + 2; m_loc2 = il/4; } - } -} -#endif - -kernel void kernel_mul_mat_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float sumf[2]={0.f}; - - const int step = sizeof(block_q5_K) * nb; + void get_scales(addr_block_q_p xb, int il, thread float & dl1, thread float & ml1, thread float & dl2, thread float & ml2) { + #if QK_K == 256 + const float d = (float)(xb->d); + const float min = (float)(xb->dmin); + addr_uint16_p sc = (addr_uint16_p)xb->scales; + uint16_t d_int = (sc[d_loc1] & d_mask1) | ((sc[d_loc2] & d_mask2) >> 2); + uint16_t m_int = il < 8 ? (sc[m_loc1] & m_mask1) : ((sc[m_loc1] & m_mask1) >> 4); + m_int = m_int | ((sc[m_loc2] & d_mask2) >> 2); + dl1 = as_type(d_int)[0] * d, ml1 = as_type(m_int)[0] * min; + dl2 = as_type(d_int)[1] * d, ml2 = as_type(m_int)[1] * min; +#else + dl1 = (float)(xb->d[0]) * (xb->scales[0]&0xF); dl2 = (float)(xb->d[0]) * (xb->scales[1]&0xF); + ml1 = (float)(xb->d[1]) * (xb->scales[0]>>4); ml2 = (float)(xb->d[1]) * (xb->scales[1]>>4); +#endif + } + void get_scales2(addr_block_q_p xb, int il, thread float & dl, thread float & ml) { + q_offset = (il/4) * 16 + 8 * (il&1); + mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8; + coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f; #if QK_K == 256 -# - float yl[16], yh[16]; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int im = tid/4; - const int ir = tid%4; - const int n = 8; - - const int l0 = n*ir; - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - const uint8_t hm1 = 1u << (2*im); - const uint8_t hm2 = hm1 << 1; - const uint8_t hm3 = hm1 << 4; - const uint8_t hm4 = hm2 << 4; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - device const float * y1 = yy + ix*QK_K + y_offset; - - for (int i = ix; i < nb; i += 4) { - - device const uint8_t * q1 = x[i].qs + q_offset; - device const uint8_t * qh = x[i].qh + l0; - device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales + im; - - device const float * y2 = y1 + 128; - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; - yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + d_mask1 = il < 8 ? 63 : 0x0F; d_mask2 = il < 8 ? 0 : 192; + d_loc1 = il < 8 ? il/2 : 4 + il/2; d_loc2 = il < 8 ? il/2 : il/2 - 4; + m_mask1 = il < 8 ? 63 : 0xF0; + m_loc1 = il/2 + 4; m_loc2 = il/2; + const float d = (float)(xb->d); + const float min = (float)(xb->dmin); + uint16_t d_int = (xb->scales[d_loc1] & d_mask1) | ((xb->scales[d_loc2] & d_mask2) >> 2); + uint16_t m_int = il < 8 ? (xb->scales[m_loc1] & m_mask1) : ((xb->scales[m_loc1] & m_mask1) >> 4); + m_int = m_int | ((xb->scales[m_loc2] & d_mask2) >> 2); + dl = d_int * d, ml = m_int * min; +#else + dl = il<2 ? (float)(xb->d[0]) * (xb->scales[0]&0xF) : (float)(xb->d[0]) * (xb->scales[1]&0xF); + ml = il<2 ? (float)(xb->d[1]) * (xb->scales[0]>>4) : (float)(xb->d[1]) * (xb->scales[1]>>4); +#endif } - for (int row = 0; row < 2; ++row) { - - device const uint8_t * q2 = q1 + 64; - - sc16[0] = a[0] & kmask1; - sc16[1] = a[2] & kmask1; - sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); - sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - - float4 acc = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - uint8_t h = qh[l]; - acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); - acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); - acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); - acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); + void inner_product_pre(int il, thread float4x4 & yl){ + sumy1 = 0.f; sumy2 = 0.f; + for (int i = 0; i < 8; i += 2) { + sumy1 += yl[i/4 ][i%4]; sumy1 += yl[i/4 ][i%4+1]; + sumy2 += yl[2+i/4][i%4]; sumy2 += yl[2+i/4][i%4+1]; + yl[i/4 ][i%4 ] = yl[i/4][i%4]; + yl[i/4 ][i%4+1] = 1/256.f * yl[i/4][i%4+1]; + yl[i/4+2][i%4 ] = 1/16.f * yl[2+i/4][i%4]; + yl[i/4+2][i%4+1] = 1/4096.f * yl[2+i/4][i%4+1]; } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - qh += step; - dh += step/2; - a += step/2; - } - y1 += 4 * QK_K; - - } -#else - float yl[8], yh[8]; - - const int il = 4 * (tiisg/8); // 0, 4, 8, 12 - const int ix = tiisg%8; - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - device const float * y = yy + ix*QK_K + il; - - for (int i = ix; i < nb; i += 8) { - - for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; - yl[l+4] = y[l+16]; - yh[l+0] = y[l+32]; - yh[l+4] = y[l+48]; + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + float dl1, ml1, dl2, ml2; + float sum2 = 0.f; + get_scales(xb, il, dl1, ml1, dl2, ml2); + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; + for (int i = 0; i < 8; i += 2) { + sum += yl[i/4 ][i%4 ] * ((q[i/2]&0x000F)); + sum += yl[i/4 ][i%4+1] * ((q[i/2]&0x0F00)); + sum2 += yl[i/4+2][i%4 ] * ((q[i/2]&0x00F0)); + sum2 += yl[i/4+2][i%4+1] * ((q[i/2]&0xF000)); + } + sum = dl1 * sum - ml1 * sumy1 + dl2 * sum2 - ml2 * sumy2; } - device const half * dh = &x[i].d; - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].qh + in; - device const int8_t * s = x[i].scales; - - for (int row = 0; row < 2; ++row) { - - const float d = dh[0]; - - float2 acc = {0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> im; - acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) - + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); - acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) - + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + float dl, ml; + get_scales2(xb, il, dl, ml); + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; + for (int i = 0; i < 16; i += 2) { + reg[i/4][i%4] = coef1 * dl * (q[i/2] & mask1) - ml; + reg[i/4][i%4+1] = coef2 * dl * (q[i/2] & mask2) - ml; } - sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); - - q += step; - h += step; - s += step; - dh += step/2; - } - - y += 8 * QK_K; - } +}; + +template +class q5_K_driver { + public: + uint16_t m1, m2, d_mask1, d_mask2, m_mask1, mask1, mask2; + float coef1, coef2, sumy; + uint16_t d_loc1, d_loc2, m_loc1, m_loc2, q_offset, h_offset; + + void init(int il) { + d_mask1 = il < 8 ? 63 : 0x0F; d_mask2 = il < 8 ? 0 : 192; + d_loc1 = il < 8 ? il/2 : 4 + il/2; d_loc2 = il < 8 ? il/2 : il/2 - 4; + m_mask1 = il < 8 ? 63 : 0xF0; + m_loc1 = il/2 + 4; m_loc2 = il/2; + mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8; + coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f; +#if QK_K == 256 + q_offset = (il/4) * 16 + 8 * (il&1); h_offset = 8 * (il&1); + m1 = 1 << (il/2); m2 = m1 << 8; +#else + q_offset = 8 * (il&1); h_offset = 0; + m1 = 1 << (il*2); m2 = m1 << 8; #endif - - for (int row = 0; row < 2; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; } - } - -} - -kernel void kernel_mul_mat_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int r2 = tgpig.z; - - const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float sumf = 0; + void get_scales(addr_block_q_p xb, int il, thread float & dl, thread float & ml) { #if QK_K == 256 - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; - - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += 2) { - - device const uint8_t * q1 = x[i].ql + q_offset_l; - device const uint8_t * q2 = q1 + 32; - device const uint8_t * qh = x[i].qh + q_offset_h; - device const int8_t * sc = x[i].scales + is; - - device const float * y = yy + i * QK_K + y_offset; - - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); - - } - + uint16_t d_int = (xb->scales[d_loc1] & d_mask1) | ((xb->scales[d_loc2] & d_mask2) >> 2); + uint16_t m_int = il < 8 ? (xb->scales[m_loc1] & m_mask1) : ((xb->scales[m_loc1] & m_mask1) >> 4); + m_int = m_int | ((xb->scales[m_loc2] & d_mask2) >> 2); + dl = d_int * xb->d, ml = m_int * xb->dmin; #else - const int ix = tiisg/4; - const int il = 4*(tiisg%4); - - for (int i = ix; i < nb; i += 8) { - device const float * y = yy + i * QK_K + il; - device const uint8_t * ql = x[i].ql + il; - device const uint8_t * qh = x[i].qh + il; - device const int8_t * s = x[i].scales; - - const float d = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); - } - + dl = (float)(xb->d) * xb->scales[il]; ml = 0.f; #endif + } - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + row] = tot; - } -} - -//============================= templates and their specializations ============================= - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const half d = il ? (xb->d / 16.h) : xb->d; - const half m = il ? ( -8.h * 16.h) : -8.h; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = il ? 0xF000 : 0x0F00; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; - reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; - } -} - -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const half d = il ? (xb->d / 16.h) : xb->d; - const half m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = il ? 0xF000 : 0x0F00; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m; - reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; - } -} - -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i=0;i<16;i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} - -template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const half d = xb->d; - const half min = xb->dmin; - device const uint8_t * q = (device const uint8_t *)xb->qs; - half dl, ml; - uint8_t sc = xb->scales[il]; + void inner_product_pre(int il, thread float4x4 & yl){ + fix_y_v2(coef1, coef2, sumy, yl); + } + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + float dl, ml; + get_scales(xb, il, dl, ml); + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; + addr_uint16_p h = (addr_uint16_p)xb->qh + h_offset; #if QK_K == 256 - q = q + 32*(il/8) + 16*(il&1); - il = (il/2)%4; + for (int i = 0; i < 16; i += 2) { + sum += yl[i/4][i%4 ] * ((q[i/2] & mask1) + ((h[i/2] & m1) ? 16.f/coef1 : 0)); + sum += yl[i/4][i%4+1] * ((q[i/2] & mask2) + ((h[i/2] & m2) ? 16.f/coef2 : 0)); + } +#else + for (int i = 0; i < 8; i += 2) { + sum += yl[i/4 ][i%4 ] * ((q[i/2 ] & mask1) - (h[i/2] & m1 ? 0 : 16.f/coef1)); + sum += yl[i/4 ][i%4+1] * ((q[i/2 ] & mask2) - (h[i/2] & m2 ? 0 : 16.f/coef2)); + sum += yl[i/4+2][i%4 ] * ((q[i/2+4] & mask1) - (h[i/2] & (2*m1) ? 0 : 16.f/coef1)); + sum += yl[i/4+2][i%4+1] * ((q[i/2+4] & mask2) - (h[i/2] & (2*m2) ? 0 : 16.f/coef2)); + } #endif - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const float d_all = (float)(xb->d); - device const uint8_t * q = (device const uint8_t *)xb->qs; - device const uint8_t * h = (device const uint8_t *)xb->hmask; - device const int8_t * scales = (device const int8_t *)xb->scales; + sum = dl * sum - ml * sumy; + } + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + float dl, ml; + get_scales(xb, il, dl, ml); + addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; + addr_uint16_p h = (addr_uint16_p)xb->qh + h_offset; #if QK_K == 256 - q = q + 32 * (il/8) + 16 * (il&1); - h = h + 16 * (il&1); - uint8_t m = 1 << (il/2); - uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ - ((il/4)>0 ? 12 : 3); - uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; - uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \ - (scale_2&kmask2) | ((scale_1&kmask1) << 4); - float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); - - il = (il/2)%4; - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef)); - } + for (int i = 0; i < 16; i += 2) { + reg[i/4][i%4 ] = coef1 * dl * ((q[i/2] & mask1) + (h[i/2] & m1 ? 16.f/coef1 : 0)) - ml; + reg[i/4][i%4+1] = coef2 * dl * ((q[i/2] & mask2) + (h[i/2] & m2 ? 16.f/coef2 : 0)) - ml; + } #else - float kcoef = il&1 ? 1.f/16.f : 1.f; - uint16_t kmask = il&1 ? 0xF0 : 0x0F; - float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint8_t m = 1<<(il*2); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); - } + for (int i = 0; i < 8; i += 2) { + reg[i/4 ][i%4 ] = coef1 * dl * ((q[i/2 ] & mask1) - (h[i/2] & m1 ? 0 : 16.f/coef1)); + reg[i/4 ][i%4+1] = coef2 * dl * ((q[i/2 ] & mask2) - (h[i/2] & m2 ? 0 : 16.f/coef2)); + reg[i/4+2][i%4 ] = coef1 * dl * ((q[i/2+4] & mask1) - (h[i/2] & (2*m1) ? 0 : 16.f/coef1)); + reg[i/4+2][i%4+1] = coef2 * dl * ((q[i/2+4] & mask2) - (h[i/2] & (2*m2) ? 0 : 16.f/coef2)); + } #endif -} + } +}; -template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; +template +class q6_K_driver { + public: + uint16_t hmask1, hmask2, lmask1, lmask2; + float coef1, coef2, h_coef, sumy1, sumy2; + uint16_t d_loc, q_offset, h_offset; + void init(int il) { + d_loc = il; #if QK_K == 256 - const float d = (float)(xb->d); - const float min = (float)(xb->dmin); - short is = (il/4) * 2; - q = q + (il/4) * 32 + 16 * (il&1); - il = il%4; - const uchar4 sc = get_scale_min_k4(is, xb->scales); - const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; - const float ml = il<2 ? min * sc[1] : min * sc[3]; + q_offset = 32*(il/8) + 8*(il%4); h_offset = 16*(il/8) + 8*(il&1); + il = (il/2)%4; #else - q = q + 16 * (il&1); - device const uint8_t * s = xb->scales; - device const half2 * dh = (device const half2 *)xb->d; - const float2 d = (float2)dh[0]; - const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4); + q_offset = 8 * (il&1); h_offset = 0; #endif - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} + hmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3), hmask2 = hmask1 << 8; + lmask1 = il>1 ? 0xF0 : 0x0F, lmask2 = lmask1 << 8; + h_coef = il&1 ? 4.h : 16.h; + coef1 = il>1 ? 1.f/16.f : 1.f, coef2 = coef1 / 256.f; + } -template -void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; - device const uint8_t * qh = xb->qh; + void get_scales(addr_block_q_p xb, int il, thread float & dl) { + dl = (float)(xb->d) * xb->scales[d_loc]; + } + void inner_product_pre(int il, thread float4x4 & yl){ + sumy1 = 0.f; sumy2 = 0.f; + for (int i = 0; i < 8; i += 2) { + sumy1 += yl[i/4 ][i%4]; sumy1 += yl[i/4 ][i%4+1]; + sumy2 += yl[2+i/4][i%4]; sumy2 += yl[2+i/4][i%4+1]; + yl[i/4 ][i%4 ] = yl[i/4][i%4]; + yl[i/4 ][i%4+1] = 1/256.f * yl[i/4][i%4+1]; + yl[i/4+2][i%4 ] = 1/16.f * yl[2+i/4][i%4]; + yl[i/4+2][i%4+1] = 1/4096.f * yl[2+i/4][i%4+1]; + } #if QK_K == 256 - const float d = (float)(xb->d); - const float min = (float)(xb->dmin); - short is = (il/4) * 2; - q = q + 32 * (il/4) + 16 * (il&1); - qh = qh + 16 * (il&1); - uint8_t ul = 1 << (il/2); - il = il%4; - const uchar4 sc = get_scale_min_k4(is, xb->scales); - const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; - const float ml = il<2 ? min * sc[1] : min * sc[3]; - - const ushort mask = il<2 ? 0x0F : 0xF0; - const float qh_val = il<2 ? 16.f : 256.f; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; - } + q_offset = 32*(il/8) + 4*(il%8); h_offset = 16*(il/8) + 4*(il%4); + hmask1 = (il%8)<4 ? 3 : 12; hmask2 = hmask1 << 8; + h_coef = (il%8)<4 ? 16.f : 4.f; d_loc = 8*(il/8) + (il%8)/2; #else - q = q + 16 * (il&1); - device const int8_t * s = xb->scales; - const float dl = xb->d * s[il]; - uint8_t m = 1<<(il*2); - const float coef = il<2 ? 1.f : 1.f/16.f; - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); - } + q_offset = 4*il; h_offset = 4*(il&1); + hmask1 = il<2 ? 3 : 12; hmask2 = hmask1 << 8; + h_coef = il<2 ? 16.f : 4.f; d_loc = il/2; #endif -} - -template -void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const float d_all = (float)(xb->d); - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; + } + void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ + float dl1, dl2; + float sum2 = 0.f; #if QK_K == 256 - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); - float sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2)%4; + dl1 = (float)(xb->d) * xb->scales[d_loc]; dl2 = (float)(xb->d) * xb->scales[d_loc+4]; #else - ql = ql + 16 * (il&1); - float sc = scales[il]; + dl1 = (float)(xb->d) * xb->scales[d_loc]; dl2 = (float)(xb->d) * xb->scales[d_loc+2]; #endif - for (int i = 0; i < 16; ++i) { - uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const float coef = il>1 ? 1.f/16.f : 1.f; - float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \ - ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef; - reg[i/4][i%4] = d_all * sc * q * coef; - } -} + addr_uint16_p ql = (addr_uint16_p)xb->ql + q_offset; + addr_uint16_p qh = (addr_uint16_p)xb->qh + h_offset; + for (int i = 0; i < 8; i+=2) { + sum += yl[i/4 ][i%4 ] * ((ql[i/2]&0x000F) + (qh[i/2]&hmask1) * h_coef); + sum += yl[i/4 ][i%4+1] * ((ql[i/2]&0x0F00) + (qh[i/2]&hmask2) * h_coef); + sum2 += yl[i/4+2][i%4 ] * ((ql[i/2]&0x00F0) + (qh[i/2]&(hmask1<<4)) * h_coef); + sum2 += yl[i/4+2][i%4+1] * ((ql[i/2]&0xF000) + (qh[i/2]&(hmask2<<4)) * h_coef); + } + sum = dl1 * (sum - 32.h * sumy1) + dl2 * (sum2 - 32.h * sumy2); + } -template + void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { + float dl; + get_scales(xb, il, dl); + addr_uint16_p ql = (addr_uint16_p)xb->ql + q_offset; + addr_uint16_p qh = (addr_uint16_p)xb->qh + h_offset; + for (int i = 0; i < 16; i+=2) { + reg[i/4][i%4 ] = dl * (((ql[i/2]&lmask1) + (qh[i/2]&hmask1) * h_coef) * coef1 - 32.f); + reg[i/4][i%4+1] = dl * (((ql[i/2]&lmask2) + (qh[i/2]&hmask2) * h_coef) * coef2 - 32.f); + } + } +}; + +//============================= templates and their specializations ============================= +#define N_SIMDWIDTH 32 + +template class quant_dri> kernel void kernel_get_rows( device const void * src0, device const int * src1, @@ -1903,12 +1251,154 @@ kernel void kernel_get_rows( for (int ind = tiitg; ind < ne00/16; ind += tptg) { float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); + quant_dri dequan_worker; + dequan_worker.init(ind%nl); + dequan_worker.dequantize( + ((device const block_q_type *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; } } +// nl: Each block has 16*nl weights +// n_shift: Each thread deals with 16 dequantized weights. However, the 16 weights may not be continuous. +// n_shift is the difference between the address of the first 8 weights and the last 8 weights. +// (i.e. n_shift=8 means 16 continuous weights) +template class quant_dri> +kernel void kernel_mat_mv(device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uint * shared_memory[[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nb = ne00/(nl * 16); + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int ix = tiisg / nl; + const int il = tiisg % nl; + const short blocks_size_aligned = ((N_SIMDWIDTH / nl) * sizeof(block_q_type) + 4 * (N_SIMDWIDTH / nr) - 1) \ + / (4 * (N_SIMDWIDTH / nr)) * (N_SIMDWIDTH / nr); + const short need_align_fix = ((sizeof(block_q_type) % 4) / 2) * (nb % 2) * sgitg; + const int first_row = (r0 * nsg) * nr + sgitg + nsg * (tiisg / (N_SIMDWIDTH/nr)); + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + const uint offset1 = r1*ne10 + im*ne00*ne1 + ix * (nl * 16) + (il/(n_shift/8))*16*(n_shift/8) + (il%(n_shift/8)) * 8; + + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + offset1; + threadgroup uint * x_st = shared_memory + blocks_size_aligned * nr * sgitg \ + + blocks_size_aligned * (tiisg / (N_SIMDWIDTH / nr)); + threadgroup uint16_t * x_ld = ((threadgroup uint16_t *)(shared_memory + blocks_size_aligned * nr * sgitg)) \ + + need_align_fix + ix * sizeof(block_q_type) / 2; + + float4x4 yl; // src1 vector cache + float sumf[nr] = {0.f}; + + quant_dri dequan_worker; + dequan_worker.init(il); + + // each thread in a SIMD group deals with 16 dequantized weights. + for (int ib = ix; ib < (nb + (N_SIMDWIDTH / nl) - 1)/(N_SIMDWIDTH / nl)*(N_SIMDWIDTH / nl) ; ib += N_SIMDWIDTH / nl) { + #pragma unroll(MIN(blocks_size_aligned / (N_SIMDWIDTH / nr), 16)) + for (int i = tiisg % (N_SIMDWIDTH / nr); i < blocks_size_aligned; i += N_SIMDWIDTH / nr) { + *(x_st + i) = *((device const uint *)x + i); + } + yl[0] = *((device const float4 *)y); + yl[1] = *((device const float4 *)y + 1); + yl[2] = *((device const float4 *)y + n_shift/4); + yl[3] = *((device const float4 *)y + n_shift/4 + 1); + + dequan_worker.inner_product_pre(il, yl); + simdgroup_barrier(mem_flags::mem_threadgroup); + #pragma unroll(nr) + for (int row = 0; row < nr; row++) { + float sum_temp = 0.f; + simdgroup_barrier(mem_flags::mem_none); + dequan_worker.inner_product((threadgroup block_q_type *) \ + (x_ld + 2 * blocks_size_aligned * row), il, yl, sum_temp); + sumf[row] += ib class quant_dri> +kernel void kernel_mat_mv_no_tg_mem(device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uint * shared_memory[[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nb = ne00/(nl * 16); + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int ix = tiisg / nl; + const int il = tiisg % nl; + const int first_row = (r0 * nsg) * nr + sgitg; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0) + ix; + const uint offset1 = r1*ne10 + im*ne00*ne1 + ix * (nl * 16) + (il/(n_shift/8))*16*(n_shift/8) + (il%(n_shift/8)) * 8; + + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + offset1; + + float4x4 yl; // src1 vector cache + float sumf[nr] = {0.f}; + + quant_dri dequan_worker; + dequan_worker.init(il); + + // each thread in a SIMD group deals with 16 dequantized weights. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / nl) { + yl[0] = *((device const float4 *)y); + yl[1] = *((device const float4 *)y + 1); + yl[2] = *((device const float4 *)y + n_shift/4); + yl[3] = *((device const float4 *)y + n_shift/4 + 1); + + dequan_worker.inner_product_pre(il, yl); + #pragma unroll(nr) + for (int row = 0; row < nr; row++) { + float sum_temp = 0.f; + dequan_worker.inner_product(x + 2 * nb * row, il, yl, sum_temp); + sumf[row] += sum_temp; + } + x += N_SIMDWIDTH / nl; + y += N_SIMDWIDTH * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + nsg * row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + nsg * row] = tot; + } + } +} + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A #define BLOCK_SIZE_K 32 @@ -1921,7 +1411,7 @@ kernel void kernel_get_rows( #define SG_MAT_ROW 8 // each block_q contains 16*nl weights -template +template class quant_dri> kernel void kernel_mul_mm(device const uchar * src0, device const float * src1, device float * dst, @@ -1960,14 +1450,18 @@ kernel void kernel_mul_mm(device const uchar * src0, short il = (tiitg % THREAD_PER_ROW); uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const block_q_type * x = (device const block_q_type *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { //load data and store to threadgroup memory half4x4 temp_a; - dequantize_func(x, il, temp_a); + quant_dri dequan_worker; + dequan_worker.init(il); + dequan_worker.dequantize(x, il, temp_a); threadgroup_barrier(mem_flags::mem_threadgroup); #pragma unroll(16) for (int i = 0; i < 16; i++) { @@ -2042,26 +1536,46 @@ kernel void kernel_mul_mm(device const uchar * src0, typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ constant uint64_t &, constant uint64_t &, uint, uint, uint); -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; + +typedef void (mat_mv_t)(device const void *, device const float *, device float *, constant int64_t &,\ + constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ + constant int64_t &, constant int64_t &, constant uint &, threadgroup uint *, uint3, uint, uint); + +#define N_DST 4 +#define N_SIMDGROUP 2 +template [[host_name("kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv_no_tg_mem; +template [[host_name("kernel_mul_mv_q4_0_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem; +template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem; +template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem; +template [[host_name("kernel_mul_mv_q2_K_f32")]] kernel mat_mv_t kernel_mat_mv; +template [[host_name("kernel_mul_mv_q3_K_f32")]] kernel mat_mv_t kernel_mat_mv; +template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv; +template [[host_name("kernel_mul_mv_q5_K_f32")]] kernel mat_mv_t kernel_mat_mv; +#if QK_K == 256 +template [[host_name("kernel_mul_mv_q6_K_f32")]] kernel mat_mv_t kernel_mat_mv; +#else +template [[host_name("kernel_mul_mv_q6_K_f32")]] kernel mat_mv_t kernel_mat_mv; +#endif