Skip to content

Commit

Permalink
Compute the real prefill latency using the logits processor (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Mar 13, 2024
1 parent 6e83384 commit d2d1e62
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 49 deletions.
117 changes: 70 additions & 47 deletions optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,14 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
else:
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

if backend.config.task in TEXT_GENERATION_TASKS:
LOGGER.info("\t+ Additional warmup for Text Generation")
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
LOGGER.info("\t+ Additional warmup for Image Diffusion")
_ = backend.call(self.call_inputs, self.config.call_kwargs)

if self.config.memory:
LOGGER.info("\t+ Creating inference memory tracker")
self.memory_tracker = MemoryTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)
if backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_memory_tracking(backend)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
Expand All @@ -142,10 +145,11 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
self.report.log_memory()

if self.config.latency:
LOGGER.info("\t+ Creating inference latency tracker")
self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)
if backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_latency_tracking(backend)
if backend.config.name in PER_TOKEN_BACKENDS:
self.run_fine_grained_text_generation_latency_tracking(backend)
else:
self.run_text_generation_latency_tracking(backend)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
self.run_image_diffusion_latency_tracking(backend)
else:
Expand All @@ -155,8 +159,6 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
self.report.log_throughput()

if self.config.energy:
LOGGER.info("\t+ Creating inference energy tracker")
self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)
if backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_energy_tracking(backend)
elif backend.config.task in IMAGE_DIFFUSION_TASKS:
Expand All @@ -170,7 +172,11 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
## Memory tracking
def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker = MemoryTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)
self.memory_tracker.reset()

with self.memory_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

Expand All @@ -184,24 +190,56 @@ def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]):

def run_image_diffusion_memory_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
self.memory_tracker = MemoryTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)

with self.memory_tracker.track():
_ = backend.call(self.call_inputs, self.config.call_kwargs)

self.report.call.memory = self.memory_tracker.get_max_memory()

def run_inference_memory_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
self.memory_tracker = MemoryTracker(
backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids
)

with self.memory_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.forward.memory = self.memory_tracker.get_max_memory()

## Latency tracking
def run_fine_grained_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running fine-grained Text Generation latency tracking")
self.logits_processor = LatencyLogitsProcessor(device=backend.config.device, backend=backend.config.name)
self.config.generate_kwargs["logits_processor"] = LogitsProcessorList(
[self.logits_processor, *self.config.generate_kwargs.get("logits_processor", [])]
)

while self.logits_processor.get_elapsed_time() < self.config.duration:
with self.logits_processor.track():
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)

self.report.per_token.latency = self.logits_processor.get_per_token_latency()
self.report.prefill.latency = self.logits_processor.get_prefill_latency()
self.report.decode.latency = self.logits_processor.get_decode_latency()

self.report.per_token.throughput = Throughput.from_latency(
self.report.per_token.latency, self.text_generation_per_token_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)
self.report.prefill.throughput = Throughput.from_latency(
self.report.prefill.latency, self.text_generation_prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)
self.report.decode.throughput = Throughput.from_latency(
self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)

def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running latency tracking")
self.latency_tracker.reset()
LOGGER.info("\t+ Running Text Generation latency tracking")
self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)

while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)
Expand All @@ -212,40 +250,21 @@ def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT])
self.report.prefill.latency, self.text_generation_prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)

if backend.config.name in PER_TOKEN_BACKENDS:
self.logits_processor = LatencyLogitsProcessor(device=backend.config.device, backend=backend.config.name)
self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.logits_processor])
self.logits_processor.reset()

while self.logits_processor.get_elapsed_time() < self.config.duration:
with self.logits_processor.track():
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)
self.latency_tracker.reset()
while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)
generate_latency = self.latency_tracker.get_latency()

self.report.decode.latency = self.logits_processor.get_decode_latency()
self.report.per_token.latency = self.logits_processor.get_per_token_latency()
self.report.decode.throughput = Throughput.from_latency(
self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)
self.report.per_token.throughput = Throughput.from_latency(
self.report.per_token.latency,
self.text_generation_per_token_volume,
unit=TEXT_GENERATION_THROUGHPUT_UNIT,
)
else:
self.latency_tracker.reset()
while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.generate(self.generate_inputs, self.config.generate_kwargs)
generate_latency = self.latency_tracker.get_latency()

self.report.decode.latency = generate_latency - forward_latency
self.report.decode.throughput = Throughput.from_latency(
self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)
self.report.decode.latency = generate_latency - forward_latency
self.report.decode.throughput = Throughput.from_latency(
self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT
)

def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running latency tracking")
self.latency_tracker.reset()
self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)

while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.call(self.call_inputs, self.config.call_kwargs)
Expand All @@ -257,7 +276,8 @@ def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT])

def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running latency tracking")
self.latency_tracker.reset()
self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device)

while self.latency_tracker.get_elapsed_time() < self.config.duration:
with self.latency_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)
Expand All @@ -270,7 +290,8 @@ def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]):
## Energy tracking
def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)

with self.energy_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)
forward_energy = self.energy_tracker.get_energy()
Expand All @@ -292,7 +313,8 @@ def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]):

def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)

with self.energy_tracker.track():
_ = backend.call(self.call_inputs, self.config.call_kwargs)

Expand All @@ -303,7 +325,8 @@ def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]):

def run_inference_energy_tracking(self, backend: Backend[BackendConfigT]):
LOGGER.info("\t+ Running energy tracking")
self.energy_tracker.reset()
self.energy_tracker = EnergyTracker(device=backend.config.device, device_ids=backend.config.device_ids)

with self.energy_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

Expand Down
21 changes: 19 additions & 2 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def track(self):

self.tok_events: List[Union[float, torch.cuda.Event]] = []

if self.device == "cuda" and self.backend == "pytorch":
prefill_event = torch.cuda.Event(enable_timing=True)
prefill_event.record()
self.tok_events.append(prefill_event)

yield # this is where generate is called, and for each token, we record an event

self.run_events.append(self.tok_events)
Expand All @@ -235,17 +240,29 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):

return scores

def get_prefill_latency(self) -> Latency:
if self.device == "cuda" and self.backend == "pytorch":
# synchronize the device to make sure all events have been recorded
torch.cuda.synchronize()
latencies_list = [
self.run_events[i][0].elapsed_time(self.run_events[i][1]) / 1e3 for i in range(len(self.run_events))
]
else:
latencies_list = [(self.run_events[i][1] - self.run_events[i][0]) for i in range(len(self.run_events))]

return Latency.from_values(latencies_list, unit=LATENCY_UNIT)

def get_per_token_latency(self) -> Latency:
latencies_list = []
for tok_events in self.run_events:
if self.device == "cuda" and self.backend == "pytorch":
# synchronize the device to make sure all events have been recorded
torch.cuda.synchronize()
latencies_list.extend(
[tok_events[i - 1].elapsed_time(tok_events[i]) / 1e3 for i in range(1, len(tok_events))]
[tok_events[i].elapsed_time(tok_events[i + 1]) / 1e3 for i in range(1, len(tok_events) - 1)]
)
else:
latencies_list.extend([(tok_events[i] - tok_events[i - 1]) for i in range(1, len(tok_events))])
latencies_list.extend([(tok_events[i] - tok_events[i + 1]) for i in range(1, len(tok_events) - 1)])

return Latency.from_values(latencies_list, unit=LATENCY_UNIT)

Expand Down

0 comments on commit d2d1e62

Please sign in to comment.