Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit da597ae

Browse files
authored
Fixed llm code for nvidia (#85)
1 parent bd18c2d commit da597ae

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

dl_bench/llm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@ def __init__(self, params) -> None:
3131
"num_beams": 4,
3232
}
3333

34-
def generate(self, prompt):
34+
def generate(self, prompt, backend):
3535
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
36+
backend.sync()
3637
start = time.perf_counter()
38+
input_ids = backend.to_device(input_ids)
3739
gen_tokens = self.model.generate(
3840
input_ids, **self.gen_kwargs, pad_token_id=self.tokenizer.eos_token_id
3941
)
42+
backend.sync()
4043
total_time = time.perf_counter() - start
4144

4245
# text = self.tokenizer.batch_decode(gen_tokens)[0]
@@ -54,15 +57,15 @@ def inference(self, backend):
5457
print("Warmup started")
5558
with torch.inference_mode(), tm.timeit("warmup_s"):
5659
self.model.eval()
57-
self.generate(self.warmup_prompt)
60+
self.generate(self.warmup_prompt, backend)
5861
print("Warmup done")
5962

6063
self.model.eval()
6164
enabled = backend.dtype != torch.float32
6265
with torch.inference_mode(), torch.autocast(
6366
enabled=enabled, device_type=backend.device_name
6467
), tm.timeit("duration_s"):
65-
tokens, total_time = self.generate(self.prompt)
68+
tokens, total_time = self.generate(self.prompt, backend)
6669
outputs = [tokens]
6770

6871
results = tm.get_results()

0 commit comments

Comments
 (0)