From 85647bcf82fd7ea7d0dfe33750934509ca8a8ff6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 10 Dec 2024 19:33:35 -0800 Subject: [PATCH 1/7] add timing statistics for tracking Signed-off-by: youkaichao --- vllm/forward_context.py | 59 ++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index cd136f43c0c57..24b2b0eac4a0c 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -1,9 +1,11 @@ import time -from collections import Counter +from collections import Counter, defaultdict from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Optional +import torch + import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger @@ -13,7 +15,9 @@ track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 batchsize_counter: Counter = Counter() last_logging_time: float = 0 +forward_start_time: float = 0 batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL +batchsize_forward_time: defaultdict = defaultdict(list) @dataclass @@ -40,23 +44,10 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): can be attention metadata, etc. Here we can inject common logic for every model forward pass. """ - global track_batchsize, batchsize_counter - global last_logging_time, batchsize_logging_interval - if track_batchsize and context is not None: - if hasattr(context, "num_prefill_tokens"): - # for v0 attention backends - batchsize = context.num_prefill_tokens + context.num_decode_tokens - else: - # for v1 attention backends - batchsize = context.num_input_tokens - batchsize_counter[batchsize] += 1 - if time.monotonic() - last_logging_time > batchsize_logging_interval: - last_logging_time = time.monotonic() - sorted_data = sorted(batchsize_counter.items(), - key=lambda x: x[1], - reverse=True) - logger.info("Batchsize distribution (batchsize, count): %s", - sorted_data) + global forward_start_time + need_to_track_batchsize = track_batchsize and context is not None + if need_to_track_batchsize: + forward_start_time = time.monotonic() global _forward_context prev_context = _forward_context _forward_context = ForwardContext( @@ -66,4 +57,36 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): try: yield finally: + global batchsize_counter + global last_logging_time, batchsize_logging_interval + if need_to_track_batchsize: + if hasattr(context, "num_prefill_tokens"): + # for v0 attention backends + batchsize = context.num_prefill_tokens + \ + context.num_decode_tokens + else: + # for v1 attention backends + batchsize = context.num_input_tokens + batchsize_counter[batchsize] += 1 + # we use synchronous scheduling right now, + # adding a sync point here should not affect + # scheduling of the next batch + torch.cuda.synchronize() + now = time.monotonic() + batchsize_forward_time[batchsize].append(now - forward_start_time) + if now - last_logging_time > batchsize_logging_interval: + last_logging_time = now + sorted_by_count = sorted(batchsize_counter.items(), + key=lambda x: x[1], + reverse=True) + logger.info("Batchsize distribution (batchsize, count): %s", + sorted_by_count) + forward_stats = [] + for bs, _ in sorted_by_count: + times = batchsize_forward_time[bs] + forward_stats.append( + (bs, len(times), torch.quantile(times, q=0.5).item())) + logger.info(("Batchsize forward time stats " + "(batchsize, count, median_time): %s"), + forward_stats) _forward_context = prev_context From c1d9e2446bd5877629f90dcb61055ceb6b41a0fd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 10 Dec 2024 19:37:31 -0800 Subject: [PATCH 2/7] fix Signed-off-by: youkaichao --- vllm/forward_context.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 24b2b0eac4a0c..b458e19861742 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -84,8 +84,9 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): forward_stats = [] for bs, _ in sorted_by_count: times = batchsize_forward_time[bs] - forward_stats.append( - (bs, len(times), torch.quantile(times, q=0.5).item())) + forward_stats.append((bs, len(times), + torch.quantile(torch.tensor(times), + q=0.5).item())) logger.info(("Batchsize forward time stats " "(batchsize, count, median_time): %s"), forward_stats) From 6a6fbfd2de33dc1537a2ef73635ec934c3255827 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 10 Dec 2024 19:44:05 -0800 Subject: [PATCH 3/7] merge logs Signed-off-by: youkaichao --- vllm/forward_context.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index b458e19861742..204b1ec56eab9 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -1,5 +1,5 @@ import time -from collections import Counter, defaultdict +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Optional @@ -13,7 +13,6 @@ logger = init_logger(__name__) track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 -batchsize_counter: Counter = Counter() last_logging_time: float = 0 forward_start_time: float = 0 batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL @@ -67,27 +66,26 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): else: # for v1 attention backends batchsize = context.num_input_tokens - batchsize_counter[batchsize] += 1 # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch torch.cuda.synchronize() now = time.monotonic() - batchsize_forward_time[batchsize].append(now - forward_start_time) + # time measurement is in milliseconds + batchsize_forward_time[batchsize].append( + (now - forward_start_time) * 1000) if now - last_logging_time > batchsize_logging_interval: last_logging_time = now - sorted_by_count = sorted(batchsize_counter.items(), - key=lambda x: x[1], - reverse=True) - logger.info("Batchsize distribution (batchsize, count): %s", - sorted_by_count) forward_stats = [] - for bs, _ in sorted_by_count: - times = batchsize_forward_time[bs] + for bs, times in batchsize_forward_time.items(): + if len(times) <= 1: + # can be cudagraph / profiling run + continue forward_stats.append((bs, len(times), torch.quantile(torch.tensor(times), q=0.5).item())) + forward_stats.sort(key=lambda x: x[1], reverse=True) logger.info(("Batchsize forward time stats " - "(batchsize, count, median_time): %s"), + "(batchsize, count, median_time(ms)): %s"), forward_stats) _forward_context = prev_context From 34d4ad644def839f260a0936b5b816485b8ade56 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 10 Dec 2024 19:49:30 -0800 Subject: [PATCH 4/7] fix Signed-off-by: youkaichao --- vllm/forward_context.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 204b1ec56eab9..ee6bdb74c14a2 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -81,9 +81,9 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): if len(times) <= 1: # can be cudagraph / profiling run continue - forward_stats.append((bs, len(times), - torch.quantile(torch.tensor(times), - q=0.5).item())) + medium = torch.quantile(torch.tensor(times), q=0.5).item() + medium = f"{medium:.2f}" + forward_stats.append((bs, len(times), medium)) forward_stats.sort(key=lambda x: x[1], reverse=True) logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), From f416469d8a881b83d76974f972585de979abb03a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 10 Dec 2024 19:51:08 -0800 Subject: [PATCH 5/7] use round Signed-off-by: youkaichao --- vllm/forward_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index ee6bdb74c14a2..71cd44d98c836 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -82,7 +82,7 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): # can be cudagraph / profiling run continue medium = torch.quantile(torch.tensor(times), q=0.5).item() - medium = f"{medium:.2f}" + medium = round(medium, 2) forward_stats.append((bs, len(times), medium)) forward_stats.sort(key=lambda x: x[1], reverse=True) logger.info(("Batchsize forward time stats " From f7016c568bfa888b83c00161cd130756a8c34352 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 10 Dec 2024 19:54:51 -0800 Subject: [PATCH 6/7] remove empty logging Signed-off-by: youkaichao --- vllm/forward_context.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 71cd44d98c836..491869e2ff1c5 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -85,7 +85,8 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): medium = round(medium, 2) forward_stats.append((bs, len(times), medium)) forward_stats.sort(key=lambda x: x[1], reverse=True) - logger.info(("Batchsize forward time stats " - "(batchsize, count, median_time(ms)): %s"), - forward_stats) + if forward_stats: + logger.info(("Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s"), + forward_stats) _forward_context = prev_context From b6590ed67092346bd17cc306b3dbc1e51031a5d0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 14 Dec 2024 14:32:42 -0800 Subject: [PATCH 7/7] use perf_counter Signed-off-by: youkaichao --- vllm/forward_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 491869e2ff1c5..7f56575279e9b 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -46,7 +46,7 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): global forward_start_time need_to_track_batchsize = track_batchsize and context is not None if need_to_track_batchsize: - forward_start_time = time.monotonic() + forward_start_time = time.perf_counter() global _forward_context prev_context = _forward_context _forward_context = ForwardContext( @@ -70,7 +70,7 @@ def set_forward_context(context: Any, vllm_config: VllmConfig): # adding a sync point here should not affect # scheduling of the next batch torch.cuda.synchronize() - now = time.monotonic() + now = time.perf_counter() # time measurement is in milliseconds batchsize_forward_time[batchsize].append( (now - forward_start_time) * 1000)