|
15 | 15 |
|
16 | 16 |
|
17 | 17 | def get_llm(name, dtype): |
18 | | - if name == "gptj": |
19 | | - model_name = "EleutherAI/gpt-j-6B" |
20 | | - |
21 | | - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype) |
22 | | - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") |
23 | | - elif name == "llama2-13b": |
24 | | - kwargs = {} |
25 | | - if "HF_TOKEN" in os.environ: |
26 | | - kwargs["token"] = os.environ.get("HF_TOKEN") |
27 | | - |
28 | | - model_name = "meta-llama/Llama-2-13b-hf" |
29 | | - model = LlamaForCausalLM.from_pretrained( |
30 | | - model_name, torch_dtype=dtype, **kwargs |
31 | | - ) |
32 | | - tokenizer = LlamaTokenizer.from_pretrained(model_name, **kwargs) |
33 | | - else: |
| 18 | + name2params = { |
| 19 | + "gptj": ("EleutherAI/gpt-j-6B", AutoModelForCausalLM, AutoTokenizer), |
| 20 | + "llama2-7b": ("meta-llama/Llama-2-7b-hf", LlamaForCausalLM, LlamaTokenizer), |
| 21 | + "llama2-13b": ("meta-llama/Llama-2-13b-hf", LlamaForCausalLM, LlamaTokenizer), |
| 22 | + } |
| 23 | + |
| 24 | + if name not in name2params: |
34 | 25 | raise ValueError("Unsupported model name") |
| 26 | + |
| 27 | + kwargs = {} |
| 28 | + if name.startswith("llama2") and "HF_TOKEN" in os.environ: |
| 29 | + kwargs = {"HF_TOKEN": os.environ.get("HF_TOKEN")} |
| 30 | + |
| 31 | + model_name, M, T = name2params[name] |
| 32 | + |
| 33 | + model = M.from_pretrained(model_name, torch_dtype=dtype, **kwargs) |
| 34 | + tokenizer = T.from_pretrained(model_name) |
35 | 35 | return tokenizer, model |
36 | 36 |
|
37 | 37 |
|
38 | 38 | class LlmBenchmark(Benchmark): |
39 | 39 | def __init__(self, params) -> None: |
40 | 40 | name = params.get("name", "gptj") |
41 | 41 | dtype = params.get("dtype") |
42 | | - self.batch_size = params.get("batch_size", 1) |
| 42 | + self.batch_size = int(params.get("batch_size", 1)) |
43 | 43 | self.n_iter = params.get("n_iter", 5) |
44 | 44 | self.warmup_batches = params.get("warmup", 2) |
45 | 45 |
|
@@ -89,12 +89,13 @@ def inference(self, backend): |
89 | 89 | with torch.inference_mode(), cast: |
90 | 90 | tokens, total_time = self.generate(backend) |
91 | 91 |
|
| 92 | + print(f"Fw time: {total_time:.1f}") |
| 93 | + |
92 | 94 | if i < self.warmup_batches: |
93 | 95 | # We restart timer because that was just a warmup |
94 | 96 | start = get_time() |
95 | 97 | continue |
96 | 98 |
|
97 | | - print(f"Fw time: {total_time:.1f}") |
98 | 99 | fw_times.append(total_time) |
99 | 100 | n_items += math.prod(tokens.shape) |
100 | 101 | outputs.append(tokens) |
|
0 commit comments