|
16 | 16 |
|
17 | 17 |
|
18 | 18 | def get_llm(name, dtype): |
19 | | - if name == "gptj": |
20 | | - model_name = "EleutherAI/gpt-j-6B" |
21 | | - |
22 | | - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype) |
23 | | - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") |
24 | | - elif name == "llama2-13b": |
25 | | - kwargs = {} |
26 | | - if "HF_TOKEN" in os.environ: |
27 | | - kwargs["token"] = os.environ.get("HF_TOKEN") |
28 | | - |
29 | | - model_name = "meta-llama/Llama-2-13b-hf" |
30 | | - model = LlamaForCausalLM.from_pretrained( |
31 | | - model_name, torch_dtype=dtype, **kwargs |
32 | | - ) |
33 | | - tokenizer = LlamaTokenizer.from_pretrained(model_name, **kwargs) |
34 | | - else: |
| 19 | + name2params = { |
| 20 | + "gptj": ("EleutherAI/gpt-j-6B", AutoModelForCausalLM, AutoTokenizer), |
| 21 | + "llama2-7b": ("meta-llama/Llama-2-7b-hf", LlamaForCausalLM, LlamaTokenizer), |
| 22 | + "llama2-13b": ("meta-llama/Llama-2-13b-hf", LlamaForCausalLM, LlamaTokenizer), |
| 23 | + } |
| 24 | + |
| 25 | + if name not in name2params: |
35 | 26 | raise ValueError("Unsupported model name") |
| 27 | + |
| 28 | + kwargs = {} |
| 29 | + if name.startswith("llama2") and "HF_TOKEN" in os.environ: |
| 30 | + kwargs = {"HF_TOKEN": os.environ.get("HF_TOKEN")} |
| 31 | + |
| 32 | + model_name, M, T = name2params[name] |
| 33 | + |
| 34 | + model = M.from_pretrained(model_name, torch_dtype=dtype, **kwargs) |
| 35 | + tokenizer = T.from_pretrained(model_name) |
36 | 36 | return tokenizer, model |
37 | 37 |
|
38 | 38 |
|
39 | 39 | class LlmBenchmark(Benchmark): |
40 | 40 | def __init__(self, params) -> None: |
41 | 41 | name = params.get("name", "gptj") |
42 | 42 | dtype = params.get("dtype") |
43 | | - self.batch_size = params.get("batch_size", 1) |
| 43 | + self.batch_size = int(params.get("batch_size", 1)) |
44 | 44 | self.n_iter = params.get("n_iter", 5) |
45 | 45 | self.warmup_batches = params.get("warmup", 2) |
46 | 46 |
|
@@ -90,12 +90,13 @@ def inference(self, backend): |
90 | 90 | with torch.inference_mode(), cast: |
91 | 91 | tokens, total_time = self.generate(backend) |
92 | 92 |
|
| 93 | + print(f"Fw time: {total_time:.1f}") |
| 94 | + |
93 | 95 | if i < self.warmup_batches: |
94 | 96 | # We restart timer because that was just a warmup |
95 | 97 | start = get_time() |
96 | 98 | continue |
97 | 99 |
|
98 | | - print(f"Fw time: {total_time:.1f}") |
99 | 100 | fw_times.append(total_time) |
100 | 101 | n_items += math.prod(tokens.shape) |
101 | 102 | outputs.append(tokens) |
|
0 commit comments