From cdab12146bdf3632b205c0716a2b3c0213077d84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 3 Oct 2024 12:41:30 +0200 Subject: [PATCH] deduce share_decoder_embeddings from HF tie_word_embeddings flag (#123) --- eole/bin/convert/convert_HF.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/eole/bin/convert/convert_HF.py b/eole/bin/convert/convert_HF.py index 3939f4d7..fbc7cc7e 100755 --- a/eole/bin/convert/convert_HF.py +++ b/eole/bin/convert/convert_HF.py @@ -577,6 +577,8 @@ def run(cls, args): quant_layers = [] params = ["weight", "bias"] + share_decoder_embeddings = config.get("tie_word_embeddings", False) + add_qkvbias = False add_ffnbias = False shared_layer_norm = False @@ -589,7 +591,6 @@ def run(cls, args): optional_eos = [] mapped_tokens = [] gpt2_pretok = False - share_decoder_embeddings = False generator_bias = False # ALL THESE IF SHOULD BE HANDLED IN MAPPINGS @@ -689,6 +690,8 @@ def get_weight(checkpoint, tensor_name): "encoder.layer_norm.bias", "generator.weight", ] + if share_decoder_embeddings: + targetlist.remove("generator.weight") for target in targetlist: if target in key_maps[arch].keys(): source = key_maps[arch][target] @@ -701,19 +704,10 @@ def get_weight(checkpoint, tensor_name): w = get_weight(checkpoint, source) if w is not None: eole_safetensor[target] = w - elif target == "generator.weight": - # lm_head is not in HF safetensors -> share from embeddings matrix - share_decoder_embeddings = True if target == "generator.bias": generator_bias = True - if torch.equal( - eole_safetensor.get("generator.weight", None), - eole_safetensor["tgt_emb.embeddings.weight"], - ): - share_decoder_embeddings = True - if wmap_path: weightmap = wmap["weight_map"] ckpt_list = []