Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : grouped-query attention + LLaMAv2 70B support #2276

Merged
merged 6 commits into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
@dataclass
class Params:
n_vocab: int
n_embd: int
n_mult: int
n_head: int
n_embd: int
n_mult: int
n_head: int
n_layer: int

@staticmethod
Expand All @@ -167,40 +167,65 @@ def guessed(model: 'LazyModel') -> 'Params':
n_head=n_embd // 128 # guessed

return Params(
n_vocab=n_vocab,
n_embd=n_embd,
n_mult=256,
n_head=n_head,
n_layer=n_layer,
n_vocab = n_vocab,
n_embd = n_embd,
n_mult = 256,
n_head = n_head,
n_layer = n_layer,
)

@staticmethod
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
config = json.load(open(config_path))

n_vocab = config["vocab_size"];
n_embd = config["hidden_size"];
n_head = config["num_attention_heads"];
n_embd = config["hidden_size"];
n_head = config["num_attention_heads"];
n_layer = config["num_hidden_layers"];
n_ff = config["intermediate_size"];
n_ff = config["intermediate_size"];

n_mult = find_n_mult(n_ff, n_embd);

return Params(
n_vocab=n_vocab,
n_embd=n_embd,
n_mult=n_mult,
n_head=n_head,
n_layer=n_layer,
n_vocab = n_vocab,
n_embd = n_embd,
n_mult = n_mult,
n_head = n_head,
n_layer = n_layer,
)

# LLaMA v2 70B params.json
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
@staticmethod
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
config = json.load(open(config_path))

n_vocab = config["vocab_size"];
n_embd = config["dim"];
n_head = config["n_heads"];
n_layer = config["n_layers"];
n_mult = config["multiple_of"];

if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0]

return Params(
n_vocab = n_vocab,
n_embd = n_embd,
n_mult = n_mult,
n_head = n_head,
n_layer = n_layer,
)

@staticmethod
def load(model_plus: 'ModelPlus') -> 'Params':
orig_config_path = model_plus.paths[0].parent / "params.json"
hf_transformer_config_path = model_plus.paths[0].parent / "config.json"
orig_config_path = model_plus.paths[0].parent / "params.json"

if hf_transformer_config_path.exists():
params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path)
if orig_config_path.exists():
params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
else:
params = Params.guessed(model_plus.model)

Expand Down Expand Up @@ -1036,8 +1061,7 @@ def write_vocab(self, vocab: Vocab) -> None:
@staticmethod
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
of = OutputFile(fname_out)
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0,
n_head=1, n_layer=0)
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
of = OutputFile(fname_out)
of.write_file_header(params, file_type=GGMLFileType.AllF32)
of.write_vocab(vocab)
Expand Down
12 changes: 10 additions & 2 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
} else if (arg == "-gqa" || arg == "--gqa") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_gqa = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -485,6 +491,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 7 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
fprintf(stderr, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stderr, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
Expand All @@ -505,15 +514,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --cfg-negative-prompt PROMPT \n");
fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --perplexity compute perplexity over each ctx window of the prompt\n");
fprintf(stderr, " --perplexity-lines compute perplexity over each line of the prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
Expand Down Expand Up @@ -580,6 +587,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param

lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
lparams.n_gqa = params.n_gqa;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
Expand Down
3 changes: 2 additions & 1 deletion examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct gpt_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
Expand All @@ -47,7 +48,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate

Expand Down
4 changes: 2 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ int main(int argc, char ** argv) {
}

if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
" you are on your own\n", __func__, params.n_ctx);
// TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx);
} else if (params.n_ctx < 8) {
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
Expand Down
71 changes: 45 additions & 26 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1745,11 +1745,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
}
}

static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
static __global__ void mul_mat_p021_f16_f32(
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {

const half * x = (const half *) vx;

const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int channel_x = channel / (nchannels_y / nchannels_x);

const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
Expand All @@ -1765,7 +1769,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
}

// x is transposed and permuted
const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
const float xi = __half2float(x[ix]);

const int row_y = col_x;
Expand Down Expand Up @@ -1793,12 +1797,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const

static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int channel_stride_x) {
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {

const half * x = (const half *) vx;

const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int channel_x = channel / channel_x_divisor;

const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
Expand All @@ -1815,7 +1820,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
break;
}

const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const float xi = __half2float(x[ix]);

const int row_y = col_x;
Expand Down Expand Up @@ -2324,20 +2329,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
}

static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
const dim3 block_nums(1, nrows_x, nchannels_x);
static void ggml_mul_mat_p021_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
const int nchannels_x, const int nchannels_y, cudaStream_t stream) {

const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
}

static void ggml_mul_mat_vec_nc_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {

const dim3 block_nums(1, nrows_x, nchannels_x);
const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
}

static void ggml_cpy_f32_f32_cuda(
Expand Down Expand Up @@ -3097,6 +3105,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;

GGML_ASSERT(ne03 == ne13);

const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
Expand All @@ -3108,12 +3119,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);

// strides for iteration over dims 3 and 2
const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
const int64_t src0_stride = ne00 * ne01 * stride_mod;
const int64_t src1_stride = ne10 * ne11 * stride_mod;
const int64_t dst_stride = ne0 * ne1 * stride_mod;

const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
const int64_t i03_max = flatten_rows ? 1 : ne03;
const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
GGML_ASSERT(!(flatten_rows && ne02 < ne12));

const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type);

Expand All @@ -3130,6 +3148,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);

const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
GGML_ASSERT(!(split && ne02 < ne12));

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);

Expand Down Expand Up @@ -3166,7 +3185,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
} else {
row_low = 0;
row_high = nrows0;
row_high = nrows0*i02_divisor;
}
if (row_low == row_high) {
continue;
Expand Down Expand Up @@ -3214,16 +3233,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
}

const int64_t i03_max = flatten_rows ? 1 : ne03;
const int64_t i02_max = flatten_rows ? 1 : ne02;
const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;

for (int64_t i03 = 0; i03 < i03_max; i03++) {
const int64_t i13 = i03 % ne13;
for (int64_t i02 = 0; i02 < i02_max; i02++) {
const int64_t i12 = i02 % ne12;

const int64_t i0 = i03*ne02 + i02;
const int64_t i0 = i03*i02_max + i02;

// i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
const int64_t i0_offset_low = row_low/rows_per_iter;
Expand Down Expand Up @@ -3257,10 +3272,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t i11 = i13*ne12 + i12;

// for split tensors the data begins at i0 == i0_offset_low
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;

// for split tensors the data pointer needs to be rounded down
// to the bin edge for i03, i02 bins beyond the first
Expand Down Expand Up @@ -3299,11 +3314,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
}
}

if (!src0_on_device || !src0_is_contiguous) {
if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
if (src0_is_f32) {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
} else {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
}
}

Expand Down Expand Up @@ -3457,6 +3472,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];

const int64_t ne12 = src1->ne[2];

CUDA_CHECK(cudaSetDevice(g_main_device));
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];

Expand All @@ -3469,7 +3486,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];

ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
}

void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
Expand All @@ -3483,6 +3500,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];

const int64_t ne12 = src1->ne[2];

const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];

Expand All @@ -3501,7 +3520,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int row_stride_x = nb01 / sizeof(half);
const int channel_stride_x = nb02 / sizeof(half);

ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
}

void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down
Loading