Skip to content

quantize: improve pattern matching for allowed tensors #13033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 7 additions & 75 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "common.h"
#include "llama.h"
#include "llama-quant.h"

#include <cstdio>
#include <cstring>
Expand Down Expand Up @@ -248,52 +249,6 @@ static ggml_type parse_ggml_type(const char * arg) {
return GGML_TYPE_COUNT;
}

// Allowed tensors for arbitrary quantization with --tensor-type option
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
"attn_k",
"attn_kv_a_mqa",
"attn_kv_b",
"attn_o",
"attn_output",
"attn_q",
"attn_q_a",
"attn_q_b",
"attn_qkv",
"attn_v",
"channel_mix_key",
"channel_mix_receptance",
"channel_mix_value",
"cls",
"cls.output",
"cross_attn_k",
"cross_attn_o",
"cross_attn_q",
"cross_attn_v",
"ffn_act",
"ffn_down",
"ffn_down_exps",
"ffn_down_shexp",
"ffn_gate",
"ffn_gate_exps",
"ffn_gate_shexp",
"ffn_up",
"ffn_up_exps",
"ffn_up_shexp",
"ssm_in",
"ssm_out",
"time_mix_gate",
"time_mix_key",
"time_mix_output",
"time_mix_receptance",
"time_mix_value",
};

// changes to this struct must be replicated in llama-quant.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};

static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
const char * sep = strchr(data, '=');
if (sep == nullptr) {
Expand All @@ -306,7 +261,6 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
printf("\n%s: missing tensor name\n\n", __func__);
return false;
}

if (const size_t qt_len = strlen(sep); qt_len == 1) {
printf("\n%s: missing quantization type\n\n", __func__);
return false;
Expand All @@ -315,37 +269,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
std::string tn(data, tn_len);
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
sep++;
const std::string qt(sep);

bool found = false;
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
std::string tensor;
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
// handle special case of cls.output
std::string cls_output = "cls.output";
if (tn.find(cls_output) != std::string::npos) {
tensor = "cls.output";
}
// check if an allowed tensor exists and it's at the end of the kv string
if (tensor == allowed) {
found = true;
break;
}
}
if (!found) {
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
return false;
}

if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
return false;
}

tensor_quantization tqz;
tqz.name = tn;
tqz.quant = parse_ggml_type(qt.c_str());
tqz.quant = parse_ggml_type(sep);
tensor_type.emplace_back(std::move(tqz));
if (tqz.quant == GGML_TYPE_COUNT) {
printf("\n%s: invalid quantization type '%s'\n\n", __func__, sep);
return false;
}

return true;
}

Expand Down
24 changes: 13 additions & 11 deletions src/llama-quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ struct quantize_state_impl {
{}
};

// changes to this struct must be replicated in quantize.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};

static void llama_tensor_dequantize_impl(
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
const size_t nelements, const int nthread
Expand Down Expand Up @@ -796,17 +790,25 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// unless the user specifies a type
if (params->tensor_types) {
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
const std::string tensor_name(tensor->name);
for (const auto & [tname, qtype] : tensor_types) {
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
if (qtype != new_type) {
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
if (tensor_name.find(allowed) != std::string::npos) {
if (qtype != new_type) {
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
new_type = qtype;
break;
}
}
}
new_type = qtype;
break;
goto loop_exit; // if two or more types are specified for the tensor, first match wins
}
}
}
loop_exit:;
}

if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
new_type = params->token_embedding_type;
}
Expand Down
65 changes: 65 additions & 0 deletions src/llama-quant.h
Original file line number Diff line number Diff line change
@@ -1 +1,66 @@
#pragma once

#include <string>
#include <vector>

#include "ggml.h"

// Allowed tensors for arbitrary quantization with --tensor-type option
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
"attn_k",
"attn_k_b",
"attn_kv_a_mqa",
"attn_kv_b",
"attn_o",
"attn_output",
"attn_q",
"attn_q_a",
"attn_q_b",
"attn_qkv",
"attn_rel_b",
"attn_v",
"attn_v_b",
"channel_mix_key",
"channel_mix_receptance",
"channel_mix_value",
"cls",
"cls.output",
"conv1",
"conv1d",
"conv2",
"cross_attn_k",
"cross_attn_o",
"cross_attn_q",
"cross_attn_rel_b",
"cross_attn_v",
"dw",
"ffn_down",
"ffn_down_exps",
"ffn_down_shexp",
"ffn_gate",
"ffn_gate_exps",
"ffn_gate_shexp",
"ffn_up",
"ffn_up_exps",
"ffn_up_shexp",
"pw1",
"pw1",
"ssm_a",
"ssm_conv1d",
"ssm_dt",
"ssm_in",
"ssm_out",
"ssm_x",
"time_mix_gate",
"time_mix_key",
"time_mix_output",
"time_mix_receptance",
"time_mix_value",
"token_types"
};

// Quantization types
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
Loading