Skip to content

Commit ba3d736

Browse files
committed
Further logging
1 parent f1454d9 commit ba3d736

File tree

3 files changed

+46
-14
lines changed

3 files changed

+46
-14
lines changed

src/transformers/generation/continuous_batching/cache.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ class PagedAttentionMemoryHandler:
208208
_activation_dtype = torch.bfloat16
209209
_activation_safety_factor = 2
210210
_input_dtype = torch.int32
211-
_upper_bound_max_batch_tokens = 2048
212-
_upper_bound_num_blocks = 16384
211+
_upper_bound_max_batch_tokens = 256
212+
_upper_bound_num_blocks = 4096
213213

214214
def __init__(
215215
self,
@@ -271,14 +271,20 @@ def compute_num_blocks_and_max_batch_tokens(
271271
m: float = 0.1,
272272
) -> tuple[int, int]:
273273
cache_memory = self.get_available_memory(max_memory_percent)
274+
logger.info(f"Cache memory: {cache_memory}")
275+
276+
# Compute memory footprints # TODO: check and explain better
277+
mem_per_activation_token = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor
278+
mem_per_cache_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize
279+
mem_per_input_token = 8 * m * self._input_dtype.itemsize
280+
logger.info(f"Memory per activation token: {mem_per_activation_token}")
281+
logger.info(f"Memory per cache token: {mem_per_cache_token}")
282+
logger.info(f"Memory per input token: {mem_per_input_token}")
274283

275284
# Compute second-degree polynomial coefficients
276285
a = m * self._activation_dtype.itemsize
277-
b = 8 * m * self._input_dtype.itemsize
278-
b += 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize
279-
c = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor
280-
c += 2 * self._input_dtype.itemsize
281-
c -= cache_memory
286+
b = mem_per_input_token + mem_per_cache_token
287+
c = mem_per_activation_token + 2 * self._input_dtype.itemsize - cache_memory
282288

283289
# Compute discriminant and greatest solution
284290
discriminant = b**2 - 4 * a * c

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,10 @@ def prepare_next_batch(self):
262262
self.max_seqlen_k = max(self.max_seqlen_k, key_length)
263263
state.position_offset += query_length
264264

265-
logger.info(
266-
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}"
265+
logger.debug(
266+
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
267+
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
268+
f"cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}"
267269
)
268270
self._build_tensors(
269271
input_ids,
@@ -666,7 +668,7 @@ def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor):
666668
torch.cuda.synchronize()
667669
batch_processor.prepare_next_batch()
668670
device, total, reserved, allocated = get_device_and_memory_breakdown()
669-
logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
671+
logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
670672
if torch.cuda.is_available() and self.use_cuda_graph:
671673
if self.current_batch == 0:
672674
self.warmup(batch_processor)
@@ -780,8 +782,8 @@ def generate_batch(
780782
"""
781783
if not inputs:
782784
return []
783-
if logger.getEffectiveLevel() <= logging.INFO:
784-
logger.warning("Progress bar is disabled when logger level is less than INFO")
785+
if logger.getEffectiveLevel() <= logging.DEBUG:
786+
logger.warning("Progress bar is disabled when logger level is less than DEBUG")
785787
progress_bar = False
786788

787789
# Initialize manager with the batch inputs

src/transformers/generation/continuous_batching/core.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
# We centralize the logger here to coordinate between logging and progress bar
1313
logger = logging.getLogger("ContinuousBatchingLogger")
14+
logger.setLevel(logging.INFO)
1415

1516

1617
@staticmethod
@@ -102,12 +103,35 @@ class RequestState:
102103
static_outputs: list[int] = field(default_factory=list) # Generated tokens
103104
allocated_blocks: list[int] = field(default_factory=list) # Block IDs allocated to the request
104105
position_offset: int = 0 # Current position in the sequence for position_ids
105-
status: RequestStatus = RequestStatus.PENDING # Status of the request
106+
_status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property
106107
max_new_tokens: int = 20 # Maximum number of new tokens to generate
107108
eos_token_id: int = -1 # ID of the end-of-sequence token
108109
created_time: float = field(default_factory=time.time) # Time the request was created
109110
error: Optional[str] = None # Error message if the request failed
110111
next_token: Optional[str] = None # Next token to be generated
112+
lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
113+
114+
@property
115+
def status(self) -> RequestStatus:
116+
return self._status
117+
118+
@status.setter
119+
def status(self, value: RequestStatus):
120+
if self._status == RequestStatus.PENDING:
121+
self.lifespan = (time.time(), -1)
122+
elif value == RequestStatus.FINISHED:
123+
self.lifespan = (self.lifespan[0], time.time())
124+
self.log_end_of_request()
125+
self._status = value
126+
127+
def log_end_of_request(self):
128+
prefill_len = len(self.full_prompt_ids)
129+
decode_len = self.generated_len()
130+
start_time = self.lifespan[0] - self.created_time
131+
end_time = self.lifespan[1] - self.created_time
132+
logger.info(
133+
f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }"
134+
)
111135

112136
def current_len(self) -> int:
113137
"""Get the current length of the sequence (prompt + generated tokens)."""
@@ -148,7 +172,7 @@ def update_with_token(self, token_id: int) -> bool:
148172
def __repr__(self):
149173
msg = [
150174
f"request_id={self.request_id}",
151-
f"status={self.status}",
175+
f"status={self._status}",
152176
f"out_tokens={self.generated_len()}",
153177
f"query_length={len(self.prompt_ids)}",
154178
f"remaining_tokens={len(self.remaining_prompt_ids)}",

0 commit comments

Comments
 (0)