-
Notifications
You must be signed in to change notification settings - Fork 1
/
llm_evaluate.py
56 lines (45 loc) · 2.03 KB
/
llm_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import json
import lm_eval
import torch
from lm_eval.models.huggingface import HFLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from train_utils import quantize_model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", default="mini_llamas/Llama-2-470m")
parser.add_argument("--checkpoint")
parser.add_argument("--quantize")
parser.add_argument("--quantize_kwargs", type=json.loads, default=dict())
parser.add_argument("--quantize_lm_head", action="store_true")
parser.add_argument("--tasks", nargs="+", required=True)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seq_len", type=int, default=2048)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
kwargs = dict(
pretrained_model_name_or_path=args.model_id,
max_position_embeddings=args.seq_len,
use_cache=False,
torch_dtype=torch.bfloat16,
)
if args.checkpoint is None:
# load pre-trained weights
model = AutoModelForCausalLM.from_pretrained(**kwargs).cuda()
else:
# don't load pre-trained weights
model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(**kwargs)).cuda()
quantize_model(model.model, args.quantize, **args.quantize_kwargs)
if args.quantize_lm_head:
quantize_model(model.lm_head, args.quantize, **args.quantize_kwargs)
if args.checkpoint is not None:
# load weights from checkpoint, after quantization, since BitNet requires model modification
state_dict = torch.load(args.checkpoint, map_location="cpu", mmap=True)
model.load_state_dict(state_dict["model"])
result = lm_eval.simple_evaluate(
model=HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.batch_size, max_length=args.seq_len),
tasks=args.tasks,
limit=10 if args.debug else None,
)
print(result["results"])