@@ -1877,15 +1877,59 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
18771877 return wtype_stat;
18781878}
18791879
1880- void ModelLoader::set_wtype_override (ggml_type wtype, std::string prefix) {
1880+ static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules (const std::string& tensor_type_rules) {
1881+ std::vector<std::pair<std::string, ggml_type>> result;
1882+ for (const auto & item : split_string (tensor_type_rules, ' ,' )) {
1883+ if (item.size () == 0 )
1884+ continue ;
1885+ std::string::size_type pos = item.find (' =' );
1886+ if (pos == std::string::npos) {
1887+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
1888+ continue ;
1889+ }
1890+ std::string tensor_pattern = item.substr (0 , pos);
1891+ std::string type_name = item.substr (pos + 1 );
1892+
1893+ ggml_type tensor_type = GGML_TYPE_COUNT;
1894+
1895+ if (type_name == " f32" ) {
1896+ tensor_type = GGML_TYPE_F32;
1897+ } else {
1898+ for (size_t i = 0 ; i < GGML_TYPE_COUNT; i++) {
1899+ auto trait = ggml_get_type_traits ((ggml_type)i);
1900+ if (trait->to_float && trait->type_size && type_name == trait->type_name ) {
1901+ tensor_type = (ggml_type)i;
1902+ }
1903+ }
1904+ }
1905+
1906+ if (tensor_type != GGML_TYPE_COUNT) {
1907+ result.emplace_back (tensor_pattern, tensor_type);
1908+ } else {
1909+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
1910+ }
1911+ }
1912+ return result;
1913+ }
1914+
1915+ void ModelLoader::set_wtype_override (ggml_type wtype, std::string tensor_type_rules) {
1916+ auto map_rules = parse_tensor_type_rules (tensor_type_rules);
18811917 for (auto & [name, tensor_storage] : tensor_storage_map) {
1882- if (!starts_with (name, prefix)) {
1918+ ggml_type dst_type = wtype;
1919+ for (const auto & tensor_type_rule : map_rules) {
1920+ std::regex pattern (tensor_type_rule.first );
1921+ if (std::regex_search (pair.first , pattern)) {
1922+ dst_type = tensor_type_rule.second ;
1923+ break ;
1924+ }
1925+ }
1926+ if (dst_type == GGML_TYPE_COUNT) {
18831927 continue ;
18841928 }
1885- if (!tensor_should_be_converted (tensor_storage, wtype )) {
1929+ if (!tensor_should_be_converted (tensor_storage, dst_type )) {
18861930 continue ;
18871931 }
1888- tensor_storage.expected_type = wtype ;
1932+ tensor_storage.expected_type = dst_type ;
18891933 }
18901934}
18911935
@@ -2226,41 +2270,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
22262270 return true ;
22272271}
22282272
2229- std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules (const std::string& tensor_type_rules) {
2230- std::vector<std::pair<std::string, ggml_type>> result;
2231- for (const auto & item : split_string (tensor_type_rules, ' ,' )) {
2232- if (item.size () == 0 )
2233- continue ;
2234- std::string::size_type pos = item.find (' =' );
2235- if (pos == std::string::npos) {
2236- LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
2237- continue ;
2238- }
2239- std::string tensor_pattern = item.substr (0 , pos);
2240- std::string type_name = item.substr (pos + 1 );
2241-
2242- ggml_type tensor_type = GGML_TYPE_COUNT;
2243-
2244- if (type_name == " f32" ) {
2245- tensor_type = GGML_TYPE_F32;
2246- } else {
2247- for (size_t i = 0 ; i < GGML_TYPE_COUNT; i++) {
2248- auto trait = ggml_get_type_traits ((ggml_type)i);
2249- if (trait->to_float && trait->type_size && type_name == trait->type_name ) {
2250- tensor_type = (ggml_type)i;
2251- }
2252- }
2253- }
2254-
2255- if (tensor_type != GGML_TYPE_COUNT) {
2256- result.emplace_back (tensor_pattern, tensor_type);
2257- } else {
2258- LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
2259- }
2260- }
2261- return result;
2262- }
2263-
22642273bool ModelLoader::tensor_should_be_converted (const TensorStorage& tensor_storage, ggml_type type) {
22652274 const std::string& name = tensor_storage.name ;
22662275 if (type != GGML_TYPE_COUNT) {
0 commit comments