Skip to content

Commit

Permalink
Manage share_decoder_embeddings in convert_HF, misc fixes and imp…
Browse files Browse the repository at this point in the history
…rovements (#121)
  • Loading branch information
francoishernandez authored Oct 2, 2024
1 parent a8cfb0d commit 43593eb
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 21 deletions.
3 changes: 3 additions & 0 deletions eole/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
import os

__version__ = "0.0.1"
ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
25 changes: 18 additions & 7 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,8 @@ def run(cls, args):
optional_eos = []
mapped_tokens = []
gpt2_pretok = False
share_decoder_embeddings = False
generator_bias = False

# ALL THESE IF SHOULD BE HANDLED IN MAPPINGS
if arch == "PhiForCausalLM":
Expand Down Expand Up @@ -699,13 +701,18 @@ def get_weight(checkpoint, tensor_name):
w = get_weight(checkpoint, source)
if w is not None:
eole_safetensor[target] = w
elif target == "generator.weight":
# lm_head is not in HF safetensors -> share from embeddings matrix
share_decoder_embeddings = True

# not sure why we're doing this if generator.bias not in key_map
if target == "generator.weight" and w is not None:
eole_safetensor["generator.bias"] = torch.zeros(
eole_safetensor["generator.weight"].size(0),
dtype=TORCH_DTYPES[compute_dtype],
)
if target == "generator.bias":
generator_bias = True

if torch.equal(
eole_safetensor.get("generator.weight", None),
eole_safetensor["tgt_emb.embeddings.weight"],
):
share_decoder_embeddings = True

if wmap_path:
weightmap = wmap["weight_map"]
Expand Down Expand Up @@ -987,7 +994,9 @@ def get_weight(checkpoint, tensor_name):

if generation_config_json is not None:
with open(generation_config_json, encoding="utf-8") as f:
data = json.load(f)
data = json.loads(
f.read().replace(",\n}", "\n}")
) # dirty patch to remove trailing comma...
generation_config_dict = {}
# we probably need a better mapping at some point
keys = ["top_k", "top_p", "temperature", "max_length"]
Expand Down Expand Up @@ -1122,6 +1131,8 @@ def get_weight(checkpoint, tensor_name):
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
left_pad=left_pad,
share_decoder_embeddings=share_decoder_embeddings,
generator_bias=generator_bias,
),
training=TrainingConfig(
compute_dtype=compute_dtype,
Expand Down
42 changes: 29 additions & 13 deletions eole/bin/tools/run_mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from eole.config.cli import add_model
from eole.config import get_non_default_values
from eole.config.run import PredictConfig
from eole import ROOT_DIR


TASKS = [
"abstract_algebra",
Expand Down Expand Up @@ -79,9 +81,7 @@
choices = ["A", "B", "C", "D"]


def compute_metric(output_filename):
with open(output_filename, "r") as f:
run_results = json.load(f)
def compute_metric(run_results):
total_acc = 0
total_num = 0
for task in run_results:
Expand All @@ -94,7 +94,11 @@ def compute_metric(output_filename):
print("ACC-%s: %.4f" % (task, acc / len(gold_answers)))
total_acc += acc
total_num += len(gold_answers)
print("ACC-all: %.4f" % (total_acc / total_num))
run_results[task]["metrics"] = {"acc": acc}
acc_all = total_acc / total_num
print("ACC-all: %.4f" % (acc_all))
run_results["metrics"] = {"acc": acc_all}
return run_results


def format_subject(subject):
Expand Down Expand Up @@ -132,22 +136,21 @@ def gen_prompt(train_df, subject, k=-1):
# return input_ids[-len(stop_ids)]


def evaluate(opt):
def evaluate(args, data_dir):
import pandas as pd

logger = init_logger(opt.log_file)
set_random_seed(opt.seed, use_gpu(opt))
logger = init_logger(args.log_file)
set_random_seed(args.seed, use_gpu(args))

run_results = {}
dir_name = os.path.dirname(opt.model_path[0])
dir_name = args.model_path[0]

# Save results in the model dir
output_filename = os.path.join(dir_name, "mmlu_results.json")

# Build the translator (along with the model)
engine = InferenceEnginePY(opt)
engine = InferenceEnginePY(args)

# this considers that we are always in the recipes/mmlu folder, should be improved
data_dir = "recipes/mmlu/data/"
ntrain = 5 # nshots from dev

start_time = time.time()
Expand Down Expand Up @@ -188,10 +191,11 @@ def evaluate(opt):

engine.terminate()

run_results = compute_metric(run_results)

with open(output_filename, "w") as f:
json.dump(run_results, f, ensure_ascii=False, indent=2)

compute_metric(output_filename)
end_time = time.time()
logger.info("total run time %.2f" % (end_time - start_time))

Expand All @@ -207,6 +211,15 @@ def add_args(cls, parser):
required=False,
help="Path of main YAML config file.",
)
# TODO: we might want to retrieve transparently from HF at some point
parser.add_argument(
"-data_dir",
"--data_dir",
"-d",
required=False,
help="Path to the MMLU data root.",
default=os.path.join(os.path.dirname(ROOT_DIR), "recipes", "mmlu", "data"),
)
add_model(parser, PredictConfig)

def run(args):
Expand All @@ -230,6 +243,9 @@ def run(args):
if "config" in config.keys():
config.pop("config")

# not supported in PredictConfig schema
data_dir = config.pop("data_dir")

config = PredictConfig(**config)

evaluate(config)
evaluate(config, data_dir)
2 changes: 1 addition & 1 deletion eole/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class QuantizeConfig(Config):
default=[], description="List of layers to be compressed in 4/8bit."
) # validate against list of layers names ?
quant_type: Literal[
"", "bnb_9bit", "bnb_FP4", "bnb_NF4", "awq_gemm", "awq_gemv"
"", "bnb_8bit", "bnb_FP4", "bnb_NF4", "awq_gemm", "awq_gemv"
] = Field(default="", description="Type of compression.")
w_bit: int = Field(
default=4, description="W_bit quantization"
Expand Down
4 changes: 4 additions & 0 deletions eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ class BaseModelConfig(Config):
description="Which function to use for generating probabilities "
"over the target vocabulary.",
)
generator_bias: bool = Field(
default=True,
description="Control whether or not the generator Linear module has bias weights.",
)
add_estimator: bool = Field(default=False, description="Add estimator layer")

left_pad: bool = Field(
Expand Down
13 changes: 13 additions & 0 deletions eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self, **kwargs):
self.tgt_emb = kwargs.get("tgt_emb", None)
self.add_estimator = kwargs.get("add_estimator", False)
self.hidden_size = kwargs.get("hidden_size", None)
self.share_decoder_embeddings = False
if self.encoder is not None and self.src_emb is None:
raise ValueError("An Encoder needs source Embeddings")
if self.decoder is not None and self.tgt_emb is None:
Expand Down Expand Up @@ -224,6 +225,7 @@ def build_generator(self, model_config, vocabs):
nn.Linear,
in_features=model_config.decoder.hidden_size,
out_features=len(vocabs["tgt"]),
bias=model_config.generator_bias,
)
if model_config.share_decoder_embeddings:
generator.weight = self.tgt_emb.embeddings.weight
Expand Down Expand Up @@ -438,6 +440,8 @@ def from_config(
model = cls.build_blocks(
model_config, vocabs, running_config=running_config
) # corresponds to build_task_specific_model
# TODO: handle this better at some point?
model.share_decoder_embeddings = model_config.share_decoder_embeddings
# generator -> shall it be called within build_blocks?
if model_config.decoder is not None:
model.build_generator(model_config, vocabs)
Expand Down Expand Up @@ -706,6 +710,15 @@ def load_safe_state_dict(
+ "."
+ param_name
)
if (
f"{name}.{param_name}"
in ["generator.weight", "generator.bias"]
and self.share_decoder_embeddings
):
logger.info(
"└─> Sharing from embeddings matrix since "
"`share_decoder_embeddings` flag is enabled."
)
if getattr(running_config, "compute_dtype", None) == torch.int8:
torch.quantization.quantize_dynamic(module, inplace=True)
else:
Expand Down

0 comments on commit 43593eb

Please sign in to comment.