@@ -18,6 +18,7 @@ def __init__(
1818 self ,
1919 kv_cache_spec : KVCacheSpec ,
2020 block_pool : BlockPool ,
21+ use_eagle : bool ,
2122 ) -> None :
2223 """
2324 Initializes the SpecializedManager.
@@ -30,12 +31,17 @@ def __init__(
3031 self .kv_cache_spec = kv_cache_spec
3132 self .block_pool = block_pool
3233
34+ # Needs special handling for find_longest_cache_hit if eagle is enabled
35+ self .use_eagle = use_eagle
36+
3337 @abstractmethod
3438 def find_longest_cache_hit (
3539 self , block_hashes : list [BlockHashType ]) -> list [KVCacheBlock ]:
3640 """
3741 Get the longest cache hit prefix of the blocks. If no cache hit is
38- found, return an empty list.
42+ found, return an empty list. if eagle is enabled, drop the last matched
43+ block to force recompute the last block to get the required hidden
44+ states for eagle drafting head.
3945
4046 Args:
4147 block_hashes: The block hashes of the request.
@@ -79,6 +85,8 @@ def find_longest_cache_hit(
7985 computed_blocks .append (cached_block )
8086 else :
8187 break
88+ if self .use_eagle and len (computed_blocks ) > 0 :
89+ computed_blocks .pop ()
8290 return computed_blocks
8391
8492 def remove_skipped_blocks (self , blocks : list [KVCacheBlock ],
@@ -89,14 +97,20 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
8997
9098class SlidingWindowManager (SpecializedManager ):
9199
92- def __init__ (self , kv_cache_spec : SlidingWindowSpec ,
93- block_pool : BlockPool ):
94- super ().__init__ (kv_cache_spec , block_pool )
100+ def __init__ (self , kv_cache_spec : SlidingWindowSpec , block_pool : BlockPool ,
101+ use_eagle : bool ):
102+ super ().__init__ (kv_cache_spec , block_pool , use_eagle )
95103 self .sliding_window = kv_cache_spec .sliding_window
96104 # The number of contiguous blocks needed for prefix cache hit.
97105 # -1 since the input token itself is also included in the window
98106 self .sliding_window_contiguous_blocks = cdiv (
99107 (kv_cache_spec .sliding_window - 1 ), self .block_size )
108+ if self .use_eagle :
109+ # Need to drop the last matched block if eagle is enabled. For
110+ # sliding window layer, we achieve this by increasing the number of
111+ # contiguous blocks needed for prefix cache hit by one and dropping
112+ # the last matched block.
113+ self .sliding_window_contiguous_blocks += 1
100114 self ._null_block = block_pool .null_block
101115
102116 def find_longest_cache_hit (
@@ -109,6 +123,7 @@ def find_longest_cache_hit(
109123 computed_blocks = [self ._null_block ] * len (block_hashes )
110124 num_contiguous_blocks = 0
111125
126+ match_found = False
112127 # Search from right to left and early stop when a match is found.
113128 for i in range (len (block_hashes ) - 1 , - 1 , - 1 ):
114129 if cached_block := self .block_pool .get_cached_block (
@@ -121,12 +136,16 @@ def find_longest_cache_hit(
121136 # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
122137 # when sliding_window_contiguous_blocks=2.
123138 del computed_blocks [i + num_contiguous_blocks :]
124- return computed_blocks
139+ match_found = True
140+ break
125141 else :
126142 num_contiguous_blocks = 0
127- # The first `num_contiguous_blocks` is a cache hit even if
128- # `num_contiguous_blocks < sliding_window_contiguous_blocks`.
129- del computed_blocks [num_contiguous_blocks :]
143+ if not match_found :
144+ # The first `num_contiguous_blocks` is a cache hit even if
145+ # `num_contiguous_blocks < sliding_window_contiguous_blocks`.
146+ del computed_blocks [num_contiguous_blocks :]
147+ if self .use_eagle and len (computed_blocks ) > 0 :
148+ computed_blocks .pop ()
130149 return computed_blocks
131150
132151 def remove_skipped_blocks (self , blocks : list [KVCacheBlock ],
@@ -155,7 +174,7 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
155174
156175
157176def get_specialized_manager (kv_cache_spec : KVCacheSpec ,
158- block_pool : BlockPool ) -> SpecializedManager :
177+ ** kwargs ) -> SpecializedManager :
159178 manager_class = spec_manager_map [type (kv_cache_spec )]
160- manager = manager_class (kv_cache_spec , block_pool )
179+ manager = manager_class (kv_cache_spec , ** kwargs )
161180 return manager
0 commit comments