diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index ad71fd354..508951565 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -49,6 +49,7 @@ } +@torch.inference_mode() def generate_image(pipe, prompt, image_name): seed = 42 image = pipe( @@ -61,56 +62,57 @@ def generate_image(pipe, prompt, image_name): print(f"Image generated saved as {image_name}") -def benchmark_model( - pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype=torch.float16 +@torch.inference_mode() +def benchmark_backbone_standalone( + pipe, + num_warmup=10, + num_benchmark=100, + model_name="flux-dev", ): - """Benchmark the backbone model inference time.""" + """Benchmark the backbone model directly without running the full pipeline.""" backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet - backbone_times = [] + # Generate dummy inputs for the backbone + dummy_inputs, _, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(model_name, backbone) + + # Extract the dict from the tuple and move to cuda + dummy_inputs_dict = { + k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs[0].items() + } + + # Warmup + print(f"Warming up: {num_warmup} iterations") + for _ in tqdm(range(num_warmup), desc="Warmup"): + _ = backbone(**dummy_inputs_dict) + + # Benchmark + torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - def forward_pre_hook(_module, _input): + print(f"Benchmarking: {num_benchmark} iterations") + times = [] + for _ in tqdm(range(num_benchmark), desc="Benchmark"): + torch.cuda.profiler.cudart().cudaProfilerStart() start_event.record() - - def forward_hook(_module, _input, _output): + _ = backbone(**dummy_inputs_dict) end_event.record() torch.cuda.synchronize() - backbone_times.append(start_event.elapsed_time(end_event)) - - pre_handle = backbone.register_forward_pre_hook(forward_pre_hook) - post_handle = backbone.register_forward_hook(forward_hook) - - try: - print(f"Starting warmup: {num_warmup} runs") - for _ in tqdm(range(num_warmup), desc="Warmup"): - with torch.amp.autocast("cuda", dtype=model_dtype): - _ = pipe( - prompt, - output_type="pil", - num_inference_steps=num_inference_steps, - generator=torch.Generator("cuda").manual_seed(42), - ) - - backbone_times.clear() - - print(f"Starting benchmark: {num_runs} runs") - for _ in tqdm(range(num_runs), desc="Benchmark"): - with torch.amp.autocast("cuda", dtype=model_dtype): - _ = pipe( - prompt, - output_type="pil", - num_inference_steps=num_inference_steps, - generator=torch.Generator("cuda").manual_seed(42), - ) - finally: - pre_handle.remove() - post_handle.remove() - - total_backbone_time = sum(backbone_times) - avg_latency = total_backbone_time / (num_runs * num_inference_steps) - print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms") + torch.cuda.profiler.cudart().cudaProfilerStop() + times.append(start_event.elapsed_time(end_event)) + + avg_latency = sum(times) / len(times) + times = sorted(times) + p50 = times[len(times) // 2] + p95 = times[int(len(times) * 0.95)] + p99 = times[int(len(times) * 0.99)] + + print("\nBackbone-only inference latency:") + print(f" Average: {avg_latency:.2f} ms") + print(f" P50: {p50:.2f} ms") + print(f" P95: {p95:.2f} ms") + print(f" P99: {p99:.2f} ms") + return avg_latency @@ -196,7 +198,12 @@ def main(): pipe.to("cuda") if args.benchmark: - benchmark_model(pipe, args.prompt, model_dtype=model_dtype) + benchmark_backbone_standalone( + pipe, + num_warmup=10, + num_benchmark=100, + model_name=args.model, + ) if not args.skip_image: generate_image(pipe, args.prompt, image_name)