diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 9d290af6..8b5268e5 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -188,17 +188,23 @@ class Driver: # Stage 1 _prefill_backlog: queue.Queue[ActiveRequest | None] # Stage 2 + _transfer_backlogs: list[queue.Queue[ActiveRequest]] = [] + # Stage 3 # We keep this as a dict to avoid a possibly expensive object comparison # when logging the index of the generate engine we send a prefill result # to, it allows us to natively have the index from the min operation, rather # than have to call .index() - _generate_backlogs: dict[int, queue.Queue[ActiveRequest | None]] = {} - # Stage 3 + _generate_backlogs: dict[int, queue.Queue[ActiveRequest]] = {} + # Stage 4 # This can be a list because we can pass it as an arg to generate and # detokenize threads. It is a list of tokens to be detokenized. _detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] _generate_slots: list[queue.Queue[int]] = [] - _active_requests: list[queue.Queue[tuple[int, ActiveRequest | None]]] = [] + _active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = [] + + # For interleaved_mode, only generate if all slots are full + # or corresponding prefill queue is empty. + _interleaved_mode: bool = False # todo: remove jax_padding after all then engine migrate to np padding _jax_padding = True @@ -209,6 +215,7 @@ def __init__( generate_engines: Optional[list[engine_api.Engine]] = None, prefill_params: Optional[list[Any]] = None, generate_params: Optional[list[Any]] = None, + interleaved_mode: bool = False, jax_padding: bool = True, ): if prefill_engines is None: @@ -229,22 +236,39 @@ def __init__( self._generate_engines = generate_engines self._prefill_params = prefill_params self._generate_params = generate_params + self._interleaved_mode = interleaved_mode + # Stages 1-4 represent the life cycle of a request. # Stage 1 # At first, a request is placed here in order to get prefilled. self._prefill_backlog = queue.Queue() - # _ready_to_prefill event will block the prefill thread until there is - # available decode slot to insert the prefill result. - self._ready_to_prefill = threading.Event() # Stage 2 + # After prefilling, it is placed here in order to get transferred to + # one of the generate backlogs. + # Interleaved Mode: Max size is 1 to increase the HBM utilization + # during generate. + # Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued + # while 1 transfer is enqueued while 1 is being transferred. + # TODO: Make queue size configurable. + self._transfer_backlogs = [ + queue.Queue(1 if self._interleaved_mode else 4) + for i in range(len(self._prefill_engines)) + ] + # Stage 3 # Each generate engine accesses its own generate backlog. + # Interleaved Mode: Max size is 1 to increase the HBM utilization + # during generate. + # Disaggregated Mode: Set as 1/3 the number of concurrent decodes. + # TODO: Calculate the backlog to saturate the generate engine while + # minimizing the memory usage for disaggregated mode. + # TODO: Make queue size configurable. self._generate_backlogs = { - # Don't receive more than 1/3 the number of concurrent decodes to avoid - # OOM for single host. - idx: queue.Queue(engine.max_concurrent_decodes // 3) + idx: queue.Queue( + 1 if self._interleaved_mode else engine.max_concurrent_decodes // 3 + ) for idx, engine in enumerate(self._generate_engines) } - # Stage 3 + # Stage 4 # After generation, ActiveRequests are placed on the detokenization backlog # for tokens to be sent into each ActiveRequest's return channel. # We have one of these per generate engine to simplify the logic keeping @@ -293,6 +317,18 @@ def __init__( JetThread( target=functools.partial(self._prefill_thread, idx), name=f"prefill-{idx}", + daemon=True, + ) + for idx in range(len(self._prefill_engines)) + ] + self._transfer_threads = [ + JetThread( + target=functools.partial( + self._transfer_thread, + idx, + ), + name=f"transfer-{idx}", + daemon=True, ) for idx in range(len(self._prefill_engines)) ] @@ -303,6 +339,7 @@ def __init__( idx, ), name=f"generate-{idx}", + daemon=True, ) for idx in range(len(self._generate_engines)) ] @@ -319,6 +356,7 @@ def __init__( self._all_threads = list( itertools.chain( self._prefill_threads, + self._transfer_threads, self._generate_threads, self.detokenize_threads, ) @@ -336,6 +374,7 @@ def stop(self): all_backlogs = list( itertools.chain( [self._prefill_backlog], + self._transfer_backlogs, self._generate_backlogs.values(), self._detokenize_backlogs, ) @@ -400,24 +439,11 @@ def _prefill_thread(self, idx: int): logging.info("---------Prefill params %d loaded.---------", idx) while self.live: - # The prefill thread can wait until there is available decode slot to - # insert. - if self._generate_slots[idx].qsize() == 0: - logging.info( - "Prefill waits for available slot; prefill queue size %d", - self._prefill_backlog.qsize(), - ) - self._ready_to_prefill.wait() - logging.info( - "Prefill continues; prefill queue size %d", - self._prefill_backlog.qsize(), - ) + my_transfer_backlog = self._transfer_backlogs[idx] # The prefill thread can just sleep until it has work to do. request = self._prefill_backlog.get(block=True) if request is None: break - # TODO: Implement hot/cold cache for history. - history = self._load_cache_history(request.history_path) # pylint: disable = assignment-from-none # Tokenize, and introduce a leading dimension is_bos = not bool(request.history_path) logging.info( @@ -434,21 +460,60 @@ def _prefill_thread(self, idx: int): max_prefill_length=prefill_engine.max_prefill_length, jax_padding=self._jax_padding, ) - # Compute new kv cache for the prefill_text, conditional on - # history. + # Compute new kv cache for the prefill_text. prefill_result = prefill_engine.prefill( params=prefill_params, - existing_prefix=history, padded_tokens=padded_tokens, true_length=true_length, ) request.prefill_result = prefill_result # Once prefill is complete, place it on the generation queue and block if # full. - self._generate_backlogs[idx].put(request, block=True) + my_transfer_backlog.put(request, block=True) + logging.info( + "Placed request on transfer queue %d, %d queued requests.", + idx, + my_transfer_backlog.qsize(), + ) + del prefill_result + del request + + def _transfer_thread(self, idx: int): + """Transfers the kv cache on an active request to the least full + generate backlog.""" + transfer_backlog = self._transfer_backlogs[idx] + + while self.live: + # The transfer thread can just sleep until it has work to do. + new_request = transfer_backlog.get(block=True) + target_idx = min( + self._generate_backlogs.items(), key=lambda q: q[1].qsize() + )[0] + # Only transfer the KVCache for the disaggregated serving. + # TODO: Remove the conditional after fixing the compatibility. + if not self._interleaved_mode: + logging.info( + "Transferring prefill from prefill engine %d " + "to generate engine %d.", + idx, + target_idx, + ) + # Transfer the info to the relevant generate slice. + new_request.prefill_result = jax.device_put( + new_request.prefill_result, + self._generate_engines[ + target_idx + ].get_prefix_destination_sharding(), + ) + # Block here so we don't block on the generate thread that steps. + jax.block_until_ready(new_request.prefill_result) + # Place the request on the correct generate backlog and block if full. + self._generate_backlogs[target_idx].put(new_request, block=True) logging.info( - "Placed request on the generate queue, generate_backlogs=%d", - self._generate_backlogs[idx].qsize(), + "Successfully transferred prefill " + "from prefill engine %d to generate engine %d.", + idx, + target_idx, ) def _generate_thread(self, idx: int): @@ -463,6 +528,7 @@ def _generate_thread(self, idx: int): generate_timestep = 0 # State to store things like running kv cache in. decode_state = generate_engine.init_decode_state() + generate_params = self._generate_params[idx] logging.info("---------Generate params %d loaded.---------", idx) time_of_last_generate = time.time() @@ -480,7 +546,6 @@ def _generate_thread(self, idx: int): max_concurrent_decodes = generate_engine.max_concurrent_decodes - # TODO: Move insert to prefill thread. # Check if there are any free my_slots. We don't want to block here since # we can still generate if we can't insert. We do this in a while loop to # insert as many sequences as possible. @@ -499,6 +564,11 @@ def _generate_thread(self, idx: int): # the case when the prefill backlog is cancelled and we end up with no # more useful prefill work to do. block = my_slots_size == max_concurrent_decodes + if self._interleaved_mode: + # For interleaved mode, we also blocks when prefill backlog + # is not empty or there are transfer work to do. + block |= not self._prefill_backlog.empty() + block |= not self._transfer_backlogs[idx].empty() try: new_request = my_generate_backlog.get(block=block, timeout=1.0) # Got free slot and new request, use them. @@ -598,7 +668,6 @@ def _detokenize_thread(self, idx: int): # Place the slot back on the free queue. my_live_requests[slot] = None my_slots.put(slot, block=False) # This should always have space. - self._ready_to_prefill.set() logging.info( "Detokenizing generate step %d took %.2fms", generate_timestep_added, diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 8911b9f6..d66af518 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -112,11 +112,15 @@ def run( generate_params = [ge.load_params() for ge in engines.generate_engines] shared_params = [ie.load_params() for ie in engines.interleaved_engines] logging.info("Loaded all weights.") + interleaved_mode = ( + len(config.prefill_slices) + len(config.generate_slices) == 0 + ) driver = orchestrator.Driver( prefill_engines=engines.prefill_engines + engines.interleaved_engines, generate_engines=engines.generate_engines + engines.interleaved_engines, prefill_params=prefill_params + shared_params, generate_params=generate_params + shared_params, + interleaved_mode=interleaved_mode, jax_padding=jax_padding, ) # We default threads to the total number of concurrent allowed decodes, @@ -130,8 +134,8 @@ def run( def get_devices() -> Any: - """Gets devices locally.""" - # Run interleaved engine on local device. + """Gets devices.""" + # TODO: Add more logs for the devices. devices = jax.devices() - logging.info("Using local devices for interleaved serving: %d", len(devices)) + logging.info("Using devices: %d", len(devices)) return devices