diff --git a/optimum_benchmark/benchmarks/inference/benchmark.py b/optimum_benchmark/benchmarks/inference/benchmark.py index 5cc4830d..4293f39c 100644 --- a/optimum_benchmark/benchmarks/inference/benchmark.py +++ b/optimum_benchmark/benchmarks/inference/benchmark.py @@ -127,11 +127,14 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None: else: _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) + if backend.config.task in TEXT_GENERATION_TASKS: + LOGGER.info("\t+ Additional warmup for Text Generation") + _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) + elif backend.config.task in IMAGE_DIFFUSION_TASKS: + LOGGER.info("\t+ Additional warmup for Image Diffusion") + _ = backend.call(self.call_inputs, self.config.call_kwargs) + if self.config.memory: - LOGGER.info("\t+ Creating inference memory tracker") - self.memory_tracker = MemoryTracker( - backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids - ) if backend.config.task in TEXT_GENERATION_TASKS: self.run_text_generation_memory_tracking(backend) elif backend.config.task in IMAGE_DIFFUSION_TASKS: @@ -142,10 +145,11 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None: self.report.log_memory() if self.config.latency: - LOGGER.info("\t+ Creating inference latency tracker") - self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) if backend.config.task in TEXT_GENERATION_TASKS: - self.run_text_generation_latency_tracking(backend) + if backend.config.name in PER_TOKEN_BACKENDS: + self.run_fine_grained_text_generation_latency_tracking(backend) + else: + self.run_text_generation_latency_tracking(backend) elif backend.config.task in IMAGE_DIFFUSION_TASKS: self.run_image_diffusion_latency_tracking(backend) else: @@ -155,8 +159,6 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None: self.report.log_throughput() if self.config.energy: - LOGGER.info("\t+ Creating inference energy tracker") - self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids) if backend.config.task in TEXT_GENERATION_TASKS: self.run_text_generation_energy_tracking(backend) elif backend.config.task in IMAGE_DIFFUSION_TASKS: @@ -170,7 +172,11 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None: ## Memory tracking def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running memory tracking") + self.memory_tracker = MemoryTracker( + backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids + ) self.memory_tracker.reset() + with self.memory_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) @@ -184,7 +190,10 @@ def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]): def run_image_diffusion_memory_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running memory tracking") - self.memory_tracker.reset() + self.memory_tracker = MemoryTracker( + backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids + ) + with self.memory_tracker.track(): _ = backend.call(self.call_inputs, self.config.call_kwargs) @@ -192,16 +201,45 @@ def run_image_diffusion_memory_tracking(self, backend: Backend[BackendConfigT]): def run_inference_memory_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running memory tracking") - self.memory_tracker.reset() + self.memory_tracker = MemoryTracker( + backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids + ) + with self.memory_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) self.report.forward.memory = self.memory_tracker.get_max_memory() ## Latency tracking + def run_fine_grained_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]): + LOGGER.info("\t+ Running fine-grained Text Generation latency tracking") + self.logits_processor = LatencyLogitsProcessor(device=backend.config.device, backend=backend.config.name) + self.config.generate_kwargs["logits_processor"] = LogitsProcessorList( + [self.logits_processor, *self.config.generate_kwargs.get("logits_processor", [])] + ) + + while self.logits_processor.get_elapsed_time() < self.config.duration: + with self.logits_processor.track(): + _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) + + self.report.per_token.latency = self.logits_processor.get_per_token_latency() + self.report.prefill.latency = self.logits_processor.get_prefill_latency() + self.report.decode.latency = self.logits_processor.get_decode_latency() + + self.report.per_token.throughput = Throughput.from_latency( + self.report.per_token.latency, self.text_generation_per_token_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + ) + self.report.prefill.throughput = Throughput.from_latency( + self.report.prefill.latency, self.text_generation_prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + ) + self.report.decode.throughput = Throughput.from_latency( + self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + ) + def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]): - LOGGER.info("\t+ Running latency tracking") - self.latency_tracker.reset() + LOGGER.info("\t+ Running Text Generation latency tracking") + self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) + while self.latency_tracker.get_elapsed_time() < self.config.duration: with self.latency_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) @@ -212,40 +250,21 @@ def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]) self.report.prefill.latency, self.text_generation_prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT ) - if backend.config.name in PER_TOKEN_BACKENDS: - self.logits_processor = LatencyLogitsProcessor(device=backend.config.device, backend=backend.config.name) - self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.logits_processor]) - self.logits_processor.reset() - - while self.logits_processor.get_elapsed_time() < self.config.duration: - with self.logits_processor.track(): - _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) + self.latency_tracker.reset() + while self.latency_tracker.get_elapsed_time() < self.config.duration: + with self.latency_tracker.track(): + _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) + generate_latency = self.latency_tracker.get_latency() - self.report.decode.latency = self.logits_processor.get_decode_latency() - self.report.per_token.latency = self.logits_processor.get_per_token_latency() - self.report.decode.throughput = Throughput.from_latency( - self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT - ) - self.report.per_token.throughput = Throughput.from_latency( - self.report.per_token.latency, - self.text_generation_per_token_volume, - unit=TEXT_GENERATION_THROUGHPUT_UNIT, - ) - else: - self.latency_tracker.reset() - while self.latency_tracker.get_elapsed_time() < self.config.duration: - with self.latency_tracker.track(): - _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) - generate_latency = self.latency_tracker.get_latency() - - self.report.decode.latency = generate_latency - forward_latency - self.report.decode.throughput = Throughput.from_latency( - self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT - ) + self.report.decode.latency = generate_latency - forward_latency + self.report.decode.throughput = Throughput.from_latency( + self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + ) def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running latency tracking") - self.latency_tracker.reset() + self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) + while self.latency_tracker.get_elapsed_time() < self.config.duration: with self.latency_tracker.track(): _ = backend.call(self.call_inputs, self.config.call_kwargs) @@ -257,7 +276,8 @@ def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT]) def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running latency tracking") - self.latency_tracker.reset() + self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) + while self.latency_tracker.get_elapsed_time() < self.config.duration: with self.latency_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) @@ -270,7 +290,8 @@ def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]): ## Energy tracking def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running energy tracking") - self.energy_tracker.reset() + self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids) + with self.energy_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) forward_energy = self.energy_tracker.get_energy() @@ -292,7 +313,8 @@ def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]): def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running energy tracking") - self.energy_tracker.reset() + self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids) + with self.energy_tracker.track(): _ = backend.call(self.call_inputs, self.config.call_kwargs) @@ -303,7 +325,8 @@ def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]): def run_inference_energy_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running energy tracking") - self.energy_tracker.reset() + self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids) + with self.energy_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 55d51a71..38bc1b99 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -218,6 +218,11 @@ def track(self): self.tok_events: List[Union[float, torch.cuda.Event]] = [] + if self.device == "cuda" and self.backend == "pytorch": + prefill_event = torch.cuda.Event(enable_timing=True) + prefill_event.record() + self.tok_events.append(prefill_event) + yield # this is where generate is called, and for each token, we record an event self.run_events.append(self.tok_events) @@ -235,6 +240,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): return scores + def get_prefill_latency(self) -> Latency: + if self.device == "cuda" and self.backend == "pytorch": + # synchronize the device to make sure all events have been recorded + torch.cuda.synchronize() + latencies_list = [ + self.run_events[i][0].elapsed_time(self.run_events[i][1]) / 1e3 for i in range(len(self.run_events)) + ] + else: + latencies_list = [(self.run_events[i][1] - self.run_events[i][0]) for i in range(len(self.run_events))] + + return Latency.from_values(latencies_list, unit=LATENCY_UNIT) + def get_per_token_latency(self) -> Latency: latencies_list = [] for tok_events in self.run_events: @@ -242,10 +259,10 @@ def get_per_token_latency(self) -> Latency: # synchronize the device to make sure all events have been recorded torch.cuda.synchronize() latencies_list.extend( - [tok_events[i - 1].elapsed_time(tok_events[i]) / 1e3 for i in range(1, len(tok_events))] + [tok_events[i].elapsed_time(tok_events[i + 1]) / 1e3 for i in range(1, len(tok_events) - 1)] ) else: - latencies_list.extend([(tok_events[i] - tok_events[i - 1]) for i in range(1, len(tok_events))]) + latencies_list.extend([(tok_events[i] - tok_events[i + 1]) for i in range(1, len(tok_events) - 1)]) return Latency.from_values(latencies_list, unit=LATENCY_UNIT)