Skip to content

Commit

Permalink
Prerequisite work for supporting disaggregation: (#68)
Browse files Browse the repository at this point in the history
1. Add transfer thread to transfer KV Cache.
2. For interleaved mode, prioritize prefill and improve the HBM
   utilization.

Co-authored-by: Zhihao Shan <zhihaoshan@google.com>
  • Loading branch information
zhihaoshan-google and Zhihao Shan authored May 1, 2024
1 parent 2db6c14 commit a3546e8
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 35 deletions.
133 changes: 101 additions & 32 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,23 @@ class Driver:
# Stage 1
_prefill_backlog: queue.Queue[ActiveRequest | None]
# Stage 2
_transfer_backlogs: list[queue.Queue[ActiveRequest]] = []
# Stage 3
# We keep this as a dict to avoid a possibly expensive object comparison
# when logging the index of the generate engine we send a prefill result
# to, it allows us to natively have the index from the min operation, rather
# than have to call .index()
_generate_backlogs: dict[int, queue.Queue[ActiveRequest | None]] = {}
# Stage 3
_generate_backlogs: dict[int, queue.Queue[ActiveRequest]] = {}
# Stage 4
# This can be a list because we can pass it as an arg to generate and
# detokenize threads. It is a list of tokens to be detokenized.
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_generate_slots: list[queue.Queue[int]] = []
_active_requests: list[queue.Queue[tuple[int, ActiveRequest | None]]] = []
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = []

# For interleaved_mode, only generate if all slots are full
# or corresponding prefill queue is empty.
_interleaved_mode: bool = False

# todo: remove jax_padding after all then engine migrate to np padding
_jax_padding = True
Expand All @@ -209,6 +215,7 @@ def __init__(
generate_engines: Optional[list[engine_api.Engine]] = None,
prefill_params: Optional[list[Any]] = None,
generate_params: Optional[list[Any]] = None,
interleaved_mode: bool = False,
jax_padding: bool = True,
):
if prefill_engines is None:
Expand All @@ -229,22 +236,39 @@ def __init__(
self._generate_engines = generate_engines
self._prefill_params = prefill_params
self._generate_params = generate_params
self._interleaved_mode = interleaved_mode

# Stages 1-4 represent the life cycle of a request.
# Stage 1
# At first, a request is placed here in order to get prefilled.
self._prefill_backlog = queue.Queue()
# _ready_to_prefill event will block the prefill thread until there is
# available decode slot to insert the prefill result.
self._ready_to_prefill = threading.Event()
# Stage 2
# After prefilling, it is placed here in order to get transferred to
# one of the generate backlogs.
# Interleaved Mode: Max size is 1 to increase the HBM utilization
# during generate.
# Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
# while 1 transfer is enqueued while 1 is being transferred.
# TODO: Make queue size configurable.
self._transfer_backlogs = [
queue.Queue(1 if self._interleaved_mode else 4)
for i in range(len(self._prefill_engines))
]
# Stage 3
# Each generate engine accesses its own generate backlog.
# Interleaved Mode: Max size is 1 to increase the HBM utilization
# during generate.
# Disaggregated Mode: Set as 1/3 the number of concurrent decodes.
# TODO: Calculate the backlog to saturate the generate engine while
# minimizing the memory usage for disaggregated mode.
# TODO: Make queue size configurable.
self._generate_backlogs = {
# Don't receive more than 1/3 the number of concurrent decodes to avoid
# OOM for single host.
idx: queue.Queue(engine.max_concurrent_decodes // 3)
idx: queue.Queue(
1 if self._interleaved_mode else engine.max_concurrent_decodes // 3
)
for idx, engine in enumerate(self._generate_engines)
}
# Stage 3
# Stage 4
# After generation, ActiveRequests are placed on the detokenization backlog
# for tokens to be sent into each ActiveRequest's return channel.
# We have one of these per generate engine to simplify the logic keeping
Expand Down Expand Up @@ -293,6 +317,18 @@ def __init__(
JetThread(
target=functools.partial(self._prefill_thread, idx),
name=f"prefill-{idx}",
daemon=True,
)
for idx in range(len(self._prefill_engines))
]
self._transfer_threads = [
JetThread(
target=functools.partial(
self._transfer_thread,
idx,
),
name=f"transfer-{idx}",
daemon=True,
)
for idx in range(len(self._prefill_engines))
]
Expand All @@ -303,6 +339,7 @@ def __init__(
idx,
),
name=f"generate-{idx}",
daemon=True,
)
for idx in range(len(self._generate_engines))
]
Expand All @@ -319,6 +356,7 @@ def __init__(
self._all_threads = list(
itertools.chain(
self._prefill_threads,
self._transfer_threads,
self._generate_threads,
self.detokenize_threads,
)
Expand All @@ -336,6 +374,7 @@ def stop(self):
all_backlogs = list(
itertools.chain(
[self._prefill_backlog],
self._transfer_backlogs,
self._generate_backlogs.values(),
self._detokenize_backlogs,
)
Expand Down Expand Up @@ -400,24 +439,11 @@ def _prefill_thread(self, idx: int):
logging.info("---------Prefill params %d loaded.---------", idx)

while self.live:
# The prefill thread can wait until there is available decode slot to
# insert.
if self._generate_slots[idx].qsize() == 0:
logging.info(
"Prefill waits for available slot; prefill queue size %d",
self._prefill_backlog.qsize(),
)
self._ready_to_prefill.wait()
logging.info(
"Prefill continues; prefill queue size %d",
self._prefill_backlog.qsize(),
)
my_transfer_backlog = self._transfer_backlogs[idx]
# The prefill thread can just sleep until it has work to do.
request = self._prefill_backlog.get(block=True)
if request is None:
break
# TODO: Implement hot/cold cache for history.
history = self._load_cache_history(request.history_path) # pylint: disable = assignment-from-none
# Tokenize, and introduce a leading dimension
is_bos = not bool(request.history_path)
logging.info(
Expand All @@ -434,21 +460,60 @@ def _prefill_thread(self, idx: int):
max_prefill_length=prefill_engine.max_prefill_length,
jax_padding=self._jax_padding,
)
# Compute new kv cache for the prefill_text, conditional on
# history.
# Compute new kv cache for the prefill_text.
prefill_result = prefill_engine.prefill(
params=prefill_params,
existing_prefix=history,
padded_tokens=padded_tokens,
true_length=true_length,
)
request.prefill_result = prefill_result
# Once prefill is complete, place it on the generation queue and block if
# full.
self._generate_backlogs[idx].put(request, block=True)
my_transfer_backlog.put(request, block=True)
logging.info(
"Placed request on transfer queue %d, %d queued requests.",
idx,
my_transfer_backlog.qsize(),
)
del prefill_result
del request

def _transfer_thread(self, idx: int):
"""Transfers the kv cache on an active request to the least full
generate backlog."""
transfer_backlog = self._transfer_backlogs[idx]

while self.live:
# The transfer thread can just sleep until it has work to do.
new_request = transfer_backlog.get(block=True)
target_idx = min(
self._generate_backlogs.items(), key=lambda q: q[1].qsize()
)[0]
# Only transfer the KVCache for the disaggregated serving.
# TODO: Remove the conditional after fixing the compatibility.
if not self._interleaved_mode:
logging.info(
"Transferring prefill from prefill engine %d "
"to generate engine %d.",
idx,
target_idx,
)
# Transfer the info to the relevant generate slice.
new_request.prefill_result = jax.device_put(
new_request.prefill_result,
self._generate_engines[
target_idx
].get_prefix_destination_sharding(),
)
# Block here so we don't block on the generate thread that steps.
jax.block_until_ready(new_request.prefill_result)
# Place the request on the correct generate backlog and block if full.
self._generate_backlogs[target_idx].put(new_request, block=True)
logging.info(
"Placed request on the generate queue, generate_backlogs=%d",
self._generate_backlogs[idx].qsize(),
"Successfully transferred prefill "
"from prefill engine %d to generate engine %d.",
idx,
target_idx,
)

def _generate_thread(self, idx: int):
Expand All @@ -463,6 +528,7 @@ def _generate_thread(self, idx: int):
generate_timestep = 0
# State to store things like running kv cache in.
decode_state = generate_engine.init_decode_state()

generate_params = self._generate_params[idx]
logging.info("---------Generate params %d loaded.---------", idx)
time_of_last_generate = time.time()
Expand All @@ -480,7 +546,6 @@ def _generate_thread(self, idx: int):

max_concurrent_decodes = generate_engine.max_concurrent_decodes

# TODO: Move insert to prefill thread.
# Check if there are any free my_slots. We don't want to block here since
# we can still generate if we can't insert. We do this in a while loop to
# insert as many sequences as possible.
Expand All @@ -499,6 +564,11 @@ def _generate_thread(self, idx: int):
# the case when the prefill backlog is cancelled and we end up with no
# more useful prefill work to do.
block = my_slots_size == max_concurrent_decodes
if self._interleaved_mode:
# For interleaved mode, we also blocks when prefill backlog
# is not empty or there are transfer work to do.
block |= not self._prefill_backlog.empty()
block |= not self._transfer_backlogs[idx].empty()
try:
new_request = my_generate_backlog.get(block=block, timeout=1.0)
# Got free slot and new request, use them.
Expand Down Expand Up @@ -598,7 +668,6 @@ def _detokenize_thread(self, idx: int):
# Place the slot back on the free queue.
my_live_requests[slot] = None
my_slots.put(slot, block=False) # This should always have space.
self._ready_to_prefill.set()
logging.info(
"Detokenizing generate step %d took %.2fms",
generate_timestep_added,
Expand Down
10 changes: 7 additions & 3 deletions jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,15 @@ def run(
generate_params = [ge.load_params() for ge in engines.generate_engines]
shared_params = [ie.load_params() for ie in engines.interleaved_engines]
logging.info("Loaded all weights.")
interleaved_mode = (
len(config.prefill_slices) + len(config.generate_slices) == 0
)
driver = orchestrator.Driver(
prefill_engines=engines.prefill_engines + engines.interleaved_engines,
generate_engines=engines.generate_engines + engines.interleaved_engines,
prefill_params=prefill_params + shared_params,
generate_params=generate_params + shared_params,
interleaved_mode=interleaved_mode,
jax_padding=jax_padding,
)
# We default threads to the total number of concurrent allowed decodes,
Expand All @@ -130,8 +134,8 @@ def run(


def get_devices() -> Any:
"""Gets devices locally."""
# Run interleaved engine on local device.
"""Gets devices."""
# TODO: Add more logs for the devices.
devices = jax.devices()
logging.info("Using local devices for interleaved serving: %d", len(devices))
logging.info("Using devices: %d", len(devices))
return devices

0 comments on commit a3546e8

Please sign in to comment.