Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4-bit Integer quantisation #27

Merged
merged 38 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8f45628
gq : attempt at n-bit quantization
ggerganov Feb 21, 2023
b0a46fd
gq : add amax based method 3
ggerganov Feb 22, 2023
da2de94
gq : progress on method 2
ggerganov Feb 22, 2023
aa5506c
gq : method 4 (AVX2)
ggerganov Feb 23, 2023
1fc11de
gq : method 4 (ARM)
ggerganov Feb 23, 2023
dae323c
gq : method 4 (AVX2 attempt) + method 5 (no min)
ggerganov Feb 24, 2023
349e917
gq : method 5 (ARM)
ggerganov Feb 24, 2023
ff4c653
gpt-2 : model conversion for Q4_0 quantization
ggerganov Feb 25, 2023
21514b7
ggml : Q4_0 quantization support (ggml_get_rows())
ggerganov Feb 25, 2023
ff54fda
gpt-2 : loading Q4_0 quantized model
ggerganov Feb 25, 2023
2219c11
ggml : q4_0 quantization support
ggerganov Feb 25, 2023
5bdfce2
ggml : q4_1 quantization support (seems to work for bigger models)
ggerganov Feb 25, 2023
b0cab89
gpt-2 : add gpt-2-quantize tool for quantizing f32 GPT-2 models
ggerganov Feb 25, 2023
2f11888
ggml : 4-bit quantization works (only scalar for now)
ggerganov Feb 25, 2023
b82c27f
gq : add method 6 (ARM)
ggerganov Feb 25, 2023
2e75d8f
ggml : vectorized mad q4_0 (ARM)
ggerganov Feb 25, 2023
b0c22a4
ggml : vectorized quantize_row_q4_0 (ARM)
ggerganov Feb 26, 2023
3c757a4
ggml : simplify mad q4_0 (ARM)
ggerganov Feb 26, 2023
e3ad879
ggml : minor indentations
ggerganov Feb 26, 2023
8abcab4
gpt-j : support for 4-bit quantized model inference
ggerganov Feb 26, 2023
c21972c
ggml : GGML_ASSERT() instead of assert() where appropriate
ggerganov Feb 26, 2023
4a56c5b
gpt : avoid ggml_transpose on model tensors (new models!)
ggerganov Feb 26, 2023
904605c
gpt-2 : minor
ggerganov Feb 26, 2023
441a38f
gpt-j : fix conversion for FP16 models (such as GPT-JT-6B)
ggerganov Feb 26, 2023
5336828
ggml : add ggml_compute_forward_rope_f16()
ggerganov Feb 26, 2023
99af48e
gpt : fix memory usage computation
ggerganov Feb 26, 2023
6aae09e
ggml : fix ggml_is_contiguous() to take into account blck size
ggerganov Feb 26, 2023
d0ac5eb
whisper : add whisper-qunatize tool
ggerganov Feb 26, 2023
37d427d
whisper : add support for quantized models
ggerganov Feb 26, 2023
e904a58
whisper : mem usage based on model format type
ggerganov Feb 26, 2023
63a8f62
gpt : seems not worth to use FP16 for KV cache
ggerganov Feb 26, 2023
519ce47
gpt : support quantisation of f16 models files
ggerganov Feb 26, 2023
9881c2b
ggml : fixes for rpi4
ggerganov Feb 26, 2023
a85bc0f
whisper : add Q4_1 model sizes
ggerganov Feb 26, 2023
c4f1403
ggml : add WASM SIMD for Q4_0
ggerganov Feb 27, 2023
331a862
utils : print quantization histograms
ggerganov Mar 6, 2023
154fcc3
ggml : sync all changes from llama.cpp and whisper.cpp
ggerganov Mar 29, 2023
724c45d
ggml : finalize the Q4_1 quantization for ARM_NEON
ggerganov Mar 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF)
# sanitizers

if (GGML_SANITIZE_THREAD)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread")
endif()

Expand Down
7 changes: 7 additions & 0 deletions examples/gpt-2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@
set(TEST_TARGET gpt-2)
add_executable(${TEST_TARGET} main.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)

#
# gpt-2-quantize

set(TEST_TARGET gpt-2-quantize)
add_executable(${TEST_TARGET} quantize.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)
2 changes: 1 addition & 1 deletion examples/gpt-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Done! Model '117M' saved in 'models/gpt-2-117M/'

Run the convert-ckpt-to-ggml.py script to convert the model to ggml format.

python /Users/john/ggml/examples/gpt-2/convert-ckpt-to-ggml.py models/gpt-2-117M/
python /Users/john/ggml/examples/gpt-2/convert-ckpt-to-ggml.py models/gpt-2-117M/ 1

```

Expand Down
69 changes: 48 additions & 21 deletions examples/gpt-2/convert-ckpt-to-ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,18 @@ def bytes_to_unicode():
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

if len(sys.argv) < 2:
print("Usage: convert-ckpt-to-ggml.py dir-model [use-f32]\n")
# helper method to convert a numpy array to different float types
def convert_to_ftype(data, ftype):
# fp16
if ftype == 1:
return data.astype(np.float16)

assert False, "Invalid ftype: " + str(ftype)

if len(sys.argv) < 3:
print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)

# output in the same directory as the model
Expand All @@ -59,11 +69,20 @@ def bytes_to_unicode():
with open(dir_model + "/hparams.json", "r") as f:
hparams = json.load(f)

# use 16-bit or 32-bit floats
use_f16 = True
# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
use_f16 = False
fname_out = sys.argv[1] + "/ggml-model-f32.bin"
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))
sys.exit(1)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"

list_vars = tf.train.list_variables(dir_model)

Expand All @@ -75,7 +94,7 @@ def bytes_to_unicode():
fout.write(struct.pack("i", hparams["n_embd"]))
fout.write(struct.pack("i", hparams["n_head"]))
fout.write(struct.pack("i", hparams["n_layer"]))
fout.write(struct.pack("i", use_f16))
fout.write(struct.pack("i", ftype))

byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k, v in byte_encoder.items()}
Expand All @@ -93,34 +112,42 @@ def bytes_to_unicode():
data = tf.train.load_variable(dir_model, name).squeeze()
n_dims = len(data.shape);

# ftype == 0 -> float32, ftype == 1 -> float16
ftype = 0;
if use_f16:
# for efficiency - transpose the projection matrices
# "model/h.*/attn/c_attn/w"
# "model/h.*/attn/c_proj/w"
# "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w"
if name[-14:] == "/attn/c_attn/w" or \
name[-14:] == "/attn/c_proj/w" or \
name[-11:] == "/mlp/c_fc/w" or \
name[-13:] == "/mlp/c_proj/w":
print(" Transposing")
data = data.transpose()

dshape = data.shape

ftype_cur = 0
if ftype != 0:
# match name:
# "model/wte"
# "model/h.*/attn/c_attn/w"
# "model/h.*/attn/c_proj/w"
# "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w"
if name == "model/wte" or name[-2:] == "/w":
print(" Converting to float16")
data = data.astype(np.float16)
ftype = 1
print(" Converting to " + ftype_str[ftype])
data = convert_to_ftype(data, ftype)
ftype_cur = ftype
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype = 0

# for efficiency - transpose the projection matrices
if name[-13:] == "/mlp/c_proj/w":
print(" Transposing")
data = data.transpose()
ftype_cur = 0

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype))
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
fout.write(str);

# data
Expand Down
132 changes: 84 additions & 48 deletions examples/gpt-2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct gpt2_layer {
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;

struct ggml_tensor * c_mlp_proj_w_trans; // transposed for efficiency
struct ggml_tensor * c_mlp_proj_w;
struct ggml_tensor * c_mlp_proj_b;
};

Expand Down Expand Up @@ -130,9 +130,23 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
}
}

// for the big tensors, we have the option to store the data in 16-bit floats
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
ggml_type wtype = GGML_TYPE_COUNT;
switch (model.hparams.f16) {
case 0: wtype = GGML_TYPE_F32; break;
case 1: wtype = GGML_TYPE_F16; break;
case 2: wtype = GGML_TYPE_Q4_0; break;
case 3: wtype = GGML_TYPE_Q4_1; break;
default:
{
fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
__func__, fname.c_str(), model.hparams.f16);
return false;
}
}

const ggml_type wtype2 = GGML_TYPE_F32;

auto & ctx = model.ctx;

Expand All @@ -146,32 +160,32 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;

ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b

ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe

ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b

ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b

ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b

ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b

ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F32); // memory_v
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v

ctx_size += (6 + 12*n_layer)*256; // object overhead

Expand Down Expand Up @@ -219,23 +233,23 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];

layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);

layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);

layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, 3*n_embd, n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);

layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);

layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);

layer.c_mlp_proj_w_trans = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);

// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
Expand All @@ -253,7 +267,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;

model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w_trans;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
Expand Down Expand Up @@ -321,17 +335,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
return false;
}

const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
if (0) {
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}

if (nelements*bpe != ggml_nbytes(tensor)) {
size_t bpe = 0;

switch (ftype) {
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
default:
{
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
return false;
}
};

if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}

fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));

//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
}

Expand Down Expand Up @@ -433,7 +463,7 @@ bool gpt2_eval(
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_attn_w),
model.layers[il].c_attn_attn_w,
cur);

cur = ggml_add(ctx0,
Expand Down Expand Up @@ -509,11 +539,13 @@ bool gpt2_eval(
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3);
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));

// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
Expand All @@ -540,7 +572,7 @@ bool gpt2_eval(
// [768, N]
{
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_attn_proj_w),
model.layers[il].c_attn_proj_w,
cur);

cur = ggml_add(ctx0,
Expand Down Expand Up @@ -577,7 +609,7 @@ bool gpt2_eval(
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
ggml_transpose(ctx0, model.layers[il].c_mlp_fc_w),
model.layers[il].c_mlp_fc_w,
cur);

cur = ggml_add(ctx0,
Expand All @@ -597,7 +629,7 @@ bool gpt2_eval(
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w_trans,
model.layers[il].c_mlp_proj_w,
cur);

cur = ggml_add(ctx0,
Expand Down Expand Up @@ -714,8 +746,12 @@ int main(int argc, char ** argv) {

params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());

printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
printf("\n");
printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size());
for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {
printf("%d ", embd_inp[i]);
}
printf("\n\n");

// submit the input prompt token-by-token
// this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
Expand Down
Loading