Skip to content

Commit

Permalink
Fix bugs for ARM platforms (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondWang0 authored Apr 19, 2024
1 parent d46a858 commit 9d14ae7
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 38 deletions.
80 changes: 46 additions & 34 deletions kernels/neon/matmul_neon_fp32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
#include <stdio.h>
#include <cmath>
#include <cstdlib>
#include <Accelerate/Accelerate.h>
// #include <omp.h>
#include <arm_neon.h>

#ifdef USE_ACCELERATE
#include <Accelerate/Accelerate.h>
#endif

#include "common.h"
#include "../matmul.h"
#include "pthread_pool.h"
Expand Down Expand Up @@ -38,6 +41,29 @@ void fp32_ref_matmul(const struct matmul_params *params) {
}
}

void fp32_ref_matmul_bias(const struct matmul_params *params) {
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
float *bias = params->bias.data_ptr;
float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr;

assert(A->column == B->row);
assert(C->row == A->row);
assert(C->column == B->column);
int m = A->row, n = B->column, k = A->column;

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
float acc = 0;
for (int kk = 0; kk < k; kk++) {
acc += data_A[i * k + kk] * data_B[j * k + kk];
}
acc = acc + bias[j];
data_C[i * n + j] = acc;
}
}
}

#ifdef USE_ACCELERATE
inline void fp32_matmul_transposed_cblas_gemm(const struct matmul_params *params) {
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr;
Expand All @@ -55,11 +81,6 @@ inline void fp32_matmul_transposed_cblas_gemm(const struct matmul_params *params
0.0f, data_C, n);
}

void MatmulOperator::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) {
// fp32_ref_matmul(params);
fp32_matmul_transposed_cblas_gemm(params);
}

inline void fp32_matmul_untransposed_cblas_gemm(const struct matmul_params *params) {
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr;
Expand All @@ -76,32 +97,6 @@ inline void fp32_matmul_untransposed_cblas_gemm(const struct matmul_params *para
0.0f, data_C, n);
}

void MatmulOperator::mat_mul_accelerator_untransposed_fastover_column(const struct matmul_params *params) {
fp32_matmul_untransposed_cblas_gemm(params);
}

void fp32_ref_matmul_bias(const struct matmul_params *params) {
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
float *bias = params->bias.data_ptr;
float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr;

assert(A->column == B->row);
assert(C->row == A->row);
assert(C->column == B->column);
int m = A->row, n = B->column, k = A->column;

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
float acc = 0;
for (int kk = 0; kk < k; kk++) {
acc += data_A[i * k + kk] * data_B[j * k + kk];
}
acc = acc + bias[j];
data_C[i * n + j] = acc;
}
}
}

void fp32_matmul_bias_cblas_gemm(const struct matmul_params *params) {
// struct fp32_thread_args* mat_args = (struct fp32_thread_args*)args;
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
Expand All @@ -123,6 +118,21 @@ void fp32_matmul_bias_cblas_gemm(const struct matmul_params *params) {
vDSP_vadd(bias, 1, data_C + i * n, 1, data_C + i * n, 1, n);
}
}
#endif

void MatmulOperator::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) {
#ifdef USE_ACCELERATE
fp32_matmul_transposed_cblas_gemm(params);
#else
fp32_ref_matmul(params);
#endif
}

void MatmulOperator::mat_mul_accelerator_untransposed_fastover_column(const struct matmul_params *params) {
#ifdef USE_ACCELERATE
fp32_matmul_untransposed_cblas_gemm(params);
#endif
}

inline static void* fp32_matmul_bias_optimized_gemm(void* args) {
struct fp32_thread_args* mat_args = (struct fp32_thread_args*)args;
Expand Down Expand Up @@ -251,9 +261,11 @@ inline static void* fp32_matmul_bias_optimized_gemm(void* args) {
}

void MatmulOperator::mat_mul_accelerator_transposed_fastover_column_bias(const struct matmul_params *params) {
// fp32_ref_matmul_bias(params);

#ifdef USE_ACCELERATE
fp32_matmul_bias_cblas_gemm(params);
#else
fp32_ref_matmul_bias(params);
#endif

// int i, j, k;
// const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
Expand Down
9 changes: 8 additions & 1 deletion kernels/neon/matmul_neon_int8_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
#include <stdio.h>
#include <cmath>
#include <cstdlib>
#include <Accelerate/Accelerate.h>
#include <arm_neon.h>

#ifdef USE_ACCELERATE
#include <Accelerate/Accelerate.h>
#endif

#include "../matmul.h"
#include "common.h"
#include "pthread_pool.h"
Expand Down Expand Up @@ -1265,6 +1268,7 @@ static void* matmul_int8_int4_no_offset_over_column_packed(void* args) {
return NULL;
}

#ifdef USE_ACCELERATE
inline static void* fp32_matmul_transposed_cblas_gemm(void* args) {
struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args;
const struct matmul_params* params = mat_args->params;
Expand All @@ -1286,6 +1290,7 @@ inline static void* fp32_matmul_transposed_cblas_gemm(void* args) {

return NULL;
}
#endif

namespace matmul {
void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params* params) {
Expand Down Expand Up @@ -1433,6 +1438,7 @@ void MatmulOperator::gemm_accelerator_int8_int4_fast_no_offset_v2(struct matmul_
pool_wait(pool);
};

#ifdef USE_ACCELERATE
void MatmulOperator::cblas_gemm_accelerator_no_offset(struct matmul_params* params) {
int i, j, k;
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
Expand Down Expand Up @@ -1470,5 +1476,6 @@ void MatmulOperator::cblas_gemm_accelerator_no_offset(struct matmul_params* para
// Join threads
pool_wait(pool);
};
#endif

} // namespace matmul
4 changes: 2 additions & 2 deletions llm/src/ops/BMM_F32T.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void BMM_F32T::forward(const Matrix3D<float> &a, const Matrix3D<float> &weight,
// op.mat_mul_transposed_fastover_column((const struct matmul_params
// *)&params);
// else
#ifdef QM_ARM
#ifdef USE_ACCELERATE
op.mat_mul_accelerator_transposed_fastover_column(&params);
#else
op.mat_mul_transposed(&params); // TODO: optimize this
Expand Down Expand Up @@ -87,7 +87,7 @@ void BMM_F32T::forward_weight_untransposed(const Matrix3D<float> &a, const Matri
for (int i = 0; i < m * n * a.m_dim_x; i++) {
params.C.data_ptr[i] = 0;
}
#ifdef QM_ARM
#ifdef USE_ACCELERATE
for (int bz = 0; bz < a.m_dim_x; bz++) {
op.mat_mul_accelerator_untransposed_fastover_column(&params);
params.A.data_ptr += m * k;
Expand Down
2 changes: 1 addition & 1 deletion llm/src/ops/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void linear(Matrix3D<T> &a, Matrix3D<T> &b, Matrix3D<T> &c) {
}
}

#ifdef QM_ARM
#ifdef USE_ACCELERATE
#define MAX_WEIGHT_BUFFER 32000 * 4096
static float *w_fp32;
void Linear_FP_int4::initialize_weight_memory() {
Expand Down

0 comments on commit 9d14ae7

Please sign in to comment.