Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 28b0e5d

Browse files
authored
Added llama-7b (#92)
1 parent 20f109b commit 28b0e5d

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

dl_bench/llm.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,31 @@
1515

1616

1717
def get_llm(name, dtype):
18-
if name == "gptj":
19-
model_name = "EleutherAI/gpt-j-6B"
20-
21-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype)
22-
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
23-
elif name == "llama2-13b":
24-
kwargs = {}
25-
if "HF_TOKEN" in os.environ:
26-
kwargs["token"] = os.environ.get("HF_TOKEN")
27-
28-
model_name = "meta-llama/Llama-2-13b-hf"
29-
model = LlamaForCausalLM.from_pretrained(
30-
model_name, torch_dtype=dtype, **kwargs
31-
)
32-
tokenizer = LlamaTokenizer.from_pretrained(model_name, **kwargs)
33-
else:
18+
name2params = {
19+
"gptj": ("EleutherAI/gpt-j-6B", AutoModelForCausalLM, AutoTokenizer),
20+
"llama2-7b": ("meta-llama/Llama-2-7b-hf", LlamaForCausalLM, LlamaTokenizer),
21+
"llama2-13b": ("meta-llama/Llama-2-13b-hf", LlamaForCausalLM, LlamaTokenizer),
22+
}
23+
24+
if name not in name2params:
3425
raise ValueError("Unsupported model name")
26+
27+
kwargs = {}
28+
if name.startswith("llama2") and "HF_TOKEN" in os.environ:
29+
kwargs = {"HF_TOKEN": os.environ.get("HF_TOKEN")}
30+
31+
model_name, M, T = name2params[name]
32+
33+
model = M.from_pretrained(model_name, torch_dtype=dtype, **kwargs)
34+
tokenizer = T.from_pretrained(model_name)
3535
return tokenizer, model
3636

3737

3838
class LlmBenchmark(Benchmark):
3939
def __init__(self, params) -> None:
4040
name = params.get("name", "gptj")
4141
dtype = params.get("dtype")
42-
self.batch_size = params.get("batch_size", 1)
42+
self.batch_size = int(params.get("batch_size", 1))
4343
self.n_iter = params.get("n_iter", 5)
4444
self.warmup_batches = params.get("warmup", 2)
4545

@@ -89,12 +89,13 @@ def inference(self, backend):
8989
with torch.inference_mode(), cast:
9090
tokens, total_time = self.generate(backend)
9191

92+
print(f"Fw time: {total_time:.1f}")
93+
9294
if i < self.warmup_batches:
9395
# We restart timer because that was just a warmup
9496
start = get_time()
9597
continue
9698

97-
print(f"Fw time: {total_time:.1f}")
9899
fw_times.append(total_time)
99100
n_items += math.prod(tokens.shape)
100101
outputs.append(tokens)

llm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ if [[ -z "${DL_BENCH_ARGS}" ]]; then
99
exit 1
1010
fi
1111

12-
for NAME in llama2-13b gptj
12+
for NAME in llama2-7b llama2-13b gptj
1313
do
1414
for BS in 1 4
1515
do

0 commit comments

Comments
 (0)