Skip to content

Commit ee4fb35

Browse files
committed
feat: support --tensor-type-rules on generation modes
1 parent 8f6c5c2 commit ee4fb35

File tree

5 files changed

+56
-46
lines changed

5 files changed

+56
-46
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,10 +1163,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
11631163
exit(1);
11641164
}
11651165

1166-
if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
1167-
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
1168-
}
1169-
11701166
if (params.mode == VID_GEN && params.video_frames <= 0) {
11711167
fprintf(stderr, "warning: --video-frames must be at least 1\n");
11721168
exit(1);
@@ -1643,6 +1639,7 @@ int main(int argc, const char* argv[]) {
16431639
params.lora_model_dir.c_str(),
16441640
params.embedding_dir.c_str(),
16451641
params.photo_maker_path.c_str(),
1642+
params.tensor_type_rules.c_str(),
16461643
vae_decode_only,
16471644
true,
16481645
params.n_threads,

model.cpp

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,15 +1877,59 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
18771877
return wtype_stat;
18781878
}
18791879

1880-
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
1880+
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
1881+
std::vector<std::pair<std::string, ggml_type>> result;
1882+
for (const auto& item : split_string(tensor_type_rules, ',')) {
1883+
if (item.size() == 0)
1884+
continue;
1885+
std::string::size_type pos = item.find('=');
1886+
if (pos == std::string::npos) {
1887+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
1888+
continue;
1889+
}
1890+
std::string tensor_pattern = item.substr(0, pos);
1891+
std::string type_name = item.substr(pos + 1);
1892+
1893+
ggml_type tensor_type = GGML_TYPE_COUNT;
1894+
1895+
if (type_name == "f32") {
1896+
tensor_type = GGML_TYPE_F32;
1897+
} else {
1898+
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
1899+
auto trait = ggml_get_type_traits((ggml_type)i);
1900+
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
1901+
tensor_type = (ggml_type)i;
1902+
}
1903+
}
1904+
}
1905+
1906+
if (tensor_type != GGML_TYPE_COUNT) {
1907+
result.emplace_back(tensor_pattern, tensor_type);
1908+
} else {
1909+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
1910+
}
1911+
}
1912+
return result;
1913+
}
1914+
1915+
void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_rules) {
1916+
auto map_rules = parse_tensor_type_rules(tensor_type_rules);
18811917
for (auto& [name, tensor_storage] : tensor_storage_map) {
1882-
if (!starts_with(name, prefix)) {
1918+
ggml_type dst_type = wtype;
1919+
for (const auto& tensor_type_rule : map_rules) {
1920+
std::regex pattern(tensor_type_rule.first);
1921+
if (std::regex_search(pair.first, pattern)) {
1922+
dst_type = tensor_type_rule.second;
1923+
break;
1924+
}
1925+
}
1926+
if (dst_type == GGML_TYPE_COUNT) {
18831927
continue;
18841928
}
1885-
if (!tensor_should_be_converted(tensor_storage, wtype)) {
1929+
if (!tensor_should_be_converted(tensor_storage, dst_type)) {
18861930
continue;
18871931
}
1888-
tensor_storage.expected_type = wtype;
1932+
tensor_storage.expected_type = dst_type;
18891933
}
18901934
}
18911935

@@ -2226,41 +2270,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
22262270
return true;
22272271
}
22282272

2229-
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
2230-
std::vector<std::pair<std::string, ggml_type>> result;
2231-
for (const auto& item : split_string(tensor_type_rules, ',')) {
2232-
if (item.size() == 0)
2233-
continue;
2234-
std::string::size_type pos = item.find('=');
2235-
if (pos == std::string::npos) {
2236-
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
2237-
continue;
2238-
}
2239-
std::string tensor_pattern = item.substr(0, pos);
2240-
std::string type_name = item.substr(pos + 1);
2241-
2242-
ggml_type tensor_type = GGML_TYPE_COUNT;
2243-
2244-
if (type_name == "f32") {
2245-
tensor_type = GGML_TYPE_F32;
2246-
} else {
2247-
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
2248-
auto trait = ggml_get_type_traits((ggml_type)i);
2249-
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
2250-
tensor_type = (ggml_type)i;
2251-
}
2252-
}
2253-
}
2254-
2255-
if (tensor_type != GGML_TYPE_COUNT) {
2256-
result.emplace_back(tensor_pattern, tensor_type);
2257-
} else {
2258-
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
2259-
}
2260-
}
2261-
return result;
2262-
}
2263-
22642273
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
22652274
const std::string& name = tensor_storage.name;
22662275
if (type != GGML_TYPE_COUNT) {

model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ class ModelLoader {
281281
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
282282
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
283283
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
284-
void set_wtype_override(ggml_type wtype, std::string prefix = "");
284+
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
285285
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
286286
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
287287
std::set<std::string> ignore_tensors = {},

stable-diffusion.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,9 @@ class StableDiffusionGGML {
286286
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
287287
? (ggml_type)sd_ctx_params->wtype
288288
: GGML_TYPE_COUNT;
289-
if (wtype != GGML_TYPE_COUNT) {
290-
model_loader.set_wtype_override(wtype);
289+
std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules);
290+
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
291+
model_loader.set_wtype_override(wtype, tensor_type_rules);
291292
}
292293

293294
std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
@@ -1893,6 +1894,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
18931894
"lora_model_dir: %s\n"
18941895
"embedding_dir: %s\n"
18951896
"photo_maker_path: %s\n"
1897+
"tensor_type_rules: %s\n"
18961898
"vae_decode_only: %s\n"
18971899
"free_params_immediately: %s\n"
18981900
"n_threads: %d\n"
@@ -1922,6 +1924,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
19221924
SAFE_STR(sd_ctx_params->lora_model_dir),
19231925
SAFE_STR(sd_ctx_params->embedding_dir),
19241926
SAFE_STR(sd_ctx_params->photo_maker_path),
1927+
SAFE_STR(sd_ctx_params->tensor_type_rules),
19251928
BOOL_STR(sd_ctx_params->vae_decode_only),
19261929
BOOL_STR(sd_ctx_params->free_params_immediately),
19271930
sd_ctx_params->n_threads,

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ typedef struct {
151151
const char* lora_model_dir;
152152
const char* embedding_dir;
153153
const char* photo_maker_path;
154+
const char* tensor_type_rules;
154155
bool vae_decode_only;
155156
bool free_params_immediately;
156157
int n_threads;

0 commit comments

Comments
 (0)