1010 stop the prefill instance when the decode instance is slow.
1111"""
1212import threading
13- import time
1413from collections import deque
1514from typing import Deque , List , Optional , Union
1615
@@ -29,21 +28,21 @@ class SimpleBuffer(KVLookupBufferBase):
2928 def __init__ (self , signal_pipe : KVPipeBase , data_pipe : KVPipeBase ,
3029 buffer_size_thresh : float ):
3130 """
32- signal_pipe: on CPU
33-
34- NOTE: on-device recv will block all threads in the process, making the
35- KV cache producer unable to listen to new request while transmitting
36- KV cache. Luckily CPU recv only blocks the current thread so we use
31+ signal_pipe: on CPU
32+
33+ NOTE: on-device recv will block all threads in the process, making the
34+ KV cache producer unable to listen to new request while transmitting
35+ KV cache. Luckily CPU recv only blocks the current thread so we use
3736 CPU recv to listen to new request.
38-
37+
3938 data_pipe: on device (e.g. GPU)
4039 """
4140
4241 self .buffer : Deque [List [torch .Tensor ]] = deque ()
4342
4443 self .buffer_size = 0
4544 self .buffer_size_threshold = buffer_size_thresh
46- self .buffer_lock = threading .Lock ()
45+ self .buffer_cv = threading .Condition ()
4746 self .signal_pipe = signal_pipe
4847 self .data_pipe = data_pipe
4948 self .request_handling_thread : Optional [threading .Thread ] = None
@@ -116,11 +115,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
116115 hidden = hidden .clone ()
117116
118117 buffer_item = [input_tokens , roi , key , value , hidden ]
118+ data_size = sum ([self ._get_element_size (data ) for data in buffer_item ])
119+
120+ with self .buffer_cv :
121+ if self .buffer_size + data_size > self .buffer_size_threshold :
122+ # log outside the while loop to avoid this message being logged
123+ # repeatedly.
124+ logger .debug ("KV transfer buffer is full. Handling..." )
125+ while self .buffer_size + data_size > self .buffer_size_threshold :
126+ self .buffer_cv .wait ()
119127
120- with self .buffer_lock :
121- for data in buffer_item :
122- self .buffer_size += self ._get_element_size (data )
128+ self .buffer_size += data_size
123129 self .buffer .append (buffer_item )
130+ self .buffer_cv .notify ()
124131
125132 def _is_end_signal (self , signal ):
126133 return signal is None
@@ -143,35 +150,31 @@ def drop_select_handler(self):
143150 roi = (roi > 0.5 )
144151 tokens_roi_recver = [input_tokens , roi ]
145152
146- matched_length = 0
147-
148- # perform input tokens and roi matching
149- # FIXME: this matching is O(n), ideally it should be O(1)
150- # but this buffer size won't (and shouldn't) be too large so
151- # the fix is not urgent.
152- with self .buffer_lock :
153-
153+ def is_buffer_available (
154+ tokens_roi_recver : List [torch .Tensor ], ) -> bool :
155+ # perform input tokens and roi matching
156+ # FIXME: this matching is O(n), ideally it should be O(1)
157+ # but this buffer size won't (and shouldn't) be too large so
158+ # the fix is not urgent.
154159 for _ in range (len (self .buffer )):
155-
156- temp_length = self ._matches (self .buffer [0 ],
157- tokens_roi_recver )
158- if temp_length > 0 :
159- matched_length = temp_length
160- break
160+ if self ._matches (self .buffer [0 ],
161+ tokens_roi_recver ) > 0 :
162+ return True
161163 # rotate the element we just accessed to the end
162164 self .buffer .rotate (- 1 )
163-
164- if matched_length > 0 :
165- # need to clone the tensor
166- # in case the tensor is freed before sending finishes
167- matched_item = self .buffer .popleft ()
168- for tensor in matched_item :
169- self ._send_tensor_and_dec_size (tensor )
170-
171- else :
172- # no match, just send None
173- for _ in range (5 ):
174- self .data_pipe .send_tensor (None )
165+ return False
166+
167+ with self .buffer_cv :
168+ while not is_buffer_available (tokens_roi_recver ):
169+ logger .debug (
170+ "KV transfer buffer is not available. Waiting..." )
171+ self .buffer_cv .wait ()
172+ # need to clone the tensor
173+ # in case the tensor is freed before sending finishes
174+ matched_item = self .buffer .popleft ()
175+ for tensor in matched_item :
176+ self ._send_tensor_and_dec_size (tensor )
177+ self .buffer_cv .notify ()
175178
176179 except RuntimeError as e :
177180 if 'Connection closed by peer' not in str (e ):
@@ -208,20 +211,10 @@ def drop_select(
208211
209212 return [input_tokens , roi , key , value , hidden ]
210213
211- def full_handler (self ):
212- time .sleep (0.001 )
213-
214214 def insert (self , input_tokens : torch .Tensor , roi : torch .Tensor ,
215215 key : torch .Tensor , value : torch .Tensor ,
216216 hidden : torch .Tensor ) -> None :
217217
218- if self .buffer_size > self .buffer_size_threshold :
219- # log outside the while loop to avoid this message being logged
220- # repeatedly.
221- logger .debug ("KV transfer buffer is full. Handling..." )
222- while self .buffer_size > self .buffer_size_threshold :
223- self .full_handler ()
224-
225218 self ._add_to_buffer (input_tokens , roi , key , value , hidden )
226219
227220 # when calling the insert, the current process is a sender
0 commit comments