From b49792b044316dc2751b9504612aca018e268129 Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Sat, 30 Sep 2023 18:35:35 +0200 Subject: [PATCH 01/16] CUDA: added support for ggml_clamp (see also: https://github.com/ggerganov/ggml/issues/545) --- ggml-cuda.cu | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 989c419cd0ea4..0c873375d0b75 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -414,6 +414,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 +#define CUDA_CLAMP_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 @@ -4555,6 +4556,16 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale dst[i] = scale * x[i]; } +static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); +} + static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; add_f32<<>>(x, y, dst, kx, ky); @@ -5436,6 +5447,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } +static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; + clamp_f32<<>>(x, dst, min, max, k); +} + template static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { @@ -6353,6 +6369,24 @@ inline void ggml_cuda_op_scale( (void) src1_dd; } +inline void ggml_cuda_op_clamp( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const float min = ((float *) dst->op_params)[0]; + const float max = ((float *) dst->op_params)[1]; + + clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src1_dd; +} + static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { const int64_t nrows0 = ggml_nrows(src0); @@ -6906,6 +6940,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } +static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp); +} + static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -7330,6 +7368,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_scale; break; + case GGML_OP_CLAMP: + if (!any_on_device) { + return false; + } + func = ggml_cuda_clamp; + break; case GGML_OP_CPY: if (!any_on_device) { return false; From 15236e855b2e167dad9ebf528ebae4117554c534 Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Sat, 30 Sep 2023 18:49:22 +0200 Subject: [PATCH 02/16] mpt : added an implementation based (mostly) on falcon integration, modified with deltas from ggml/examples/mpt --- convert-mpt-hf-to-gguf.py | 263 ++++++++++++++++++++++ gguf-py/gguf/gguf.py | 14 ++ llama.cpp | 449 +++++++++++++++++++++++++++++++++++++- 3 files changed, 724 insertions(+), 2 deletions(-) create mode 100755 convert-mpt-hf-to-gguf.py diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py new file mode 100755 index 0000000000000..b5888fd576b1a --- /dev/null +++ b/convert-mpt-hf-to-gguf.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +# HF gptneox--> gguf conversion + +from __future__ import annotations + +import argparse +import json +import os +import struct +import sys +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from transformers import AutoTokenizer # type: ignore[import] + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf')) +import gguf + +# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py + + +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + return dict(zip(bs, (chr(n) for n in cs))) + + +def count_model_parts(dir_model: Path) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith("pytorch_model-"): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + return num_parts + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file") + parser.add_argument( + "--vocab-only", action="store_true", + help="extract only the vocab", + ) + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input", + ) + parser.add_argument( + "model", type=Path, + help="directory containing model file, or model file itself (*.bin)", + ) + parser.add_argument( + "ftype", type=int, choices=[0, 1], default=1, nargs='?', + help="output format - use 0 for float32, 1 for float16", + ) + return parser.parse_args() + +args = parse_args() + +dir_model = args.model +ftype = args.ftype +if not dir_model.is_dir(): + print(f'Error: {args.model} is not a directory', file = sys.stderr) + sys.exit(1) + +# possible tensor data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 + +# map from ftype to string +ftype_str = ["f32", "f16"] + +if args.outfile is not None: + fname_out = args.outfile +else: + # output in the same directory as the model by default + fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf' + +print("gguf: loading model "+dir_model.name) + +with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +if hparams["architectures"][0] != "MPTForCausalLM": + print("Model architecture not supported: " + hparams["architectures"][0]) + + sys.exit() + +# get number of model parts +num_parts = count_model_parts(dir_model) + +ARCH=gguf.MODEL_ARCH.MPT +gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + +print("gguf: get model metadata") + +block_count = hparams["n_layers"] + +gguf_writer.add_name(dir_model.name) +gguf_writer.add_context_length(hparams["max_seq_len"]) +gguf_writer.add_embedding_length(hparams["d_model"]) +gguf_writer.add_block_count(block_count) +gguf_writer.add_feed_forward_length(4 * hparams["d_model"]) +gguf_writer.add_head_count(hparams["n_heads"]) +gguf_writer.add_layer_norm_eps(1e-05) +gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) +gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"]) + +# TOKENIZATION + +print("gguf: get tokenizer metadata") + +tokens: list[bytearray] = [] + +tokenizer_json_file = dir_model / 'tokenizer.json' +if not tokenizer_json_file.is_file(): + print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr) + sys.exit(1) + +# gpt2 tokenizer +gguf_writer.add_tokenizer_model("gpt2") + +with open(tokenizer_json_file, "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + +print("gguf: get gpt2 tokenizer vocab") + +# MPT token embedding tensors have dimension 50432, but there are only 50254 +# tokens in the vocab, presumably to accomodate some "reserved" tokens; +# this is causing problems down the line in llama.cpp, so we extend the vocab_size: + +vocab_size = len(tokenizer_json["model"]["vocab"]) + 178 + +# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py +tokenizer = AutoTokenizer.from_pretrained(dir_model) + +reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} +byte_encoder = bytes_to_unicode() +byte_decoder = {v: k for k, v in byte_encoder.items()} + +for i in range(vocab_size): + if i in reverse_vocab: + try: + text = bytearray([byte_decoder[c] for c in reverse_vocab[i]]) + except KeyError: + text = bytearray() + for c in reverse_vocab[i]: + if ord(c) < 256: # single byte character + try: + text.append(byte_decoder[c]) + except KeyError: + text.extend(c.encode('utf-8')) + else: # multibyte special token character + text.extend(c.encode('utf-8')) + else: + print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token. (It's normal for MPT.)") + pad_token = f"[PAD{i}]".encode("utf8") + text = bytearray(pad_token) + + tokens.append(text) + +gguf_writer.add_token_list(tokens) + +special_vocab = gguf.SpecialVocab(dir_model, load_merges = True) +special_vocab.add_to_gguf(gguf_writer) + +# TENSORS + +tensor_map = gguf.get_tensor_name_map(ARCH,block_count) + +# tensor info +print("gguf: get tensor metadata") + +if num_parts == 0: + part_names = iter(("pytorch_model.bin",)) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) + +for part_name in part_names: + if args.vocab_only: + break + print("gguf: loading model part '" + part_name + "'") + model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") + + for name in model_part.keys(): + data = model_part[name] + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + data = data.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) + if new_name is None: + print("Cannot map tensor '" + name + "'") + continue # for the sake of compatibility with some old published models, don't quit + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + + +# if new_name == "wte.weight" and data.shape[0] == 50432 and vocab_size == 50254: +# data = data[0:vocab_size,:] + + gguf_writer.add_tensor(new_name, data) + + # note: MPT output is tied to (same as) wte in original model; + # for easier implementation in llama.cpp it's duplicated in GGUF, though :/ + if new_name == "wte.weight": + gguf_writer.add_tensor("output.weight", data) + +print("gguf: write header") +gguf_writer.write_header_to_file() +print("gguf: write metadata") +gguf_writer.write_kv_data_to_file() +if not args.vocab_only: + print("gguf: write tensors") + gguf_writer.write_tensors_to_file() + +gguf_writer.close() + +print(f"gguf: model successfully exported to '{fname_out}'") +print("") diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 598cf8e594aa8..9c49d0ada29a3 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -185,6 +185,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", }, + MODEL_ARCH.MPT: { + MODEL_TENSOR.TOKEN_EMBD: "wte", + MODEL_TENSOR.OUTPUT_NORM: "norm_f", + # note: MPT output is tied to (same as) wte in original model; + # for easier implementation in llama.cpp it's duplicated in GGUF, though :/ + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.norm_1", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn.Wqkv", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn.out_proj", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn.down_proj", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn.up_proj", + }, MODEL_ARCH.GPT2: { # TODO }, @@ -231,6 +244,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 falcon + "transformer.norm_f", # mpt "model.norm", # llama-hf baichuan "norm", # llama-pth ), diff --git a/llama.cpp b/llama.cpp index bff17135b985f..108ae5eb4725f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -377,7 +377,15 @@ static std::map> LLM_TENSOR_NAMES = { LLM_ARCH_MPT, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD, "wte" }, + { LLM_TENSOR_OUTPUT_NORM, "norm_f" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.norm_1" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn.Wqkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn.out_proj" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn.down_proj" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn.up_proj" }, }, }, { @@ -947,6 +955,9 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; + float f_clamp_kqv; + float f_max_alibi_bias; + bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT } @@ -1912,6 +1923,18 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MPT: + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_CLAMP_KQV)); + GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS)); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_30B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -2436,6 +2459,73 @@ static void llm_load_tensors( } } } break; + case LLM_ARCH_MPT: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend backend_norm; + ggml_backend backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + + ggml_nbytes(layer.wqkv) + + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w2) + + ggml_nbytes(layer.w3); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -3955,6 +4045,356 @@ static struct ggml_cgraph * llm_build_starcoder( return gf; } +static struct ggml_cgraph * llm_build_mpt( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + const float norm_eps = hparams.f_norm_eps; + const float clamp_kqv = hparams.f_clamp_kqv; + const float max_alibi_bias = hparams.f_max_alibi_bias; + + const int n_gpu_layers = model.n_gpu_layers; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", + // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + //int warmup = 0; + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + //warmup = ((uint32_t*) inp_tokens->data)[0] == 0; + } + + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); + } + } + + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // shift the entire K-cache if needed + // TODO: Do we need to handle it? (MPT uses alibi instead of rope) +/* if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } + }*/ + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * attn_norm; + + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // self-attention + // TODO: refactor into common function (shared with LLaMA) + { + attn_norm = ggml_norm(ctx0, inpL, norm_eps); + offload_func(attn_norm); + + attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm); + offload_func(attn_norm); + + if (1) { + cur = attn_norm; + } + + // compute QKV + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + offload_func_kq(cur); + + if (clamp_kqv > 0.0f) { + cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv); + offload_func_kq(cur); + } + + const size_t wsize = ggml_type_size(cur->type); + + struct ggml_tensor * Qcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + 0); + offload_func_kq(Qcur); + + struct ggml_tensor * Kcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * n_head); + offload_func_kq(Kcur); + + struct ggml_tensor * tmpv = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * (n_head + n_head_kv)); + offload_func_kq(Kcur); + + ggml_set_name(Qcur, "Qcur"); + ggml_set_name(Kcur, "Kcur"); + + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + offload_func_v(Vcur); + offload_func_v(Vcur->src[0]->src[0]); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + offload_func_v(v); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); + ggml_set_name(K, "K"); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + // TODO: replace with ggml_add() + struct ggml_tensor * KQ_scaled_alibi = + ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias); + offload_func_kq(KQ_scaled_alibi); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + // Add the input + cur = ggml_add(ctx0, cur, inpL); + offload_func(cur); + + struct ggml_tensor * attn_out = cur; + + // feed forward + { + // Norm + { + cur = ggml_norm(ctx0, attn_out, norm_eps); + offload_func(cur); + + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); + offload_func(cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); + offload_func(cur); + + cur = ggml_gelu(ctx0, cur); + offload_func(cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + offload_func(cur); + } + + cur = ggml_add(ctx0, cur, attn_out); + offload_func(cur); + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, norm_eps); + offload_func_nr(cur); + + cur = ggml_mul(ctx0, cur, model.output_norm); + ggml_set_name(cur, "result_norm"); + } + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch) { @@ -3979,6 +4419,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm_build_starcoder(lctx, batch); } break; + case LLM_ARCH_MPT: + { + result = llm_build_mpt(lctx, batch); + } break; default: GGML_ASSERT(false); } @@ -4112,7 +4556,8 @@ static int llama_decode_internal( // If all tensors can be run on the GPU then using more than 1 thread is detrimental. const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_BAICHUAN || - model.arch == LLM_ARCH_FALCON; + model.arch == LLM_ARCH_FALCON || + model.arch == LLM_ARCH_MPT; const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) { n_threads = 1; From 84e30e891c2e211f10a3fa21363095ca0ed97b3c Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Sun, 1 Oct 2023 00:32:33 +0200 Subject: [PATCH 03/16] mpt : protect against "clip_qkv": null in mpt-7b --- convert-mpt-hf-to-gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py index b5888fd576b1a..e0a8b166dd9e2 100755 --- a/convert-mpt-hf-to-gguf.py +++ b/convert-mpt-hf-to-gguf.py @@ -122,7 +122,7 @@ def parse_args() -> argparse.Namespace: gguf_writer.add_feed_forward_length(4 * hparams["d_model"]) gguf_writer.add_head_count(hparams["n_heads"]) gguf_writer.add_layer_norm_eps(1e-05) -gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) +gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"] or 0.0) gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"]) # TOKENIZATION From 00e8c5c5f6253347c61e3b74965df3577cccfee6 Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Sun, 1 Oct 2023 00:49:13 +0200 Subject: [PATCH 04/16] mpt : quick fix to avoid "Strange model" warning when quantizing MPT models --- llama.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 108ae5eb4725f..a2bdd9d3dfeb2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6511,10 +6511,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const std::string name = ggml_get_name(meta); // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos) { + if (name.find("attn_v.weight") != std::string::npos || + name.find("attn.Wqkv.weight") != std::string::npos) { ++n_attention_wv; } - else if (name.find("ffn_down.weight") != std::string::npos) { + else if (name.find("ffn_down.weight") != std::string::npos || + name.find("ffn.down_proj.weight") != std::string::npos) { ++n_feed_forward_w2; } } From 1be89c4002a247449d220824e76c6d26fd548eae Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Sun, 1 Oct 2023 01:14:07 +0200 Subject: [PATCH 05/16] mpt : addendum to changeset:84e30e8 - leave parameter clamp_kqv out from metadata rather than use 0.0 to indicate "no clamping" (more compliant with the current GGUF spec?) --- convert-mpt-hf-to-gguf.py | 3 ++- llama.cpp | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py index e0a8b166dd9e2..057cb34f683f5 100755 --- a/convert-mpt-hf-to-gguf.py +++ b/convert-mpt-hf-to-gguf.py @@ -122,7 +122,8 @@ def parse_args() -> argparse.Namespace: gguf_writer.add_feed_forward_length(4 * hparams["d_model"]) gguf_writer.add_head_count(hparams["n_heads"]) gguf_writer.add_layer_norm_eps(1e-05) -gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"] or 0.0) +if hparams["attn_config"]["clip_qkv"] is not None: + gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"]) # TOKENIZATION diff --git a/llama.cpp b/llama.cpp index a2bdd9d3dfeb2..2bec27b8b41f2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1926,7 +1926,11 @@ static void llm_load_hparams( case LLM_ARCH_MPT: { GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_CLAMP_KQV)); + if (gguf_find_key(ctx, kv(LLM_KV_ATTENTION_CLAMP_KQV).c_str()) >= 0) { + GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_CLAMP_KQV)); + } else { + hparams.f_clamp_kqv = 0.0f; + } GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS)); switch (hparams.n_layer) { From 26c253eda29bac3d76f36fd37a6861e32961013a Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Sun, 1 Oct 2023 01:43:39 +0200 Subject: [PATCH 06/16] mpt : standardized all tensor names to follow GGUF spec --- convert-mpt-hf-to-gguf.py | 2 +- gguf-py/gguf/gguf.py | 16 ++++++++-------- llama.cpp | 21 ++++++++++----------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py index 057cb34f683f5..a0cee0cf9e547 100755 --- a/convert-mpt-hf-to-gguf.py +++ b/convert-mpt-hf-to-gguf.py @@ -247,7 +247,7 @@ def parse_args() -> argparse.Namespace: # note: MPT output is tied to (same as) wte in original model; # for easier implementation in llama.cpp it's duplicated in GGUF, though :/ - if new_name == "wte.weight": + if new_name == "token_embd.weight": gguf_writer.add_tensor("output.weight", data) print("gguf: write header") diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 9c49d0ada29a3..afd16e2126d70 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -186,17 +186,17 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", }, MODEL_ARCH.MPT: { - MODEL_TENSOR.TOKEN_EMBD: "wte", - MODEL_TENSOR.OUTPUT_NORM: "norm_f", + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", # note: MPT output is tied to (same as) wte in original model; # for easier implementation in llama.cpp it's duplicated in GGUF, though :/ MODEL_TENSOR.OUTPUT: "output", - MODEL_TENSOR.ATTN_NORM: "blk.{bid}.norm_1", - MODEL_TENSOR.FFN_NORM: "blk.{bid}.norm_2", - MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn.Wqkv", - MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn.out_proj", - MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn.down_proj", - MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn.up_proj", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", }, MODEL_ARCH.GPT2: { # TODO diff --git a/llama.cpp b/llama.cpp index 2bec27b8b41f2..7ea6dbe72758b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -377,15 +377,15 @@ static std::map> LLM_TENSOR_NAMES = { LLM_ARCH_MPT, { - { LLM_TENSOR_TOKEN_EMBD, "wte" }, - { LLM_TENSOR_OUTPUT_NORM, "norm_f" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.norm_1" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.norm_2" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn.Wqkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn.out_proj" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn.down_proj" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn.up_proj" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, { @@ -6516,11 +6516,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // TODO: avoid hardcoded tensor names - use the TN_* constants if (name.find("attn_v.weight") != std::string::npos || - name.find("attn.Wqkv.weight") != std::string::npos) { + name.find("attn_qkv.weight") != std::string::npos) { ++n_attention_wv; } - else if (name.find("ffn_down.weight") != std::string::npos || - name.find("ffn.down_proj.weight") != std::string::npos) { + else if (name.find("ffn_down.weight") != std::string::npos) { ++n_feed_forward_w2; } } From df072d2d999bc5924bebbe1ee37cdd810ed343fe Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Sun, 1 Oct 2023 01:48:47 +0200 Subject: [PATCH 07/16] mpt : addendum to changeset:1be89c40 - use "req" parameter of GGUF_GET_KEY macro instead of duplicate code --- llama.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7ea6dbe72758b..81a014d0a29c9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1925,12 +1925,10 @@ static void llm_load_hparams( } break; case LLM_ARCH_MPT: { + hparams.f_clamp_kqv = 0.0f; + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - if (gguf_find_key(ctx, kv(LLM_KV_ATTENTION_CLAMP_KQV).c_str()) >= 0) { - GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_CLAMP_KQV)); - } else { - hparams.f_clamp_kqv = 0.0f; - } + GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV)); GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS)); switch (hparams.n_layer) { From 90e7d6de28c05506e59a1422456bb6b2ab6a1f8b Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Mon, 2 Oct 2023 19:55:59 +0200 Subject: [PATCH 08/16] mpt : fixed comment s/gptneox/mpt/ --- convert-mpt-hf-to-gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py index a0cee0cf9e547..60bceb0fa98b4 100755 --- a/convert-mpt-hf-to-gguf.py +++ b/convert-mpt-hf-to-gguf.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# HF gptneox--> gguf conversion +# HF mpt--> gguf conversion from __future__ import annotations From 470801292df7d0d99e2782479c2909df191ab4e0 Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Mon, 2 Oct 2023 21:55:22 +0200 Subject: [PATCH 09/16] mpt : remove tabs, trailing whitespace --- convert-mpt-hf-to-gguf.py | 2 +- llama.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py index 60bceb0fa98b4..cbe4a9f110498 100755 --- a/convert-mpt-hf-to-gguf.py +++ b/convert-mpt-hf-to-gguf.py @@ -121,7 +121,7 @@ def parse_args() -> argparse.Namespace: gguf_writer.add_block_count(block_count) gguf_writer.add_feed_forward_length(4 * hparams["d_model"]) gguf_writer.add_head_count(hparams["n_heads"]) -gguf_writer.add_layer_norm_eps(1e-05) +gguf_writer.add_layer_norm_eps(1e-05) if hparams["attn_config"]["clip_qkv"] is not None: gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"]) diff --git a/llama.cpp b/llama.cpp index 81a014d0a29c9..ede95f60763a4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4304,11 +4304,11 @@ static struct ggml_cgraph * llm_build_mpt( offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - // TODO: replace with ggml_add() - struct ggml_tensor * KQ_scaled_alibi = - ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias); + // TODO: replace with ggml_add() + struct ggml_tensor * KQ_scaled_alibi = + ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias); offload_func_kq(KQ_scaled_alibi); - ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); offload_func_kq(KQ_masked); From 1364bcd712ef4d9c5ca905f4901d16b1dba695c4 Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Tue, 3 Oct 2023 21:53:31 +0200 Subject: [PATCH 10/16] mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) and rope_shift from build_mpt --- ggml-cuda.cu | 4 ++-- ggml.c | 4 ++-- llama.cpp | 32 +------------------------------- 3 files changed, 5 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0c873375d0b75..967fb5dca32e4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6295,12 +6295,12 @@ inline void ggml_cuda_op_alibi( const int64_t ne02 = src0->ne[2]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - GGML_ASSERT(ne01 + n_past == ne00); + //GGML_ASSERT(ne01 + n_past == ne00); GGML_ASSERT(n_head == ne02); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); diff --git a/ggml.c b/ggml.c index 820fe2e74b0ae..323b84974cafa 100644 --- a/ggml.c +++ b/ggml.c @@ -12889,7 +12889,7 @@ static void ggml_compute_forward_alibi_f32( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); @@ -12910,7 +12910,7 @@ static void ggml_compute_forward_alibi_f32( //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(ne1 + n_past == ne0); + //GGML_ASSERT(ne1 + n_past == ne0); GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) diff --git a/llama.cpp b/llama.cpp index ede95f60763a4..b1b1b801de89f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4076,8 +4076,6 @@ static struct ggml_cgraph * llm_build_mpt( const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); @@ -4176,34 +4174,6 @@ static struct ggml_cgraph * llm_build_mpt( } } - // shift the entire K-cache if needed - // TODO: Do we need to handle it? (MPT uses alibi instead of rope) -/* if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); - ggml_set_name(K_shift, "K_shift"); - ggml_allocr_alloc(lctx.alloc, K_shift); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) K_shift->data; - for (int i = 0; i < n_ctx; ++i) { - data[i] = kv_self.cells[i].delta; - } - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * tmp = - ggml_rope_custom_inplace(ctx0, - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_head_kv, n_ctx, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), - K_shift, n_embd_head, 2, 0, freq_base, freq_scale); - offload_func_kq(tmp); - ggml_build_forward_expand(gf, tmp); - } - }*/ - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; @@ -4306,7 +4276,7 @@ static struct ggml_cgraph * llm_build_mpt( // TODO: replace with ggml_add() struct ggml_tensor * KQ_scaled_alibi = - ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias); + ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias); offload_func_kq(KQ_scaled_alibi); ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); From 7d6a24aad4d2eae524bd3472290fbfb3efab5510 Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Fri, 6 Oct 2023 13:53:32 +0200 Subject: [PATCH 11/16] mpt : updated convert-mpt-hf-to-gguf.py to reflect changes made to convert-gptneox-hf-to-gguf.py in pr:3252 --- convert-mpt-hf-to-gguf.py | 55 +++++---------------------------------- 1 file changed, 7 insertions(+), 48 deletions(-) diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py index cbe4a9f110498..a6a049bc922e2 100755 --- a/convert-mpt-hf-to-gguf.py +++ b/convert-mpt-hf-to-gguf.py @@ -19,29 +19,6 @@ sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf')) import gguf -# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py - - -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a significant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - return dict(zip(bs, (chr(n) for n in cs))) - def count_model_parts(dir_model: Path) -> int: num_parts = 0 @@ -131,6 +108,8 @@ def parse_args() -> argparse.Namespace: print("gguf: get tokenizer metadata") tokens: list[bytearray] = [] +scores: list[float] = [] +toktypes: list[int] = [] tokenizer_json_file = dir_model / 'tokenizer.json' if not tokenizer_json_file.is_file(): @@ -155,31 +134,15 @@ def parse_args() -> argparse.Namespace: tokenizer = AutoTokenizer.from_pretrained(dir_model) reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} -byte_encoder = bytes_to_unicode() -byte_decoder = {v: k for k, v in byte_encoder.items()} for i in range(vocab_size): - if i in reverse_vocab: - try: - text = bytearray([byte_decoder[c] for c in reverse_vocab[i]]) - except KeyError: - text = bytearray() - for c in reverse_vocab[i]: - if ord(c) < 256: # single byte character - try: - text.append(byte_decoder[c]) - except KeyError: - text.extend(c.encode('utf-8')) - else: # multibyte special token character - text.extend(c.encode('utf-8')) - else: - print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token. (It's normal for MPT.)") - pad_token = f"[PAD{i}]".encode("utf8") - text = bytearray(pad_token) - - tokens.append(text) + tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]") + scores.append(0.0) # dummy + toktypes.append(gguf.TokenType.NORMAL) gguf_writer.add_token_list(tokens) +gguf_writer.add_token_scores(scores) +gguf_writer.add_token_types(toktypes) special_vocab = gguf.SpecialVocab(dir_model, load_merges = True) special_vocab.add_to_gguf(gguf_writer) @@ -239,10 +202,6 @@ def parse_args() -> argparse.Namespace: print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) - -# if new_name == "wte.weight" and data.shape[0] == 50432 and vocab_size == 50254: -# data = data[0:vocab_size,:] - gguf_writer.add_tensor(new_name, data) # note: MPT output is tied to (same as) wte in original model; From ad3c2f3b23164420c93694684be7e906811631d1 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Mon, 9 Oct 2023 10:16:24 -0400 Subject: [PATCH 12/16] comment out n_past instead of marking it unused --- ggml-metal.m | 2 +- ggml.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 5a23144d0c891..87fa172161405 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute( const int nth = MIN(1024, ne00); - const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); diff --git a/ggml.c b/ggml.c index 5bb1da31ba624..8cdab8854f756 100644 --- a/ggml.c +++ b/ggml.c @@ -13059,7 +13059,7 @@ static void ggml_compute_forward_alibi_f32( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); From 1a454eb561d19b7dba5d377feb48144c17128f5b Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Mon, 9 Oct 2023 16:48:01 +0200 Subject: [PATCH 13/16] mpt : removed hardcoded +178 from convert script in favor of utilizing hparams["vocab_size"] --- convert-mpt-hf-to-gguf.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py index a6a049bc922e2..6d318dbb63e28 100755 --- a/convert-mpt-hf-to-gguf.py +++ b/convert-mpt-hf-to-gguf.py @@ -124,11 +124,13 @@ def parse_args() -> argparse.Namespace: print("gguf: get gpt2 tokenizer vocab") -# MPT token embedding tensors have dimension 50432, but there are only 50254 +# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), +# but there are only 50254 (len(tokenizer_json["model"]["vocab"])) # tokens in the vocab, presumably to accomodate some "reserved" tokens; -# this is causing problems down the line in llama.cpp, so we extend the vocab_size: +# this is causing problems down the line in llama.cpp, so we pad the vocab +# with dummy tokens: -vocab_size = len(tokenizer_json["model"]["vocab"]) + 178 +vocab_size = hparams["vocab_size"] # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py tokenizer = AutoTokenizer.from_pretrained(dir_model) From 32172f12f57e02b1a233822c6999be592b8a739e Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Mon, 9 Oct 2023 11:39:53 -0400 Subject: [PATCH 14/16] mpt : remove unused tokenizer_json in convert script --- convert-mpt-hf-to-gguf.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py index 6d318dbb63e28..73a4932f7c831 100755 --- a/convert-mpt-hf-to-gguf.py +++ b/convert-mpt-hf-to-gguf.py @@ -111,24 +111,15 @@ def parse_args() -> argparse.Namespace: scores: list[float] = [] toktypes: list[int] = [] -tokenizer_json_file = dir_model / 'tokenizer.json' -if not tokenizer_json_file.is_file(): - print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr) - sys.exit(1) - # gpt2 tokenizer gguf_writer.add_tokenizer_model("gpt2") -with open(tokenizer_json_file, "r", encoding="utf-8") as f: - tokenizer_json = json.load(f) - print("gguf: get gpt2 tokenizer vocab") -# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), -# but there are only 50254 (len(tokenizer_json["model"]["vocab"])) -# tokens in the vocab, presumably to accomodate some "reserved" tokens; -# this is causing problems down the line in llama.cpp, so we pad the vocab -# with dummy tokens: +# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but +# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to +# accomodate some "reserved" tokens; this is causing problems down the line in +# llama.cpp, so we pad the vocab with dummy tokens: vocab_size = hparams["vocab_size"] From 96cf3f5dc3e145a9555df377947ac57ecabaa708 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 10 Oct 2023 10:45:24 +0300 Subject: [PATCH 15/16] ggml : remove obsolete n_past assert in ggml_alibi --- ggml.c | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml.c b/ggml.c index 8cdab8854f756..1f5598fa6af8f 100644 --- a/ggml.c +++ b/ggml.c @@ -13064,8 +13064,6 @@ static void ggml_compute_forward_alibi_f32( float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int64_t ne1 = src0->ne[1]; // seq_len_without_past const int64_t ne2 = src0->ne[2]; // n_head -> this is k From 9b66378cacfb0a4e0401e7682ec66f6ab89348ed Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 10 Oct 2023 10:49:39 +0300 Subject: [PATCH 16/16] llama : print clam_kqv and max_alibi_bias hparams --- llama.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index eaa008dbfaa00..3b63b64010b0f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2229,6 +2229,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); @@ -5013,8 +5015,8 @@ static struct ggml_cgraph * llm_build_mpt( const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_gqa = hparams.n_embd_gqa(); - const float norm_eps = hparams.f_norm_eps; - const float clamp_kqv = hparams.f_clamp_kqv; + const float norm_eps = hparams.f_norm_eps; + const float clamp_kqv = hparams.f_clamp_kqv; const float max_alibi_bias = hparams.f_max_alibi_bias; const int n_gpu_layers = model.n_gpu_layers;