From af19099ab168b9dd4f16f08d385db0018eca54f0 Mon Sep 17 00:00:00 2001 From: ds5t5 Date: Fri, 29 Sep 2023 01:13:41 -0700 Subject: [PATCH] rebase to the latest --- convert-refact-hf-to-gguf.py | 23 +++++++- gguf-py/gguf/gguf.py | 9 +-- llama.cpp | 110 ++++++++++++++++++++--------------- 3 files changed, 87 insertions(+), 55 deletions(-) diff --git a/convert-refact-hf-to-gguf.py b/convert-refact-hf-to-gguf.py index 5a876c248cb3a..e0cd417dbbbc4 100755 --- a/convert-refact-hf-to-gguf.py +++ b/convert-refact-hf-to-gguf.py @@ -6,10 +6,8 @@ import argparse import json import os -import struct import sys from pathlib import Path -from typing import Any import numpy as np import torch @@ -235,6 +233,27 @@ def parse_args() -> argparse.Namespace: print("gguf: loading model part '" + part_name + "'") model_part = torch.load(dir_model / part_name, map_location="cpu") + for i in range(block_count): + if f"transformer.h.{i}.attn.kv.weight" in model_part: + data = model_part[f"transformer.h.{i}.attn.kv.weight"] + model_part[f"model.layers.{i}.self_attn.k_proj.weight"] = data[ + : n_head_kv * head_dim + ] + model_part[f"model.layers.{i}.self_attn.v_proj.weight"] = data[ + n_head_kv * head_dim : + ] + del model_part[f"transformer.h.{i}.attn.kv.weight"] + if f"transformer.h.{i}.attn.q.weight" in model_part: + model_part[f"model.layers.{i}.self_attn.q_proj.weight"] = model_part[ + f"transformer.h.{i}.attn.q.weight" + ] + del model_part[f"transformer.h.{i}.attn.q.weight"] + if f"transformer.h.{i}.mlp.gate_up_proj.weight" in model_part: + data = model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"] + model_part[f"model.layers.{i}.mlp.gate_proj.weight"] = data[:ff_dim] + model_part[f"model.layers.{i}.mlp.up_proj.weight"] = data[ff_dim:] + del model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"] + for name in model_part.keys(): data = model_part[name] diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index b4b52d4add407..f7e96c9833db7 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -286,21 +286,18 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( "model.layers.{bid}.self_attn.q_proj", # llama-hf - "transformer.h.{bid}.attn.q", # refact "layers.{bid}.attention.wq", # llama-pth ), # Attention key MODEL_TENSOR.ATTN_K: ( "model.layers.{bid}.self_attn.k_proj", # llama-hf - "transformer.h.{bid}.attn.k", # refact "layers.{bid}.attention.wk", # llama-pth ), # Attention value MODEL_TENSOR.ATTN_V: ( "model.layers.{bid}.self_attn.v_proj", # llama-hf - "transformer.h.{bid}.attn.v", # refact "layers.{bid}.attention.wv", # llama-pth ), @@ -335,15 +332,13 @@ class TensorNameMap: "transformer.h.{bid}.mlp.c_fc", # gpt2 "transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon - "model.layers.{bid}.mlp.up_proj", # llama-hf + "model.layers.{bid}.mlp.up_proj", # llama-hf refact "layers.{bid}.feed_forward.w3", # llama-pth - "transformer.h.{bid}.mlp.linear_3", # refact ), # Feed-forward gate MODEL_TENSOR.FFN_GATE: ( - "model.layers.{bid}.mlp.gate_proj", # llama-hf - "transformer.h.{bid}.mlp.linear_1", # refact + "model.layers.{bid}.mlp.gate_proj", # llama-hf refact "layers.{bid}.feed_forward.w1", # llama-pth ), diff --git a/llama.cpp b/llama.cpp index 71fcec48302c7..dc50a19a75c7b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3369,17 +3369,10 @@ static struct ggml_cgraph * llm_build_baichaun( static struct ggml_cgraph * llm_build_refact( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -3387,7 +3380,7 @@ static struct ggml_cgraph * llm_build_refact( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -3397,6 +3390,12 @@ static struct ggml_cgraph * llm_build_refact( const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + // printf("n_kv = %d\n", n_kv); + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -3414,12 +3413,12 @@ static struct ggml_cgraph * llm_build_refact( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -3429,11 +3428,11 @@ static struct ggml_cgraph * llm_build_refact( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -3442,9 +3441,6 @@ static struct ggml_cgraph * llm_build_refact( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -3461,12 +3457,36 @@ static struct ggml_cgraph * llm_build_refact( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -3504,36 +3524,33 @@ static struct ggml_cgraph * llm_build_refact( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur; - struct ggml_tensor * Qcur; - Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N); - Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N); - + struct ggml_tensor * Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); + struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); // store key and value to memory { - // compute the transposed [N, n_embd] V matrix + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -3547,7 +3564,7 @@ static struct ggml_cgraph * llm_build_refact( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3560,25 +3577,28 @@ static struct ggml_cgraph * llm_build_refact( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + // KQ_scaled shape [n_kv, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - struct ggml_tensor * KQ_masked; - struct ggml_tensor * KQ_scaled_alibi; - - KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8); + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8); ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); - KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3593,7 +3613,7 @@ static struct ggml_cgraph * llm_build_refact( // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif @@ -3602,10 +3622,8 @@ static struct ggml_cgraph * llm_build_refact( offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -4338,7 +4356,7 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_REFACT: { - result = llm_build_refact(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_refact(lctx, batch); } break; default: GGML_ASSERT(false);