Skip to content

Commit

Permalink
llama : add phi-2 + fix NeoX rope + ggml_mul_mat_set_prec (#4490)
Browse files Browse the repository at this point in the history
* phi2 implementation

* fix breaking change

* phi-2 : various fixes

* phi-2 : use layer norm eps

* py : whitespaces

* llama : fix meta KV override bug

* convert : phi don't add BOS token

* convert : revert "added_tokens_decoder" change

* phi-2 : scale Q instead of KQ for better precision

* ggml : fix NeoX rope to rotate just first n_dims

* cuda : less diff in the rope_neox kernel

* ggml : add ggml_mul_mat_set_prec

ggml-ci

* Update ggml-cuda.cu

Co-authored-by: slaren <slarengh@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: slaren <slarengh@gmail.com>

* cuda : ggml_cuda_op_mul_mat_cublas support F32 precision

* cuda : remove oboslete comment

---------

Co-authored-by: Ebey Abraham <ebeyabraham@microsoft.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
  • Loading branch information
4 people authored Dec 18, 2023
1 parent 3c04bf6 commit b9e74f9
Show file tree
Hide file tree
Showing 9 changed files with 463 additions and 76 deletions.
22 changes: 22 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def from_model_architecture(model_architecture):
return QwenModel
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "PhiForCausalLM":
return Phi2Model
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -221,6 +223,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -980,6 +984,24 @@ def write_tensors(self):
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)


class Phi2Model(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]

self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_add_bos_token(False)


###### CONVERSION LOGIC ######


Expand Down
117 changes: 81 additions & 36 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
const int ib = col / n_dims;
const int ic = col % n_dims;

const int i = row*ncols + ib*n_dims + ic/2;
if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;

float cur_rot = inv_ndims * ic - ib;
Expand Down Expand Up @@ -7057,6 +7066,7 @@ inline void ggml_cuda_op_upscale(

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_pad(
Expand All @@ -7073,6 +7083,7 @@ inline void ggml_cuda_op_pad(

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_rms_norm(
Expand Down Expand Up @@ -7376,7 +7387,7 @@ inline void ggml_cuda_op_mul_mat_cublas(

const int compute_capability = g_compute_capabilities[id];

if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
Expand Down Expand Up @@ -8300,27 +8311,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
}

static __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13,
int ne23,
int nb02, int nb03,
int nb12, int nb13,
int nb2, int nb3,
int r2, int r3) {
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
int64_t ne12, int64_t ne13,
int64_t ne23,
size_t nb02, size_t nb03,
size_t nb12, size_t nb13,
size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3) {
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;

if (i13 >= ne13 || i12 >= ne12) {
return;
}

int i03 = i13 / r3;
int i02 = i12 / r2;
int64_t i03 = i13 / r3;
int64_t i02 = i12 / r2;

ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
}

static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down Expand Up @@ -8376,7 +8387,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);

size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);

half * dst_f16 = nullptr;
char * dst_t = nullptr;

cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const void * alpha = &alpha_f16;
const void * beta = &beta_f16;

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
dst_t = (char *) dst_f16;

nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half);
} else {
dst_t = (char *) dst_ddf;

cu_compute_type = CUBLAS_COMPUTE_32F;
cu_data_type = CUDA_R_32F;

alpha = &alpha_f32;
beta = &beta_f32;
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand All @@ -8385,9 +8430,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

#if 0
// use cublasGemmEx
{
Expand All @@ -8397,12 +8439,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
int i02 = i12 / r2;

CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
CUBLAS_COMPUTE_16F,
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
Expand All @@ -8414,11 +8456,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// use cublasGemmBatchedEx
Expand All @@ -8435,24 +8477,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const

dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_as_f16, src1_as_f16, dst_f16,
src0_as_f16, src1_as_f16, dst_t,
ptrs_src, ptrs_dst,
ne12, ne13,
ne23,
nb02, nb03,
nb12, nb13,
dst->nb[2], dst->nb[3],
nbd2, nbd3,
r2, r3);
CUDA_CHECK(cudaGetLastError());

CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
ne23,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (ptrs_src_s != 0) {
Expand All @@ -8464,11 +8506,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
}
#endif

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);

ggml_cuda_pool_free(dst_f16, dst_as);
}

ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}

static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down
13 changes: 11 additions & 2 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1702,8 +1702,9 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
Expand All @@ -1722,6 +1723,14 @@ kernel void kernel_rope(

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;

device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down
46 changes: 40 additions & 6 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4098,6 +4098,14 @@ struct ggml_tensor * ggml_mul_mat(
return result;
}

void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
const int32_t prec_i32 = (int32_t) prec;

ggml_set_op_params_i32(a, 0, prec_i32);
}

// ggml_mul_mat_id

struct ggml_tensor * ggml_mul_mat_id(
Expand Down Expand Up @@ -9168,6 +9176,8 @@ static void ggml_compute_forward_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

GGML_ASSERT(eps > 0.0f);

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
Expand Down Expand Up @@ -9237,6 +9247,8 @@ static void ggml_compute_forward_rms_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

GGML_ASSERT(eps > 0.0f);

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
Expand Down Expand Up @@ -11562,10 +11574,13 @@ static void ggml_compute_forward_rope_f32(
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;

Expand All @@ -11588,6 +11603,14 @@ static void ggml_compute_forward_rope_f32(

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;

const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down Expand Up @@ -11715,10 +11738,13 @@ static void ggml_compute_forward_rope_f16(
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;

Expand All @@ -11741,6 +11767,14 @@ static void ggml_compute_forward_rope_f16(

dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} else {
const int64_t i0 = ic;

const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down
Loading

0 comments on commit b9e74f9

Please sign in to comment.