Skip to content

Commit

Permalink
Add Phi-2 architecture support
Browse files Browse the repository at this point in the history
This change adds support for a new 2.7b model, announced by Microsoft
recently, which claims to have similar quality to much larger models.

This change cherry-picks ggerganov/llama.cpp@b9e74f9. To support this
change, our tinyBLAS library needed to be extended to support float32
output arrays. I've confirmed that tinyBLAS is still produces outputs
consistent with cuBLAS. However our Phi-2 model output isn't the same
compared to llama.cpp. Running Phi-2 on Apple Metal vs. Cuda are also
producing different output, unlike llama.cpp. This issue can probably
fix itself the next time a full upstream synchronization happens. The
output we're making, while different, still appears pretty reasonable
although I haven't measured perplexity scores.

Fixes #145
  • Loading branch information
jart committed Dec 29, 2023
1 parent 8f73d39 commit 6423228
Show file tree
Hide file tree
Showing 7 changed files with 470 additions and 109 deletions.
133 changes: 92 additions & 41 deletions llama.cpp/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,6 @@ static enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
}

#define ggml_nbytes_split ggml_nbytes_split_
static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
return (nrows_split*tensor->ne[0]*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type);
}

#define ggml_nelements ggml_nelements_
static int64_t ggml_nelements(const struct ggml_tensor * tensor) {
return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
Expand Down Expand Up @@ -252,6 +247,11 @@ static size_t ggml_element_size(const struct ggml_tensor * tensor) {
return ggml_type_size(tensor->type);
}

#define ggml_row_size ggml_row_size_
static size_t ggml_row_size(enum ggml_type type, int64_t ne) {
return ggml_type_size(type)*ne/ggml_blck_size(type);
}

static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}
Expand Down Expand Up @@ -5328,7 +5328,16 @@ static __global__ void rope_neox(
const int ib = col / n_dims;
const int ic = col % n_dims;

const int i = row*ncols + ib*n_dims + ic/2;
if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;

float cur_rot = inv_ndims * ic - ib;
Expand Down Expand Up @@ -7389,6 +7398,7 @@ inline void ggml_cuda_op_upscale(

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_pad(
Expand All @@ -7405,6 +7415,7 @@ inline void ggml_cuda_op_pad(

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_rms_norm(
Expand Down Expand Up @@ -7709,7 +7720,7 @@ inline void ggml_cuda_op_mul_mat_cublas(

const int compute_capability = g_compute_capabilities[id];

if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
Expand Down Expand Up @@ -8633,27 +8644,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
}

static __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13,
int ne23,
int nb02, int nb03,
int nb12, int nb13,
int nb2, int nb3,
int r2, int r3) {
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
int64_t ne12, int64_t ne13,
int64_t ne23,
size_t nb02, size_t nb03,
size_t nb12, size_t nb13,
size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3) {
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;

if (i13 >= ne13 || i12 >= ne12) {
return;
}

int i03 = i13 / r3;
int i02 = i12 / r2;
int64_t i03 = i13 / r3;
int64_t i02 = i12 / r2;

ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
}

static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down Expand Up @@ -8710,7 +8721,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);

size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);

half * dst_f16 = nullptr;
char * dst_t = nullptr;

cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const void * alpha = &alpha_f16;
const void * beta = &beta_f16;

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
dst_t = (char *) dst_f16;

nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half);
} else {
dst_t = (char *) dst_ddf;

cu_compute_type = CUBLAS_COMPUTE_32F;
cu_data_type = CUDA_R_32F;

alpha = &alpha_f32;
beta = &beta_f32;
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand All @@ -8719,9 +8764,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

#if 0
// use cublasGemmEx
{
Expand All @@ -8731,12 +8773,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
int i02 = i12 / r2;

CUBLAS_CHECK(
cublasGemmEx(CUBLAS_HANDLE(id), CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmEx(CUBLAS_HANDLE(g_main_device), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
CUBLAS_COMPUTE_16F,
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
Expand All @@ -8748,11 +8790,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(CUBLAS_HANDLE(g_main_device), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// use cublasGemmBatchedEx
Expand All @@ -8769,24 +8811,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const

dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_as_f16, src1_as_f16, dst_f16,
src0_as_f16, src1_as_f16, dst_t,
ptrs_src, ptrs_dst,
ne12, ne13,
ne23,
nb02, nb03,
nb12, nb13,
dst->nb[2], dst->nb[3],
nbd2, nbd3,
r2, r3);
CUDA_CHECK(cudaGetLastError());

CUBLAS_CHECK(
cublasGemmBatchedEx(CUBLAS_HANDLE(g_main_device), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
ne23,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (ptrs_src_s != 0) {
Expand All @@ -8798,11 +8840,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
}
#endif

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);

ggml_cuda_pool_free(dst_f16, dst_as);
}

ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}

static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down Expand Up @@ -9232,6 +9277,12 @@ static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, gg
(void) dst;
}

static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
}

void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
const int64_t nrows = ggml_nrows(tensor);

Expand Down
13 changes: 11 additions & 2 deletions llama.cpp/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1702,8 +1702,9 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
Expand All @@ -1722,6 +1723,14 @@ kernel void kernel_rope(

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;

device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down
52 changes: 40 additions & 12 deletions llama.cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2056,12 +2056,6 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
}

size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

return (nrows_split*tensor->ne[0]*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type);
}

int ggml_blck_size(enum ggml_type type) {
return type_traits[type].blck_size;
}
Expand Down Expand Up @@ -4146,6 +4140,14 @@ struct ggml_tensor * ggml_mul_mat(
return result;
}

void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
const int32_t prec_i32 = (int32_t) prec;

ggml_set_op_params_i32(a, 0, prec_i32);
}

// ggml_mul_mat_id

struct ggml_tensor * ggml_mul_mat_id(
Expand Down Expand Up @@ -9217,6 +9219,8 @@ static void ggml_compute_forward_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

GGML_ASSERT(eps > 0.0f);

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
Expand Down Expand Up @@ -9286,6 +9290,8 @@ static void ggml_compute_forward_rms_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

GGML_ASSERT(eps > 0.0f);

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
Expand Down Expand Up @@ -11453,10 +11459,13 @@ static void ggml_compute_forward_rope_f32(
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;

Expand All @@ -11479,6 +11488,14 @@ static void ggml_compute_forward_rope_f32(

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;

const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down Expand Up @@ -11606,10 +11623,13 @@ static void ggml_compute_forward_rope_f16(
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;

Expand All @@ -11632,6 +11652,14 @@ static void ggml_compute_forward_rope_f16(

dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} else {
const int64_t i0 = ic;

const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down
Loading

0 comments on commit 6423228

Please sign in to comment.