Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit a7a8988

Browse files
committed
Merge branch 'main' into egor/margin
2 parents 06d1709 + 28b0e5d commit a7a8988

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

dl_bench/llm.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,31 @@
1616

1717

1818
def get_llm(name, dtype):
19-
if name == "gptj":
20-
model_name = "EleutherAI/gpt-j-6B"
21-
22-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype)
23-
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
24-
elif name == "llama2-13b":
25-
kwargs = {}
26-
if "HF_TOKEN" in os.environ:
27-
kwargs["token"] = os.environ.get("HF_TOKEN")
28-
29-
model_name = "meta-llama/Llama-2-13b-hf"
30-
model = LlamaForCausalLM.from_pretrained(
31-
model_name, torch_dtype=dtype, **kwargs
32-
)
33-
tokenizer = LlamaTokenizer.from_pretrained(model_name, **kwargs)
34-
else:
19+
name2params = {
20+
"gptj": ("EleutherAI/gpt-j-6B", AutoModelForCausalLM, AutoTokenizer),
21+
"llama2-7b": ("meta-llama/Llama-2-7b-hf", LlamaForCausalLM, LlamaTokenizer),
22+
"llama2-13b": ("meta-llama/Llama-2-13b-hf", LlamaForCausalLM, LlamaTokenizer),
23+
}
24+
25+
if name not in name2params:
3526
raise ValueError("Unsupported model name")
27+
28+
kwargs = {}
29+
if name.startswith("llama2") and "HF_TOKEN" in os.environ:
30+
kwargs = {"HF_TOKEN": os.environ.get("HF_TOKEN")}
31+
32+
model_name, M, T = name2params[name]
33+
34+
model = M.from_pretrained(model_name, torch_dtype=dtype, **kwargs)
35+
tokenizer = T.from_pretrained(model_name)
3636
return tokenizer, model
3737

3838

3939
class LlmBenchmark(Benchmark):
4040
def __init__(self, params) -> None:
4141
name = params.get("name", "gptj")
4242
dtype = params.get("dtype")
43-
self.batch_size = params.get("batch_size", 1)
43+
self.batch_size = int(params.get("batch_size", 1))
4444
self.n_iter = params.get("n_iter", 5)
4545
self.warmup_batches = params.get("warmup", 2)
4646

@@ -90,12 +90,13 @@ def inference(self, backend):
9090
with torch.inference_mode(), cast:
9191
tokens, total_time = self.generate(backend)
9292

93+
print(f"Fw time: {total_time:.1f}")
94+
9395
if i < self.warmup_batches:
9496
# We restart timer because that was just a warmup
9597
start = get_time()
9698
continue
9799

98-
print(f"Fw time: {total_time:.1f}")
99100
fw_times.append(total_time)
100101
n_items += math.prod(tokens.shape)
101102
outputs.append(tokens)

llm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ if [[ -z "${DL_BENCH_ARGS}" ]]; then
99
exit 1
1010
fi
1111

12-
for NAME in llama2-13b gptj
12+
for NAME in llama2-7b llama2-13b gptj
1313
do
1414
for BS in 1 4
1515
do

0 commit comments

Comments
 (0)