From aa2d2afea2d1416d01f658db2406a046fc7d9cd0 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 12 Feb 2024 18:15:32 +0100 Subject: [PATCH] fixed nvidia llm --- dl_bench/llm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dl_bench/llm.py b/dl_bench/llm.py index 1f01256..0ed79bf 100644 --- a/dl_bench/llm.py +++ b/dl_bench/llm.py @@ -31,12 +31,15 @@ def __init__(self, params) -> None: "num_beams": 4, } - def generate(self, prompt): + def generate(self, prompt, backend): input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + backend.sync() start = time.perf_counter() + input_ids = backend.to_device(input_ids) gen_tokens = self.model.generate( input_ids, **self.gen_kwargs, pad_token_id=self.tokenizer.eos_token_id ) + backend.sync() total_time = time.perf_counter() - start # text = self.tokenizer.batch_decode(gen_tokens)[0] @@ -54,7 +57,7 @@ def inference(self, backend): print("Warmup started") with torch.inference_mode(), tm.timeit("warmup_s"): self.model.eval() - self.generate(self.warmup_prompt) + self.generate(self.warmup_prompt, backend) print("Warmup done") self.model.eval() @@ -62,7 +65,7 @@ def inference(self, backend): with torch.inference_mode(), torch.autocast( enabled=enabled, device_type=backend.device_name ), tm.timeit("duration_s"): - tokens, total_time = self.generate(self.prompt) + tokens, total_time = self.generate(self.prompt, backend) outputs = [tokens] results = tm.get_results()