From b49792b044316dc2751b9504612aca018e268129 Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Sat, 30 Sep 2023 18:35:35 +0200
Subject: [PATCH 01/16] CUDA: added support for ggml_clamp (see also:
 https://github.com/ggerganov/ggml/issues/545)

---
 ggml-cuda.cu | 44 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 44 insertions(+)

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 989c419cd0ea4..0c873375d0b75 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -414,6 +414,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
+#define CUDA_CLAMP_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
 #define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
@@ -4555,6 +4556,16 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
     dst[i] = scale * x[i];
 }
 
+static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+}
+
 static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
     const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
     add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -5436,6 +5447,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
 }
 
+static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
+    clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
+}
+
 template<typename T>
 static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
                           const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
@@ -6353,6 +6369,24 @@ inline void ggml_cuda_op_scale(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_clamp(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const float min = ((float *) dst->op_params)[0];
+    const float max = ((float *) dst->op_params)[1];
+
+    clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
     const int64_t nrows0 = ggml_nrows(src0);
 
@@ -6906,6 +6940,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
 }
 
+static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
+}
+
 static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
@@ -7330,6 +7368,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_scale;
             break;
+        case GGML_OP_CLAMP:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_clamp;
+            break;
         case GGML_OP_CPY:
             if (!any_on_device) {
                 return false;

From 15236e855b2e167dad9ebf528ebae4117554c534 Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Sat, 30 Sep 2023 18:49:22 +0200
Subject: [PATCH 02/16] mpt : added an implementation based (mostly) on falcon
 integration, modified with deltas from ggml/examples/mpt

---
 convert-mpt-hf-to-gguf.py | 263 ++++++++++++++++++++++
 gguf-py/gguf/gguf.py      |  14 ++
 llama.cpp                 | 449 +++++++++++++++++++++++++++++++++++++-
 3 files changed, 724 insertions(+), 2 deletions(-)
 create mode 100755 convert-mpt-hf-to-gguf.py

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
new file mode 100755
index 0000000000000..b5888fd576b1a
--- /dev/null
+++ b/convert-mpt-hf-to-gguf.py
@@ -0,0 +1,263 @@
+#!/usr/bin/env python3
+# HF gptneox--> gguf conversion
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import struct
+import sys
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer  # type: ignore[import]
+
+if 'NO_LOCAL_GGUF' not in os.environ:
+    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
+import gguf
+
+# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+
+
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a significant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    return dict(zip(bs, (chr(n) for n in cs)))
+
+
+def count_model_parts(dir_model: Path) -> int:
+    num_parts = 0
+    for filename in os.listdir(dir_model):
+        if filename.startswith("pytorch_model-"):
+            num_parts += 1
+
+    if num_parts > 0:
+        print("gguf: found " + str(num_parts) + " model parts")
+    return num_parts
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file")
+    parser.add_argument(
+        "--vocab-only", action="store_true",
+        help="extract only the vocab",
+    )
+    parser.add_argument(
+        "--outfile", type=Path,
+        help="path to write to; default: based on input",
+    )
+    parser.add_argument(
+        "model", type=Path,
+        help="directory containing model file, or model file itself (*.bin)",
+    )
+    parser.add_argument(
+        "ftype", type=int, choices=[0, 1], default=1, nargs='?',
+        help="output format - use 0 for float32, 1 for float16",
+    )
+    return parser.parse_args()
+
+args = parse_args()
+
+dir_model = args.model
+ftype = args.ftype
+if not dir_model.is_dir():
+    print(f'Error: {args.model} is not a directory', file = sys.stderr)
+    sys.exit(1)
+
+# possible tensor data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+if args.outfile is not None:
+    fname_out = args.outfile
+else:
+    # output in the same directory as the model by default
+    fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
+
+print("gguf: loading model "+dir_model.name)
+
+with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+    hparams = json.load(f)
+
+if hparams["architectures"][0] != "MPTForCausalLM":
+    print("Model architecture not supported: " + hparams["architectures"][0])
+
+    sys.exit()
+
+# get number of model parts
+num_parts = count_model_parts(dir_model)
+
+ARCH=gguf.MODEL_ARCH.MPT
+gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
+
+print("gguf: get model metadata")
+
+block_count = hparams["n_layers"]
+
+gguf_writer.add_name(dir_model.name)
+gguf_writer.add_context_length(hparams["max_seq_len"])
+gguf_writer.add_embedding_length(hparams["d_model"])
+gguf_writer.add_block_count(block_count)
+gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
+gguf_writer.add_head_count(hparams["n_heads"])
+gguf_writer.add_layer_norm_eps(1e-05) 
+gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) 
+gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])
+
+# TOKENIZATION
+
+print("gguf: get tokenizer metadata")
+
+tokens: list[bytearray] = []
+
+tokenizer_json_file = dir_model / 'tokenizer.json'
+if not tokenizer_json_file.is_file():
+    print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
+    sys.exit(1)
+
+# gpt2 tokenizer
+gguf_writer.add_tokenizer_model("gpt2")
+
+with open(tokenizer_json_file, "r", encoding="utf-8") as f:
+    tokenizer_json = json.load(f)
+
+print("gguf: get gpt2 tokenizer vocab")
+
+# MPT token embedding tensors have dimension 50432, but there are only 50254
+# tokens in the vocab, presumably to accomodate some "reserved" tokens;
+# this is causing problems down the line in llama.cpp, so we extend the vocab_size:
+
+vocab_size = len(tokenizer_json["model"]["vocab"]) + 178
+
+# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
+tokenizer = AutoTokenizer.from_pretrained(dir_model)
+
+reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v: k for k, v in byte_encoder.items()}
+
+for i in range(vocab_size):
+    if i in reverse_vocab:
+        try:
+            text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
+        except KeyError:
+            text = bytearray()
+            for c in reverse_vocab[i]:
+                if ord(c) < 256:  # single byte character
+                    try:
+                        text.append(byte_decoder[c])
+                    except KeyError:
+                        text.extend(c.encode('utf-8'))
+                else:  # multibyte special token character
+                    text.extend(c.encode('utf-8'))
+    else:
+        print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token. (It's normal for MPT.)")
+        pad_token = f"[PAD{i}]".encode("utf8")
+        text = bytearray(pad_token)
+
+    tokens.append(text)
+
+gguf_writer.add_token_list(tokens)
+
+special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
+special_vocab.add_to_gguf(gguf_writer)
+
+# TENSORS
+
+tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
+
+# tensor info
+print("gguf: get tensor metadata")
+
+if num_parts == 0:
+    part_names = iter(("pytorch_model.bin",))
+else:
+    part_names = (
+        f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
+    )
+
+for part_name in part_names:
+    if args.vocab_only:
+        break
+    print("gguf: loading model part '" + part_name + "'")
+    model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+
+    for name in model_part.keys():
+        data = model_part[name]
+
+        old_dtype = data.dtype
+
+        # convert any unsupported data types to float32
+        if data.dtype != torch.float16 and data.dtype != torch.float32:
+            data = data.to(torch.float32)
+
+        data = data.squeeze().numpy()
+
+        # map tensor names
+        new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
+        if new_name is None:
+            print("Cannot map tensor '" + name + "'")
+            continue # for the sake of compatibility with some old published models, don't quit
+            sys.exit()
+
+        n_dims = len(data.shape)
+        data_dtype = data.dtype
+
+        # if f32 desired, convert any float16 to float32
+        if ftype == 0 and data_dtype == np.float16:
+            data = data.astype(np.float32)
+
+        # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+        if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+            data = data.astype(np.float32)
+
+        # if f16 desired, convert any float32 2-dim weight tensors to float16
+        if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+            data = data.astype(np.float16)
+
+        print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
+
+
+#        if new_name == "wte.weight" and data.shape[0] == 50432 and vocab_size == 50254:
+#            data = data[0:vocab_size,:]
+
+        gguf_writer.add_tensor(new_name, data)
+
+        # note: MPT output is tied to (same as) wte in original model;
+        # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
+        if new_name == "wte.weight":
+            gguf_writer.add_tensor("output.weight", data)
+
+print("gguf: write header")
+gguf_writer.write_header_to_file()
+print("gguf: write metadata")
+gguf_writer.write_kv_data_to_file()
+if not args.vocab_only:
+    print("gguf: write tensors")
+    gguf_writer.write_tensors_to_file()
+
+gguf_writer.close()
+
+print(f"gguf: model successfully exported to '{fname_out}'")
+print("")
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py
index 598cf8e594aa8..9c49d0ada29a3 100644
--- a/gguf-py/gguf/gguf.py
+++ b/gguf-py/gguf/gguf.py
@@ -185,6 +185,19 @@ class MODEL_TENSOR(IntEnum):
         MODEL_TENSOR.FFN_DOWN:      "blk.{bid}.ffn_down",
         MODEL_TENSOR.FFN_UP:        "blk.{bid}.ffn_up",
     },
+    MODEL_ARCH.MPT: {
+        MODEL_TENSOR.TOKEN_EMBD:    "wte",
+        MODEL_TENSOR.OUTPUT_NORM:   "norm_f",
+        # note: MPT output is tied to (same as) wte in original model;
+        # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
+        MODEL_TENSOR.OUTPUT:        "output",
+        MODEL_TENSOR.ATTN_NORM:     "blk.{bid}.norm_1",
+        MODEL_TENSOR.FFN_NORM:      "blk.{bid}.norm_2",
+        MODEL_TENSOR.ATTN_QKV:      "blk.{bid}.attn.Wqkv",
+        MODEL_TENSOR.ATTN_OUT:      "blk.{bid}.attn.out_proj",
+        MODEL_TENSOR.FFN_DOWN:      "blk.{bid}.ffn.down_proj",
+        MODEL_TENSOR.FFN_UP:        "blk.{bid}.ffn.up_proj",
+    },
     MODEL_ARCH.GPT2: {
         # TODO
     },
@@ -231,6 +244,7 @@ class TensorNameMap:
         MODEL_TENSOR.OUTPUT_NORM: (
             "gpt_neox.final_layer_norm", # gptneox
             "transformer.ln_f",          # gpt2 falcon
+            "transformer.norm_f",        # mpt
             "model.norm",                # llama-hf baichuan
             "norm",                      # llama-pth
         ),
diff --git a/llama.cpp b/llama.cpp
index bff17135b985f..108ae5eb4725f 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -377,7 +377,15 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
     {
         LLM_ARCH_MPT,
         {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD,      "wte" },
+            { LLM_TENSOR_OUTPUT_NORM,     "norm_f" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.norm_1" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.norm_2" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn.Wqkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn.out_proj" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn.down_proj" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn.up_proj" },
         },
     },
     {
@@ -947,6 +955,9 @@ struct llama_hparams {
     float rope_freq_base_train;
     float rope_freq_scale_train;
 
+    float f_clamp_kqv;
+    float f_max_alibi_bias;
+
     bool operator!=(const llama_hparams & other) const {
         return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
     }
@@ -1912,6 +1923,18 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_MPT:
+            {
+                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+                GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_CLAMP_KQV));
+                GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
+
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 48: model.type = e_model::MODEL_30B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -2436,6 +2459,73 @@ static void llm_load_tensors(
                         }
                     }
                 } break;
+            case LLM_ARCH_MPT:
+                {
+                    model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+
+                    // output
+                    {
+                        ggml_backend backend_norm;
+                        ggml_backend backend_output;
+
+                        if (n_gpu_layers > int(n_layer)) {
+                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+                            // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+#else
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#endif // _WIN32
+
+                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                        } else {
+                            backend_norm   = GGML_BACKEND_CPU;
+                            backend_output = GGML_BACKEND_CPU;
+                        }
+
+                        model.output_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
+                        model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+
+                        if (backend_norm == GGML_BACKEND_GPU) {
+                            vram_weights += ggml_nbytes(model.output_norm);
+                        }
+                        if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+                            vram_weights += ggml_nbytes(model.output);
+                        }
+                    }
+
+                    const uint32_t n_ff = hparams.n_ff;
+
+                    const int i_gpu_start = n_layer - n_gpu_layers;
+
+                    model.layers.resize(n_layer);
+
+                    for (uint32_t i = 0; i < n_layer; ++i) {
+                        const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+                        const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+                        layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
+                        layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},     backend_split);
+
+                        layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
+
+                        layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
+                        layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+
+                        if (backend == GGML_BACKEND_GPU) {
+                            vram_weights +=
+                                ggml_nbytes(layer.attn_norm) +
+                                ggml_nbytes(layer.wqkv)      +
+                                ggml_nbytes(layer.wo)        +
+                                ggml_nbytes(layer.ffn_norm)  +
+                                ggml_nbytes(layer.w2)        +
+                                ggml_nbytes(layer.w3);
+                        }
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -3955,6 +4045,356 @@ static struct ggml_cgraph * llm_build_starcoder(
     return gf;
 }
 
+static struct ggml_cgraph * llm_build_mpt(
+         llama_context & lctx,
+     const llama_batch & batch) {
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    const auto & kv_self = lctx.kv_self;
+
+    GGML_ASSERT(!!kv_self.ctx);
+
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_layer     = hparams.n_layer;
+    const int64_t n_ctx       = cparams.n_ctx;
+    const int64_t n_head      = hparams.n_head;
+    const int64_t n_head_kv   = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
+    const int64_t n_embd_head = hparams.n_embd_head();
+    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
+
+    const float norm_eps = hparams.f_norm_eps;
+    const float clamp_kqv = hparams.f_clamp_kqv;
+    const float max_alibi_bias = hparams.f_max_alibi_bias;
+
+    const int n_gpu_layers = model.n_gpu_layers;
+
+    const int32_t n_tokens = batch.n_tokens;
+    const int32_t n_kv     = ggml_allocr_is_measure(lctx.alloc) ? n_ctx            : kv_self.n;
+    const int32_t kv_head  = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
+
+    const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
+
+    //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
+    //        kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
+
+    auto & buf_compute = lctx.buf_compute;
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute.size,
+        /*.mem_buffer =*/ buf_compute.data,
+        /*.no_alloc   =*/ false,
+    };
+
+    params.no_alloc = true;
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * cur;
+    struct ggml_tensor * inpL;
+
+    //int warmup = 0;
+    if (batch.token) {
+        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+
+        ggml_allocr_alloc(lctx.alloc, inp_tokens);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
+            //warmup = ((uint32_t*) inp_tokens->data)[0] == 0;
+        }
+
+        ggml_set_name(inp_tokens, "inp_tokens");
+
+        inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
+    } else {
+#ifdef GGML_USE_MPI
+        GGML_ASSERT(false && "not implemented");
+#endif
+
+        inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
+
+        ggml_allocr_alloc(lctx.alloc, inpL);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
+        }
+    }
+
+    const int i_gpu_start = n_layer - n_gpu_layers;
+    (void) i_gpu_start;
+
+    // offload functions set the tensor output backend to GPU
+    // tensors are GPU-accelerated if any input or the output has been offloaded
+    offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
+    offload_func_t offload_func_kq = llama_nop;
+    offload_func_t offload_func_v  = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+    if (n_gpu_layers > n_layer) {
+        offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
+    }
+    if (n_gpu_layers > n_layer + 1) {
+        offload_func_v  = ggml_cuda_assign_buffers_no_alloc;
+    }
+    if (n_gpu_layers > n_layer + 2) {
+        offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
+    }
+#endif // GGML_USE_CUBLAS
+
+    // KQ_scale
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+    ggml_allocr_alloc(lctx.alloc, KQ_scale);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+    }
+
+    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+    offload_func_kq(KQ_mask);
+    ggml_set_name(KQ_mask, "KQ_mask");
+    ggml_allocr_alloc(lctx.alloc, KQ_mask);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        float * data = (float *) KQ_mask->data;
+        memset(data, 0, ggml_nbytes(KQ_mask));
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                const llama_pos    pos    = batch.pos[j];
+                const llama_seq_id seq_id = batch.seq_id[j];
+
+                for (int i = 0; i < n_kv; ++i) {
+                    if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+                    }
+                }
+            }
+        }
+    }
+
+    // shift the entire K-cache if needed
+    // TODO: Do we need to handle it? (MPT uses alibi instead of rope)
+/*    if (do_rope_shift) {
+        struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
+        offload_func_kq(K_shift);
+        ggml_set_name(K_shift, "K_shift");
+        ggml_allocr_alloc(lctx.alloc, K_shift);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            int * data = (int *) K_shift->data;
+            for (int i = 0; i < n_ctx; ++i) {
+                data[i] = kv_self.cells[i].delta;
+            }
+        }
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * tmp =
+                    ggml_rope_custom_inplace(ctx0,
+                        ggml_view_3d(ctx0, kv_self.k,
+                            n_embd_head, n_head_kv, n_ctx,
+                            ggml_element_size(kv_self.k)*n_embd_head,
+                            ggml_element_size(kv_self.k)*n_embd_gqa,
+                            ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
+                        K_shift, n_embd_head, 2, 0, freq_base, freq_scale);
+            offload_func_kq(tmp);
+            ggml_build_forward_expand(gf, tmp);
+        }
+    }*/
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * attn_norm;
+
+        offload_func_t offload_func = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+        if (il >= i_gpu_start) {
+            offload_func = ggml_cuda_assign_buffers_no_alloc;
+        }
+#endif // GGML_USE_CUBLAS
+
+        // self-attention
+        // TODO: refactor into common function (shared with LLaMA)
+        {
+            attn_norm = ggml_norm(ctx0, inpL, norm_eps);
+            offload_func(attn_norm);
+
+            attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm);
+            offload_func(attn_norm);
+
+            if (1) {
+                cur = attn_norm;
+            }
+
+            // compute QKV
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+            offload_func_kq(cur);
+
+            if (clamp_kqv > 0.0f) {
+                cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv);
+                offload_func_kq(cur);
+            }
+
+            const size_t wsize = ggml_type_size(cur->type);
+
+            struct ggml_tensor * Qcur = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head, n_tokens,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                0);
+            offload_func_kq(Qcur);
+
+            struct ggml_tensor * Kcur = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head_kv, n_tokens,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                wsize * n_embd_head *  n_head);
+            offload_func_kq(Kcur);
+
+            struct ggml_tensor * tmpv = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head_kv, n_tokens,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                wsize * n_embd_head * (n_head +     n_head_kv));
+            offload_func_kq(Kcur);
+
+            ggml_set_name(Qcur, "Qcur");
+            ggml_set_name(Kcur, "Kcur");
+
+            {
+                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
+                offload_func_v(Vcur);
+                offload_func_v(Vcur->src[0]->src[0]);
+                ggml_set_name(Vcur, "Vcur");
+
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
+                offload_func_kq(k);
+                ggml_set_name(k, "k");
+
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
+                        (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
+                offload_func_v(v);
+
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+            offload_func_kq(Q);
+            ggml_set_name(Q, "Q");
+
+            struct ggml_tensor * K =
+                ggml_view_3d(ctx0, kv_self.k,
+                        n_embd_head, n_kv, n_head_kv,
+                        ggml_element_size(kv_self.k)*n_embd_gqa,
+                        ggml_element_size(kv_self.k)*n_embd_head,
+                        ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+            offload_func_kq(K);
+            ggml_set_name(K, "K");
+
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            offload_func_kq(KQ);
+            ggml_set_name(KQ, "KQ");
+
+            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
+            offload_func_kq(KQ_scaled);
+            ggml_set_name(KQ_scaled, "KQ_scaled");
+
+			// TODO: replace with ggml_add()
+			struct ggml_tensor * KQ_scaled_alibi =
+				ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias);
+            offload_func_kq(KQ_scaled_alibi);
+			ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
+
+            struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
+            offload_func_kq(KQ_masked);
+            ggml_set_name(KQ_masked, "KQ_masked");
+
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+            offload_func_v(KQ_soft_max);
+            ggml_set_name(KQ_soft_max, "KQ_soft_max");
+
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, kv_self.v,
+                        n_kv, n_embd_head, n_head_kv,
+                        ggml_element_size(kv_self.v)*n_ctx,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+            offload_func_v(V);
+            ggml_set_name(V, "V");
+
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+            offload_func_v(KQV);
+            ggml_set_name(KQV, "KQV");
+
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            offload_func_v(KQV_merged);
+            ggml_set_name(KQV_merged, "KQV_merged");
+
+            cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
+            offload_func_v(cur);
+            ggml_set_name(cur, "KQV_merged_contiguous");
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+            offload_func(cur);
+            ggml_set_name(cur, "result_wo");
+        }
+
+        // Add the input
+        cur = ggml_add(ctx0, cur, inpL);
+        offload_func(cur);
+
+        struct ggml_tensor * attn_out = cur;
+
+        // feed forward
+        {
+            // Norm
+            {
+                cur = ggml_norm(ctx0, attn_out, norm_eps);
+                offload_func(cur);
+
+                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
+                offload_func(cur);
+            }
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
+            offload_func(cur);
+
+            cur = ggml_gelu(ctx0, cur);
+            offload_func(cur);
+            cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
+            offload_func(cur);
+        }
+
+        cur = ggml_add(ctx0, cur, attn_out);
+        offload_func(cur);
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    // norm
+    {
+        cur = ggml_norm(ctx0, cur, norm_eps);
+        offload_func_nr(cur);
+
+        cur = ggml_mul(ctx0, cur, model.output_norm);
+        ggml_set_name(cur, "result_norm");
+    }
+
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    ggml_set_name(cur, "result_output");
+
+    ggml_build_forward_expand(gf, cur);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
 static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
      const llama_batch & batch) {
@@ -3979,6 +4419,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm_build_starcoder(lctx, batch);
             } break;
+        case LLM_ARCH_MPT:
+            {
+                result = llm_build_mpt(lctx, batch);
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -4112,7 +4556,8 @@ static int llama_decode_internal(
     // If all tensors can be run on the GPU then using more than 1 thread is detrimental.
     const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA ||
         model.arch == LLM_ARCH_BAICHUAN ||
-        model.arch == LLM_ARCH_FALCON;
+        model.arch == LLM_ARCH_FALCON ||
+        model.arch == LLM_ARCH_MPT;
     const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
     if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
         n_threads = 1;

From 84e30e891c2e211f10a3fa21363095ca0ed97b3c Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Sun, 1 Oct 2023 00:32:33 +0200
Subject: [PATCH 03/16] mpt : protect against "clip_qkv": null in mpt-7b

---
 convert-mpt-hf-to-gguf.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
index b5888fd576b1a..e0a8b166dd9e2 100755
--- a/convert-mpt-hf-to-gguf.py
+++ b/convert-mpt-hf-to-gguf.py
@@ -122,7 +122,7 @@ def parse_args() -> argparse.Namespace:
 gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
 gguf_writer.add_head_count(hparams["n_heads"])
 gguf_writer.add_layer_norm_eps(1e-05) 
-gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) 
+gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"] or 0.0)
 gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])
 
 # TOKENIZATION

From 00e8c5c5f6253347c61e3b74965df3577cccfee6 Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Sun, 1 Oct 2023 00:49:13 +0200
Subject: [PATCH 04/16] mpt : quick fix to avoid "Strange model" warning when
 quantizing MPT models

---
 llama.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llama.cpp b/llama.cpp
index 108ae5eb4725f..a2bdd9d3dfeb2 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6511,10 +6511,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         const std::string name = ggml_get_name(meta);
 
         // TODO: avoid hardcoded tensor names - use the TN_* constants
-        if (name.find("attn_v.weight") != std::string::npos) {
+        if (name.find("attn_v.weight") != std::string::npos ||
+            name.find("attn.Wqkv.weight") != std::string::npos) {
             ++n_attention_wv;
         }
-        else if (name.find("ffn_down.weight") != std::string::npos) {
+        else if (name.find("ffn_down.weight") != std::string::npos ||
+                 name.find("ffn.down_proj.weight") != std::string::npos) {
             ++n_feed_forward_w2;
         }
     }

From 1be89c4002a247449d220824e76c6d26fd548eae Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Sun, 1 Oct 2023 01:14:07 +0200
Subject: [PATCH 05/16] mpt : addendum to changeset:84e30e8 - leave parameter
 clamp_kqv out from metadata rather than use 0.0 to indicate "no clamping"
 (more compliant with the current GGUF spec?)

---
 convert-mpt-hf-to-gguf.py | 3 ++-
 llama.cpp                 | 6 +++++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
index e0a8b166dd9e2..057cb34f683f5 100755
--- a/convert-mpt-hf-to-gguf.py
+++ b/convert-mpt-hf-to-gguf.py
@@ -122,7 +122,8 @@ def parse_args() -> argparse.Namespace:
 gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
 gguf_writer.add_head_count(hparams["n_heads"])
 gguf_writer.add_layer_norm_eps(1e-05) 
-gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"] or 0.0)
+if hparams["attn_config"]["clip_qkv"] is not None:
+    gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
 gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])
 
 # TOKENIZATION
diff --git a/llama.cpp b/llama.cpp
index a2bdd9d3dfeb2..2bec27b8b41f2 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1926,7 +1926,11 @@ static void llm_load_hparams(
         case LLM_ARCH_MPT:
             {
                 GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
-                GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_CLAMP_KQV));
+                if (gguf_find_key(ctx, kv(LLM_KV_ATTENTION_CLAMP_KQV).c_str()) >= 0) {
+                    GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_CLAMP_KQV));
+                } else {
+                    hparams.f_clamp_kqv = 0.0f;
+                }
                 GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
 
                 switch (hparams.n_layer) {

From 26c253eda29bac3d76f36fd37a6861e32961013a Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Sun, 1 Oct 2023 01:43:39 +0200
Subject: [PATCH 06/16] mpt : standardized all tensor names to follow GGUF spec

---
 convert-mpt-hf-to-gguf.py |  2 +-
 gguf-py/gguf/gguf.py      | 16 ++++++++--------
 llama.cpp                 | 21 ++++++++++-----------
 3 files changed, 19 insertions(+), 20 deletions(-)

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
index 057cb34f683f5..a0cee0cf9e547 100755
--- a/convert-mpt-hf-to-gguf.py
+++ b/convert-mpt-hf-to-gguf.py
@@ -247,7 +247,7 @@ def parse_args() -> argparse.Namespace:
 
         # note: MPT output is tied to (same as) wte in original model;
         # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
-        if new_name == "wte.weight":
+        if new_name == "token_embd.weight":
             gguf_writer.add_tensor("output.weight", data)
 
 print("gguf: write header")
diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py
index 9c49d0ada29a3..afd16e2126d70 100644
--- a/gguf-py/gguf/gguf.py
+++ b/gguf-py/gguf/gguf.py
@@ -186,17 +186,17 @@ class MODEL_TENSOR(IntEnum):
         MODEL_TENSOR.FFN_UP:        "blk.{bid}.ffn_up",
     },
     MODEL_ARCH.MPT: {
-        MODEL_TENSOR.TOKEN_EMBD:    "wte",
-        MODEL_TENSOR.OUTPUT_NORM:   "norm_f",
+        MODEL_TENSOR.TOKEN_EMBD:    "token_embd",
+        MODEL_TENSOR.OUTPUT_NORM:   "output_norm",
         # note: MPT output is tied to (same as) wte in original model;
         # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
         MODEL_TENSOR.OUTPUT:        "output",
-        MODEL_TENSOR.ATTN_NORM:     "blk.{bid}.norm_1",
-        MODEL_TENSOR.FFN_NORM:      "blk.{bid}.norm_2",
-        MODEL_TENSOR.ATTN_QKV:      "blk.{bid}.attn.Wqkv",
-        MODEL_TENSOR.ATTN_OUT:      "blk.{bid}.attn.out_proj",
-        MODEL_TENSOR.FFN_DOWN:      "blk.{bid}.ffn.down_proj",
-        MODEL_TENSOR.FFN_UP:        "blk.{bid}.ffn.up_proj",
+        MODEL_TENSOR.ATTN_NORM:     "blk.{bid}.attn_norm",
+        MODEL_TENSOR.FFN_NORM:      "blk.{bid}.ffn_norm",
+        MODEL_TENSOR.ATTN_QKV:      "blk.{bid}.attn_qkv",
+        MODEL_TENSOR.ATTN_OUT:      "blk.{bid}.attn_output",
+        MODEL_TENSOR.FFN_DOWN:      "blk.{bid}.ffn_down",
+        MODEL_TENSOR.FFN_UP:        "blk.{bid}.ffn_up",
     },
     MODEL_ARCH.GPT2: {
         # TODO
diff --git a/llama.cpp b/llama.cpp
index 2bec27b8b41f2..7ea6dbe72758b 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -377,15 +377,15 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
     {
         LLM_ARCH_MPT,
         {
-            { LLM_TENSOR_TOKEN_EMBD,      "wte" },
-            { LLM_TENSOR_OUTPUT_NORM,     "norm_f" },
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
             { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.norm_1" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.norm_2" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn.Wqkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn.out_proj" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn.down_proj" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn.up_proj" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
     {
@@ -6516,11 +6516,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
         // TODO: avoid hardcoded tensor names - use the TN_* constants
         if (name.find("attn_v.weight") != std::string::npos ||
-            name.find("attn.Wqkv.weight") != std::string::npos) {
+            name.find("attn_qkv.weight") != std::string::npos) {
             ++n_attention_wv;
         }
-        else if (name.find("ffn_down.weight") != std::string::npos ||
-                 name.find("ffn.down_proj.weight") != std::string::npos) {
+        else if (name.find("ffn_down.weight") != std::string::npos) {
             ++n_feed_forward_w2;
         }
     }

From df072d2d999bc5924bebbe1ee37cdd810ed343fe Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Sun, 1 Oct 2023 01:48:47 +0200
Subject: [PATCH 07/16] mpt : addendum to changeset:1be89c40 - use "req"
 parameter of GGUF_GET_KEY macro instead of duplicate code

---
 llama.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/llama.cpp b/llama.cpp
index 7ea6dbe72758b..81a014d0a29c9 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1925,12 +1925,10 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_MPT:
             {
+                hparams.f_clamp_kqv = 0.0f;
+
                 GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
-                if (gguf_find_key(ctx, kv(LLM_KV_ATTENTION_CLAMP_KQV).c_str()) >= 0) {
-                    GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_CLAMP_KQV));
-                } else {
-                    hparams.f_clamp_kqv = 0.0f;
-                }
+                GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV));
                 GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
 
                 switch (hparams.n_layer) {

From 90e7d6de28c05506e59a1422456bb6b2ab6a1f8b Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Mon, 2 Oct 2023 19:55:59 +0200
Subject: [PATCH 08/16] mpt : fixed comment s/gptneox/mpt/

---
 convert-mpt-hf-to-gguf.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
index a0cee0cf9e547..60bceb0fa98b4 100755
--- a/convert-mpt-hf-to-gguf.py
+++ b/convert-mpt-hf-to-gguf.py
@@ -1,5 +1,5 @@
 #!/usr/bin/env python3
-# HF gptneox--> gguf conversion
+# HF mpt--> gguf conversion
 
 from __future__ import annotations
 

From 470801292df7d0d99e2782479c2909df191ab4e0 Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Mon, 2 Oct 2023 21:55:22 +0200
Subject: [PATCH 09/16] mpt : remove tabs, trailing whitespace

---
 convert-mpt-hf-to-gguf.py | 2 +-
 llama.cpp                 | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
index 60bceb0fa98b4..cbe4a9f110498 100755
--- a/convert-mpt-hf-to-gguf.py
+++ b/convert-mpt-hf-to-gguf.py
@@ -121,7 +121,7 @@ def parse_args() -> argparse.Namespace:
 gguf_writer.add_block_count(block_count)
 gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
 gguf_writer.add_head_count(hparams["n_heads"])
-gguf_writer.add_layer_norm_eps(1e-05) 
+gguf_writer.add_layer_norm_eps(1e-05)
 if hparams["attn_config"]["clip_qkv"] is not None:
     gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
 gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])
diff --git a/llama.cpp b/llama.cpp
index 81a014d0a29c9..ede95f60763a4 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -4304,11 +4304,11 @@ static struct ggml_cgraph * llm_build_mpt(
             offload_func_kq(KQ_scaled);
             ggml_set_name(KQ_scaled, "KQ_scaled");
 
-			// TODO: replace with ggml_add()
-			struct ggml_tensor * KQ_scaled_alibi =
-				ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias);
+            // TODO: replace with ggml_add()
+            struct ggml_tensor * KQ_scaled_alibi =
+                ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias);
             offload_func_kq(KQ_scaled_alibi);
-			ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
+            ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
 
             struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
             offload_func_kq(KQ_masked);

From 1364bcd712ef4d9c5ca905f4901d16b1dba695c4 Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Tue, 3 Oct 2023 21:53:31 +0200
Subject: [PATCH 10/16] mpt : removed ne01 + n_past == ne00 assertion from
 alibi (cuda/f32) and rope_shift from build_mpt

---
 ggml-cuda.cu |  4 ++--
 ggml.c       |  4 ++--
 llama.cpp    | 32 +-------------------------------
 3 files changed, 5 insertions(+), 35 deletions(-)

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 0c873375d0b75..967fb5dca32e4 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -6295,12 +6295,12 @@ inline void ggml_cuda_op_alibi(
     const int64_t ne02 = src0->ne[2];
     const int64_t nrows = ggml_nrows(src0);
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-    GGML_ASSERT(ne01 + n_past == ne00);
+    //GGML_ASSERT(ne01 + n_past == ne00);
     GGML_ASSERT(n_head == ne02);
 
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
diff --git a/ggml.c b/ggml.c
index 820fe2e74b0ae..323b84974cafa 100644
--- a/ggml.c
+++ b/ggml.c
@@ -12889,7 +12889,7 @@ static void ggml_compute_forward_alibi_f32(
         return;
     }
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -12910,7 +12910,7 @@ static void ggml_compute_forward_alibi_f32(
     //const int nb3 = src0->nb[3];
 
     GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(ne1 + n_past == ne0);
+    //GGML_ASSERT(ne1 + n_past == ne0);
     GGML_ASSERT(n_head == ne2);
 
     // add alibi to src0 (KQ_scaled)
diff --git a/llama.cpp b/llama.cpp
index ede95f60763a4..b1b1b801de89f 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -4076,8 +4076,6 @@ static struct ggml_cgraph * llm_build_mpt(
     const int32_t n_kv     = ggml_allocr_is_measure(lctx.alloc) ? n_ctx            : kv_self.n;
     const int32_t kv_head  = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
 
-    const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
-
     //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
     //        kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
 
@@ -4176,34 +4174,6 @@ static struct ggml_cgraph * llm_build_mpt(
         }
     }
 
-    // shift the entire K-cache if needed
-    // TODO: Do we need to handle it? (MPT uses alibi instead of rope)
-/*    if (do_rope_shift) {
-        struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
-        offload_func_kq(K_shift);
-        ggml_set_name(K_shift, "K_shift");
-        ggml_allocr_alloc(lctx.alloc, K_shift);
-        if (!ggml_allocr_is_measure(lctx.alloc)) {
-            int * data = (int *) K_shift->data;
-            for (int i = 0; i < n_ctx; ++i) {
-                data[i] = kv_self.cells[i].delta;
-            }
-        }
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * tmp =
-                    ggml_rope_custom_inplace(ctx0,
-                        ggml_view_3d(ctx0, kv_self.k,
-                            n_embd_head, n_head_kv, n_ctx,
-                            ggml_element_size(kv_self.k)*n_embd_head,
-                            ggml_element_size(kv_self.k)*n_embd_gqa,
-                            ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
-                        K_shift, n_embd_head, 2, 0, freq_base, freq_scale);
-            offload_func_kq(tmp);
-            ggml_build_forward_expand(gf, tmp);
-        }
-    }*/
-
     for (int il = 0; il < n_layer; ++il) {
         struct ggml_tensor * attn_norm;
 
@@ -4306,7 +4276,7 @@ static struct ggml_cgraph * llm_build_mpt(
 
             // TODO: replace with ggml_add()
             struct ggml_tensor * KQ_scaled_alibi =
-                ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias);
+                ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias);
             offload_func_kq(KQ_scaled_alibi);
             ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
 

From 7d6a24aad4d2eae524bd3472290fbfb3efab5510 Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Fri, 6 Oct 2023 13:53:32 +0200
Subject: [PATCH 11/16] mpt : updated convert-mpt-hf-to-gguf.py to reflect
 changes made to convert-gptneox-hf-to-gguf.py in pr:3252

---
 convert-mpt-hf-to-gguf.py | 55 +++++----------------------------------
 1 file changed, 7 insertions(+), 48 deletions(-)

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
index cbe4a9f110498..a6a049bc922e2 100755
--- a/convert-mpt-hf-to-gguf.py
+++ b/convert-mpt-hf-to-gguf.py
@@ -19,29 +19,6 @@
     sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
 import gguf
 
-# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
-
-
-def bytes_to_unicode():
-    """
-    Returns list of utf-8 byte and a corresponding list of unicode strings.
-    The reversible bpe codes work on unicode strings.
-    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
-    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
-    This is a significant percentage of your normal, say, 32K bpe vocab.
-    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
-    And avoids mapping to whitespace/control characters the bpe code barfs on.
-    """
-    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
-    cs = bs[:]
-    n = 0
-    for b in range(2**8):
-        if b not in bs:
-            bs.append(b)
-            cs.append(2**8+n)
-            n += 1
-    return dict(zip(bs, (chr(n) for n in cs)))
-
 
 def count_model_parts(dir_model: Path) -> int:
     num_parts = 0
@@ -131,6 +108,8 @@ def parse_args() -> argparse.Namespace:
 print("gguf: get tokenizer metadata")
 
 tokens: list[bytearray] = []
+scores: list[float] = []
+toktypes: list[int] = []
 
 tokenizer_json_file = dir_model / 'tokenizer.json'
 if not tokenizer_json_file.is_file():
@@ -155,31 +134,15 @@ def parse_args() -> argparse.Namespace:
 tokenizer = AutoTokenizer.from_pretrained(dir_model)
 
 reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
-byte_encoder = bytes_to_unicode()
-byte_decoder = {v: k for k, v in byte_encoder.items()}
 
 for i in range(vocab_size):
-    if i in reverse_vocab:
-        try:
-            text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
-        except KeyError:
-            text = bytearray()
-            for c in reverse_vocab[i]:
-                if ord(c) < 256:  # single byte character
-                    try:
-                        text.append(byte_decoder[c])
-                    except KeyError:
-                        text.extend(c.encode('utf-8'))
-                else:  # multibyte special token character
-                    text.extend(c.encode('utf-8'))
-    else:
-        print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token. (It's normal for MPT.)")
-        pad_token = f"[PAD{i}]".encode("utf8")
-        text = bytearray(pad_token)
-
-    tokens.append(text)
+    tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
+    scores.append(0.0) # dummy
+    toktypes.append(gguf.TokenType.NORMAL)
 
 gguf_writer.add_token_list(tokens)
+gguf_writer.add_token_scores(scores)
+gguf_writer.add_token_types(toktypes)
 
 special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
 special_vocab.add_to_gguf(gguf_writer)
@@ -239,10 +202,6 @@ def parse_args() -> argparse.Namespace:
 
         print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
 
-
-#        if new_name == "wte.weight" and data.shape[0] == 50432 and vocab_size == 50254:
-#            data = data[0:vocab_size,:]
-
         gguf_writer.add_tensor(new_name, data)
 
         # note: MPT output is tied to (same as) wte in original model;

From ad3c2f3b23164420c93694684be7e906811631d1 Mon Sep 17 00:00:00 2001
From: Cebtenzzre <cebtenzzre@gmail.com>
Date: Mon, 9 Oct 2023 10:16:24 -0400
Subject: [PATCH 12/16] comment out n_past instead of marking it unused

---
 ggml-metal.m | 2 +-
 ggml.c       | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/ggml-metal.m b/ggml-metal.m
index 5a23144d0c891..87fa172161405 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute(
 
                             const int nth = MIN(1024, ne00);
 
-                            const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+                            //const int n_past = ((int32_t *) dst->op_params)[0];
                             const int n_head = ((int32_t *) dst->op_params)[1];
                             float max_bias;
                             memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
diff --git a/ggml.c b/ggml.c
index 5bb1da31ba624..8cdab8854f756 100644
--- a/ggml.c
+++ b/ggml.c
@@ -13059,7 +13059,7 @@ static void ggml_compute_forward_alibi_f32(
         return;
     }
 
-    const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

From 1a454eb561d19b7dba5d377feb48144c17128f5b Mon Sep 17 00:00:00 2001
From: Jan Ploski <jpl@plosquare.com>
Date: Mon, 9 Oct 2023 16:48:01 +0200
Subject: [PATCH 13/16] mpt : removed hardcoded +178 from convert script in
 favor of utilizing hparams["vocab_size"]

---
 convert-mpt-hf-to-gguf.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
index a6a049bc922e2..6d318dbb63e28 100755
--- a/convert-mpt-hf-to-gguf.py
+++ b/convert-mpt-hf-to-gguf.py
@@ -124,11 +124,13 @@ def parse_args() -> argparse.Namespace:
 
 print("gguf: get gpt2 tokenizer vocab")
 
-# MPT token embedding tensors have dimension 50432, but there are only 50254
+# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]),
+# but there are only 50254 (len(tokenizer_json["model"]["vocab"]))
 # tokens in the vocab, presumably to accomodate some "reserved" tokens;
-# this is causing problems down the line in llama.cpp, so we extend the vocab_size:
+# this is causing problems down the line in llama.cpp, so we pad the vocab
+# with dummy tokens:
 
-vocab_size = len(tokenizer_json["model"]["vocab"]) + 178
+vocab_size = hparams["vocab_size"]
 
 # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
 tokenizer = AutoTokenizer.from_pretrained(dir_model)

From 32172f12f57e02b1a233822c6999be592b8a739e Mon Sep 17 00:00:00 2001
From: Cebtenzzre <cebtenzzre@gmail.com>
Date: Mon, 9 Oct 2023 11:39:53 -0400
Subject: [PATCH 14/16] mpt : remove unused tokenizer_json in convert script

---
 convert-mpt-hf-to-gguf.py | 17 ++++-------------
 1 file changed, 4 insertions(+), 13 deletions(-)

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
index 6d318dbb63e28..73a4932f7c831 100755
--- a/convert-mpt-hf-to-gguf.py
+++ b/convert-mpt-hf-to-gguf.py
@@ -111,24 +111,15 @@ def parse_args() -> argparse.Namespace:
 scores: list[float] = []
 toktypes: list[int] = []
 
-tokenizer_json_file = dir_model / 'tokenizer.json'
-if not tokenizer_json_file.is_file():
-    print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
-    sys.exit(1)
-
 # gpt2 tokenizer
 gguf_writer.add_tokenizer_model("gpt2")
 
-with open(tokenizer_json_file, "r", encoding="utf-8") as f:
-    tokenizer_json = json.load(f)
-
 print("gguf: get gpt2 tokenizer vocab")
 
-# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]),
-# but there are only 50254 (len(tokenizer_json["model"]["vocab"]))
-# tokens in the vocab, presumably to accomodate some "reserved" tokens;
-# this is causing problems down the line in llama.cpp, so we pad the vocab
-# with dummy tokens:
+# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but
+# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to
+# accomodate some "reserved" tokens; this is causing problems down the line in
+# llama.cpp, so we pad the vocab with dummy tokens:
 
 vocab_size = hparams["vocab_size"]
 

From 96cf3f5dc3e145a9555df377947ac57ecabaa708 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Tue, 10 Oct 2023 10:45:24 +0300
Subject: [PATCH 15/16] ggml : remove obsolete n_past assert in ggml_alibi

---
 ggml.c | 2 --
 1 file changed, 2 deletions(-)

diff --git a/ggml.c b/ggml.c
index 8cdab8854f756..1f5598fa6af8f 100644
--- a/ggml.c
+++ b/ggml.c
@@ -13064,8 +13064,6 @@ static void ggml_compute_forward_alibi_f32(
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-    assert(n_past >= 0);
-
     const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int64_t ne1 = src0->ne[1]; // seq_len_without_past
     const int64_t ne2 = src0->ne[2]; // n_head -> this is k

From 9b66378cacfb0a4e0401e7682ec66f6ab89348ed Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Tue, 10 Oct 2023 10:49:39 +0300
Subject: [PATCH 16/16] llama : print clam_kqv and max_alibi_bias hparams

---
 llama.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llama.cpp b/llama.cpp
index eaa008dbfaa00..3b63b64010b0f 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2229,6 +2229,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: n_gqa            = %u\n",     __func__, hparams.n_gqa());
     LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
     LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
+    LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
+    LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
@@ -5013,8 +5015,8 @@ static struct ggml_cgraph * llm_build_mpt(
     const int64_t n_embd_head = hparams.n_embd_head();
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();
 
-    const float norm_eps = hparams.f_norm_eps;
-    const float clamp_kqv = hparams.f_clamp_kqv;
+    const float norm_eps       = hparams.f_norm_eps;
+    const float clamp_kqv      = hparams.f_clamp_kqv;
     const float max_alibi_bias = hparams.f_max_alibi_bias;
 
     const int n_gpu_layers = model.n_gpu_layers;