Skip to content

Commit 8931c4d

Browse files
committed
cleanup load
1 parent 0a12ccc commit 8931c4d

File tree

3 files changed

+7
-113
lines changed

3 files changed

+7
-113
lines changed

models/baichuan.cpp

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,13 @@ namespace m1
322322
conv_v.set_id(id);
323323
}
324324

325+
void load(const std::string &path, TensorLoader *loader) override
326+
{
327+
BaseAttn::load(path, loader);
328+
conv_k.load(path + "conv_k", loader);
329+
conv_v.load(path + "conv_v", loader);
330+
}
331+
325332
public:
326333
FIR2 conv_k;
327334
FIR2 conv_v;
@@ -404,44 +411,6 @@ namespace m1
404411
}
405412
}
406413

407-
void load(ModelLoader &loader) override
408-
{
409-
auto transformer = get_typed_transformer<ModelClass>();
410-
411-
#define LOAD_TENSORS() \
412-
loader.read_tensor(layer_prefix + "input_layernorm.weight", layer->input_layernorm.weight); \
413-
loader.read_tensor(layer_prefix + "mlp.down_proj.weight", layer->mlp.down_proj.weight); \
414-
loader.read_tensor(layer_prefix + "mlp.gate_proj.weight", layer->mlp.gate_proj.weight); \
415-
loader.read_tensor(layer_prefix + "mlp.up_proj.weight", layer->mlp.up_proj.weight); \
416-
loader.read_tensor(layer_prefix + "post_attention_layernorm.weight", layer->post_attention_layernorm.weight); \
417-
loader.read_tensor(layer_prefix + "self_attn.k_proj.weight", layer->attention.k_proj.weight); \
418-
loader.read_tensor(layer_prefix + "self_attn.o_proj.weight", layer->attention.o_proj.weight); \
419-
loader.read_tensor(layer_prefix + "self_attn.q_proj.weight", layer->attention.q_proj.weight); \
420-
loader.read_tensor(layer_prefix + "self_attn.v_proj.weight", layer->attention.v_proj.weight); \
421-
loader.read_tensor(layer_prefix + "self_attn.conv_k", layer->attention.conv_k.weight); \
422-
loader.read_tensor(layer_prefix + "self_attn.conv_v", layer->attention.conv_v.weight);
423-
424-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
425-
for (int i = 0; i < config.num_hidden_layers; i++)
426-
{
427-
std::string layer_prefix = "model.layers." + std::to_string(layer_ids[i]) + '.';
428-
if (is_swa_layer(i))
429-
{
430-
auto layer = dynamic_cast<BaiChuanSWABlock8k *>(transformer->get_layer(i));
431-
LOAD_TENSORS();
432-
}
433-
else
434-
{
435-
auto layer = dynamic_cast<BaiChuanFullBlock *>(transformer->get_layer(i));
436-
LOAD_TENSORS();
437-
}
438-
}
439-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
440-
loader.read_tensor("lm_head.weight", dynamic_cast<Linear *>(transformer->lm_head)->weight);
441-
442-
#undef LOAD_TENSORS
443-
}
444-
445414
private:
446415
bool is_swa_layer(int layer_index) const
447416
{

models/llama.cpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -85,48 +85,6 @@ If a question does not make any sense, or is not factually coherent, explain why
8585
config.intermediate_size, num_key_value_heads, head_dim, max_length);
8686
}
8787

88-
void load(ModelLoader &loader) override
89-
{
90-
switch (type_class)
91-
{
92-
case 1:
93-
load0<ModelClass>(loader);
94-
break;
95-
default:
96-
load0<ModelClass2>(loader);
97-
break;
98-
}
99-
}
100-
101-
protected:
102-
template <class T> void load0(ModelLoader &loader)
103-
{
104-
auto transformer = Base::get_typed_transformer<T>();
105-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
106-
for (int i = 0; i < config.num_hidden_layers; i++)
107-
{
108-
std::string layer_prefix = "model.layers." + std::to_string(Base::layer_ids[i]) + '.';
109-
loader.read_tensor(layer_prefix + "input_layernorm.weight", transformer->layers[i].input_layernorm.weight);
110-
loader.read_tensor(layer_prefix + "mlp.down_proj.weight", transformer->layers[i].mlp.down_proj.weight);
111-
loader.read_tensor(layer_prefix + "mlp.gate_proj.weight", transformer->layers[i].mlp.gate_proj.weight);
112-
loader.read_tensor(layer_prefix + "mlp.up_proj.weight", transformer->layers[i].mlp.up_proj.weight);
113-
loader.read_tensor(layer_prefix + "post_attention_layernorm.weight", transformer->layers[i].post_attention_layernorm.weight);
114-
115-
loader.read_tensor(layer_prefix + "self_attn.k_proj.weight", transformer->layers[i].attention.k_proj.weight);
116-
loader.read_tensor(layer_prefix + "self_attn.o_proj.weight", transformer->layers[i].attention.o_proj.weight);
117-
loader.read_tensor(layer_prefix + "self_attn.q_proj.weight", transformer->layers[i].attention.q_proj.weight);
118-
loader.read_tensor(layer_prefix + "self_attn.v_proj.weight", transformer->layers[i].attention.v_proj.weight);
119-
}
120-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
121-
122-
if (transformer->lm_head)
123-
loader.read_tensor("lm_head.weight", dynamic_cast<Linear *>(transformer->lm_head)->weight);
124-
125-
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
126-
<< "corrupted model weights: " << w_ctx_.get_used_mem() / ggml_tensor_overhead() << " vs "
127-
<< w_ctx_.get_mem_size() / ggml_tensor_overhead();
128-
}
129-
13088
public:
13189
BaseConfig config;
13290

models/qwen.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,6 @@ namespace v2
266266
public:
267267
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_QWEN2, bool tie_embeddings = false);
268268

269-
void load(ModelLoader &loader) override;
270-
271269
public:
272270
Config config;
273271

@@ -308,37 +306,6 @@ namespace v2
308306
layer.attention.freq_base = config.rope_theta;
309307
}
310308
}
311-
312-
void ConditionalGeneration::load(ModelLoader &loader)
313-
{
314-
auto transformer = get_typed_transformer<ModelClass>();
315-
loader.read_tensor("model.embed_tokens.weight", transformer->word_embeddings.weight);
316-
for (int i = 0; i < config.num_hidden_layers; i++)
317-
{
318-
std::string layer_prefix = "model.layers." + std::to_string(layer_ids[i]) + '.';
319-
320-
loader.read_tensor(layer_prefix + "self_attn.k_proj.weight", transformer->layers[i].attention.k_proj.weight);
321-
loader.read_tensor(layer_prefix + "self_attn.k_proj.bias", transformer->layers[i].attention.k_proj.bias);
322-
loader.read_tensor(layer_prefix + "self_attn.q_proj.weight", transformer->layers[i].attention.q_proj.weight);
323-
loader.read_tensor(layer_prefix + "self_attn.q_proj.bias", transformer->layers[i].attention.q_proj.bias);
324-
loader.read_tensor(layer_prefix + "self_attn.v_proj.weight", transformer->layers[i].attention.v_proj.weight);
325-
loader.read_tensor(layer_prefix + "self_attn.v_proj.bias", transformer->layers[i].attention.v_proj.bias);
326-
loader.read_tensor(layer_prefix + "self_attn.o_proj.weight", transformer->layers[i].attention.o_proj.weight);
327-
328-
loader.read_tensor(layer_prefix + "input_layernorm.weight", transformer->layers[i].input_layernorm.weight);
329-
loader.read_tensor(layer_prefix + "post_attention_layernorm.weight", transformer->layers[i].post_attention_layernorm.weight);
330-
331-
loader.read_tensor(layer_prefix + "mlp.down_proj.weight", transformer->layers[i].mlp.down_proj.weight);
332-
loader.read_tensor(layer_prefix + "mlp.up_proj.weight", transformer->layers[i].mlp.up_proj.weight);
333-
loader.read_tensor(layer_prefix + "mlp.gate_proj.weight", transformer->layers[i].mlp.gate_proj.weight);
334-
}
335-
loader.read_tensor("model.norm.weight", transformer->final_layernorm.weight);
336-
if (!tie_embeddings)
337-
loader.read_tensor("lm_head.weight", dynamic_cast<Linear *>(transformer->lm_head)->weight);
338-
339-
CHATLLM_CHECK(w_ctx_.get_used_mem() == w_ctx_.get_mem_size())
340-
<< "corrupted model weights";
341-
}
342309
}
343310

344311
namespace v2_tie

0 commit comments

Comments
 (0)