@@ -39,8 +39,9 @@ def __init__(
3939
4040 self .enable_caching = enable_caching
4141 self .caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
42- # FIXME: make prefix cache stats conditional on log_stats
4342 self .log_stats = log_stats
43+ # FIXME: make prefix cache stats conditional on log_stats
44+ self .prefix_cache_stats = PrefixCacheStats () if log_stats else None
4445 # NOTE(woosuk): To avoid frequent block allocation, we preallocate some
4546 # blocks for each request. For example, when a request reaches the end
4647 # of its block table, we preallocate N blocks in advance. This way, we
@@ -79,7 +80,6 @@ def __init__(
7980 # This is only used to track the RUNNING requests, we do not track the
8081 # data for reempted ones.
8182 self .num_cached_block : dict [str , int ] = {}
82- self .prefix_cache_stats = PrefixCacheStats ()
8383
8484 @property
8585 def usage (self ) -> float :
@@ -90,12 +90,14 @@ def usage(self) -> float:
9090 """
9191 return self .block_pool .get_usage ()
9292
93- def make_prefix_cache_stats (self ) -> PrefixCacheStats :
93+ def make_prefix_cache_stats (self ) -> Optional [ PrefixCacheStats ] :
9494 """Get (and reset) the prefix cache stats.
9595
9696 Returns:
97- The current prefix caching stats.
97+ The current prefix caching stats, or None if logging is disabled .
9898 """
99+ if not self .log_stats :
100+ return None
99101 stats = self .prefix_cache_stats
100102 self .prefix_cache_stats = PrefixCacheStats ()
101103 return stats
@@ -125,7 +127,9 @@ def get_computed_blocks(
125127 self .block_size , request )
126128 self .req_to_block_hashes [request .request_id ] = block_hashes
127129
128- self .prefix_cache_stats .requests += 1
130+ if self .log_stats :
131+ assert self .prefix_cache_stats is not None
132+ self .prefix_cache_stats .requests += 1
129133 # When the request requires prompt logprobs, we skip prefix caching.
130134 if request .sampling_params .prompt_logprobs is not None :
131135 return [], 0
@@ -145,8 +149,10 @@ def get_computed_blocks(
145149
146150 computed_blocks = (
147151 self .specialized_manager .find_longest_cache_hit (block_hashes ))
148- self .prefix_cache_stats .queries += len (block_hashes )
149- self .prefix_cache_stats .hits += len (computed_blocks )
152+ if self .log_stats :
153+ assert self .prefix_cache_stats is not None
154+ self .prefix_cache_stats .queries += len (block_hashes )
155+ self .prefix_cache_stats .hits += len (computed_blocks )
150156
151157 if last_block_hash is not None :
152158 # Add back the last block hash if it was removed.
@@ -317,17 +323,19 @@ def free(self, request: Request) -> None:
317323
318324 def reset_prefix_cache (self ) -> bool :
319325 """Reset prefix cache. This function may be used in RLHF
320- flows to invalid prefix caching after the weights are updated,
326+ flows to invalidate prefix caching after the weights are updated,
321327 or used for resetting prefix caching status for benchmarking.
322328
323329 Returns:
324330 bool: True if the prefix cache is successfully reset,
325331 False otherwise.
326332 """
327- if self .block_pool .reset_prefix_cache ():
333+ if not self .block_pool .reset_prefix_cache ():
334+ return False
335+ if self .log_stats :
336+ assert self .prefix_cache_stats is not None
328337 self .prefix_cache_stats .reset = True
329- return True
330- return False
338+ return True
331339
332340 def get_num_common_prefix_blocks (
333341 self ,
0 commit comments