diff --git a/kernels/neon/matmul_neon_int8_int4.cc b/kernels/neon/matmul_neon_int8_int4.cc index 3fa783d..ba463b0 100644 --- a/kernels/neon/matmul_neon_int8_int4.cc +++ b/kernels/neon/matmul_neon_int8_int4.cc @@ -232,6 +232,104 @@ static void* matmul_int8_int4_no_offset_over_column(void* args) { return NULL; } +inline static void* gemv_int8_int4_no_offset_over_column_unroll128(void* args) { + struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; + const struct matmul_params* params = mat_args->params; + int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size; + const int num_block = k / block_size; + assert(m == 1); + + for (int j = mat_args->start_j; j < mat_args->end_j; j++) { + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + float32x4_t sumv2 = vdupq_n_f32(0.0f); + float32x4_t sumv3 = vdupq_n_f32(0.0f); + const unsigned char* w_start = ¶ms->B.int4_data_ptr[j * k / 2]; + const signed char* a_start = ¶ms->A.int8_data_ptr[0]; + float* s_a = ¶ms->A_scales[0]; + float* s_w = ¶ms->scales[j * k / 32]; + + const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); + const int8x16_t offsets = vdupq_n_s8(8); + for (int q = 0; q < num_block; q += 4) { + int32x4_t int_sum0 = vdupq_n_s32(0); + int32x4_t int_sum1 = vdupq_n_s32(0); + int32x4_t int_sum2 = vdupq_n_s32(0); + int32x4_t int_sum3 = vdupq_n_s32(0); + float s_0 = *s_a++ * *s_w++; + float s_1 = *s_a++ * *s_w++; + float s_2 = *s_a++ * *s_w++; + float s_3 = *s_a++ * *s_w++; + + const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight + const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight + const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight + const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight + w_start += 64; + + // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit + // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit + // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... + // low; (0, 0), (1, 0), (2, 0), (3, 0) ... + // high: (16, 0), (17, 0), (18, 0), (19, 0) ... + int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); + int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); + int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit)); + int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4)); + int8x16_t w2_low = vreinterpretq_s8_u8(vandq_u8(w2, mask_low4bit)); + int8x16_t w2_high = vreinterpretq_s8_u8(vshrq_n_u8(w2, 4)); + int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit)); + int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4)); + + // apply offset + w0_low = vsubq_s8(w0_low, offsets); + w0_high = vsubq_s8(w0_high, offsets); + w1_low = vsubq_s8(w1_low, offsets); + w1_high = vsubq_s8(w1_high, offsets); + w2_low = vsubq_s8(w2_low, offsets); + w2_high = vsubq_s8(w2_high, offsets); + w3_low = vsubq_s8(w3_low, offsets); + w3_high = vsubq_s8(w3_high, offsets); + + // load 64 8-bit activation + const int8x16_t a0 = vld1q_s8(a_start); + const int8x16_t a1 = vld1q_s8(a_start + 16); + const int8x16_t a2 = vld1q_s8(a_start + 32); + const int8x16_t a3 = vld1q_s8(a_start + 48); + const int8x16_t a4 = vld1q_s8(a_start + 64); + const int8x16_t a5 = vld1q_s8(a_start + 80); + const int8x16_t a6 = vld1q_s8(a_start + 96); + const int8x16_t a7 = vld1q_s8(a_start + 112); + a_start += 128; + + // dot product into int32x4_t + int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0); + int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1); + int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2); + int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3); + int_sum2 = my_vdotq_s32(int_sum2, w2_low, a4); + int_sum2 = my_vdotq_s32(int_sum2, w2_high, a5); + int_sum3 = my_vdotq_s32(int_sum3, w3_low, a6); + int_sum3 = my_vdotq_s32(int_sum3, w3_high, a7); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1); + sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2); + sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3); + } + if (params->bias.data_ptr) { + params->C.data_ptr[j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + } + else { + params->C.data_ptr[j] = + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + } + } + + return NULL; +} + inline static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) { struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; const struct matmul_params* params = mat_args->params; @@ -594,4 +692,36 @@ void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_ // for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL); pool_wait(pool); }; + +void MatmulOperator::gemv_accelerator_int8_int4_fast_no_offset(struct matmul_params* params) { + int i, j, k; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + const int block_size = params->block_size; + float *scale = params->scales, *offset = params->offset; + assert(params->block_size % 32 == 0); // support block size to be multiply of 32 + assert(A->row == C->row); // support block size to be multiply of 32 + assert(A->row == 1); + + quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, block_size); + + const int num_thread = params->opt_params.num_thread; + struct a8w4_thread_args threads_args[num_thread]; + assert(params->block_size == 32); // support block size 32 for now + + static void *pool = pool_start(gemv_int8_int4_no_offset_over_column_unroll128, num_thread); + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_j = j * (params->C.column / num_thread); + if (j == num_thread - 1) { + threads_args[j].end_j = params->C.column; + } else { + threads_args[j].end_j = (j + 1) * (params->C.column / num_thread); + } + threads_args[j].params = params; + pool_enqueue(pool, &threads_args[j], '\0'); + } + // Join threads + pool_wait(pool); +}; } // namespace matmul