Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 0f01f6b

Browse files
committed
Fixed llm code for nvidia (#85)
1 parent 09f178e commit 0f01f6b

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

dl_bench/llm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@ def __init__(self, params) -> None:
3232
"num_beams": 4,
3333
}
3434

35-
def generate(self, prompt):
35+
def generate(self, prompt, backend):
3636
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
37+
backend.sync()
3738
start = time.perf_counter()
39+
input_ids = backend.to_device(input_ids)
3840
gen_tokens = self.model.generate(
3941
input_ids, **self.gen_kwargs, pad_token_id=self.tokenizer.eos_token_id
4042
)
43+
backend.sync()
4144
total_time = time.perf_counter() - start
4245

4346
# text = self.tokenizer.batch_decode(gen_tokens)[0]
@@ -54,16 +57,16 @@ def inference(self, backend):
5457

5558
print("Warmup started")
5659
with torch.inference_mode(), tm.timeit("warmup_s"):
57-
# self.model.eval()
58-
self.generate(self.warmup_prompt)
60+
self.model.eval()
61+
self.generate(self.warmup_prompt, backend)
5962
print("Warmup done")
6063

6164
# self.model.eval()
6265
enabled = backend.dtype != torch.float32
6366
with torch.inference_mode(), torch.autocast(
6467
enabled=enabled, device_type=backend.device_name
6568
), tm.timeit("duration_s"):
66-
tokens, total_time = self.generate(self.prompt)
69+
tokens, total_time = self.generate(self.prompt, backend)
6770
outputs = [tokens]
6871

6972
results = tm.get_results()

0 commit comments

Comments
 (0)