Skip to content

Commit

Permalink
jit the entire prefill_insert function
Browse files Browse the repository at this point in the history
  • Loading branch information
sixiang-google committed Jan 15, 2025
1 parent fd21bd6 commit 05ddd3f
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 164 deletions.
2 changes: 2 additions & 0 deletions MaxText/inference_mlperf/llama_offline_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ fi
# LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
# makes subsequent runs faster
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
export JAX_TRACEBACK_FILTERING=off
export LIBTPU_INIT_ARGS

run_loadgen() {
Expand Down Expand Up @@ -170,6 +171,7 @@ run_loadgen_accuracy () {
fi

${CMD} python3 ${EVAL_SCRIPT} \
--checkpoint-path meta-llama/Llama-2-70b-chat-hf
--tokenizer-path ${TOKENIZER_PATH} \
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log
Expand Down
203 changes: 104 additions & 99 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class InputData:
tokens: jax.Array
true_length: int


class JetThread(threading.Thread):

def run(self):
Expand All @@ -50,6 +51,7 @@ def run(self):
traceback.print_exc()
os.kill(os.getpid(), signal.SIGKILL)


class OfflineInference:

def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.Engine, enable_batch_prefill: bool):
Expand All @@ -75,6 +77,7 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En
self._cached_generate = None
self.detokenize_backlog = queue.Queue(10)
self.prefill_buckets = defaultdict(list)

def init_decode_state(self):
if self.decode_state is None:
self.decode_state = self.engine.init_decode_state()
Expand All @@ -96,79 +99,84 @@ def warmup(self, max_length, warmup_samples):
log.info(f"Compiling prefill: {length}")
input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32"))
self._cached_pref[length] = (
jax.jit(self._prefill_insert, donate_argnums=(4,))
.lower(
self.params,
tokens=input_data,
slot=0,
true_length=length - 1,
decode_state=self.decode_state)
.compile()
jax.jit(self._prefill_insert, donate_argnums=(4,))
.lower(self.params, tokens=input_data, slot=0, true_length=length - 1, decode_state=self.decode_state)
.compile()
)
if length == 64 or length == 1024:
continue
log.info(f"Compiling batched prefill: {length}")
input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32"))
num_prompts = max_length // length
self._cached_pref_batch[length] = (
jax.jit(
self._prefill_insert_batch,
static_argnames=(
"num_prompts",
"padded_length",
),
donate_argnames=("decode_state",),
)
.lower(
self.params,
tokens=input_data_batch,
slots=jnp.arange(0, 8, dtype=int),
num_prompts=num_prompts,
decoder_positions=jnp.arange(0, max_length, dtype=int),
decoder_segment_ids=jnp.ones(max_length, dtype=int),
start_pos=jnp.arange(0, max_length, 128, dtype=int),
padded_length=length,
true_lengths=jnp.full(8, length, dtype=int),
decode_state=self.decode_state,
)
.compile()
)
# input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32"))
# example_seq_len=16
# num_prompts = max_length//length
# self._cached_pref_batch[length] = (
# jax.jit(self._prefill_insert_batch, donate_argnums=(4,))
# .lower(
# self.params,
# tokens=input_data_batch,
# slots=jnp.arange(0, example_seq_len),
# num_prompts = 16,
# decoder_positions = jnp.arange(0, max_length),
# decoder_segment_ids = jnp.ones(max_length),
# start_pos=jnp.arange(0, max_length, max_length//example_seq_len),
# padded_lengths=jnp.arange(0, max_length, max_length//example_seq_len),
# true_lengths=jnp.arange(0, max_length, max_length//example_seq_len),
# decode_state=self.decode_state)
# .compile()
# )
self.batch_inference(warmup_samples, desc="warmup")
self._cached_generate = (
jax.jit(self.engine.generate, donate_argnums=(1,))
.lower(self.params, self.decode_state)
.compile()
jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile()
)

self.batch_inference(warmup_samples, desc="warmup")

def _prefill_insert(self, params, tokens, slot, true_length, decode_state):
"""return decodestate."""
padded_len = tokens.shape[0]
prefill_result, first_token = self.engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = self.engine.insert(prefill_result, decode_state, slot)
return first_token, decode_state
def _prefill_insert_batch(self, params, tokens, slots, num_prompts,
decoder_positions, decoder_segment_ids,
start_pos, padded_lengths, true_lengths,
decode_state):

def _prefill_insert_batch(
self,
params,
tokens,
slots,
num_prompts,
decoder_positions,
decoder_segment_ids,
start_pos,
padded_length,
true_lengths,
decode_state,
):
"""return decodestate."""
cache, prefill_results, first_tokens = self.engine.prefill_concat(
params = params,
padded_tokens = tokens,
decoder_positions = decoder_positions,
decoder_segment_ids = decoder_segment_ids,
start_pos = start_pos,
true_lengths = true_lengths,
num_prompts = num_prompts)
# decode_state = jax.lax.fori_loop(
# 0, num_prompts,
# lambda i, state: self.engine.insert(
# prefill_results[i],
# state,
# slot=slots[i],
# start_idx = start_pos[i],
# seq_len = padded_lengths[i]),
# decode_state
# )
for i in range(num_prompts):
decode_state = self.engine.insert_partial(
prefill_results[i],
params=params,
padded_tokens=tokens,
decoder_positions=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
start_pos=start_pos,
true_lengths=true_lengths,
num_prompts=num_prompts,
)
decode_state = self.engine.insert_partial(
prefill_results,
decode_state,
cache,
slots[i],
start_idx = start_pos[i].item(),
seq_len = padded_lengths[i].item()
)
cache,
slots,
num_prompts=num_prompts,
start_indices=start_pos,
seq_len=padded_length,
)
return first_tokens, decode_state

def batch_inference_with_callback(
self,
data: List[InputData],
Expand All @@ -186,18 +194,14 @@ def prefill(prefill_bucket, prefill_len):
if self.dummy:
log.info("dummy prefill")
return 123
if not self.enable_batch_prefill or prefill_len * len(prefill_bucket) != 1024:
if not self.enable_batch_prefill or prefill_len in (64, 1024) or prefill_len * len(prefill_bucket) != 1024:
prefill_result = []
prefill_fn = self._prefill_insert
if (cached := self._cached_pref.get(prefill_len)) is not None:
prefill_fn = cached
for (slot, row) in prefill_bucket:
for slot, row in prefill_bucket:
first_token, self.decode_state = prefill_fn(
self.params,
tokens=row.tokens,
slot=slot,
true_length=row.true_length,
decode_state=self.decode_state
self.params, tokens=row.tokens, slot=slot, true_length=row.true_length, decode_state=self.decode_state
)
prefill_result.append((first_token, slot, row))
return prefill_result
Expand All @@ -212,45 +216,43 @@ def prefill(prefill_bucket, prefill_len):
for idx, (slot, row) in enumerate(prefill_bucket):
zero_to_n = np.arange(0, row.tokens.shape[0])
ones_to_keep = zero_to_n < row.true_length
one_d_output = (zero_to_n < row.true_length).astype(int) * (idx * 2 + 1) + \
(zero_to_n >= row.true_length).astype(int) * (idx + 1) * 2
one_d_output = (zero_to_n < row.true_length).astype(int) * (idx * 2 + 1) + (zero_to_n >= row.true_length).astype(
int
) * (idx + 1) * 2
sequence_indicators.append(one_d_output)
sequence_indicator = jnp.array(np.concatenate(sequence_indicators))

tokens = jnp.concat([row.tokens for (slot, row) in prefill_bucket])

slots = [slot for (slot, row) in prefill_bucket]
padded_lengths = [row.tokens.shape[0] for (slot, row) in prefill_bucket]
true_lengths = [row.true_length for (slot, row) in prefill_bucket]
start_pos = np.cumsum([0]+[row.tokens.shape[0] for (slot, row) in prefill_bucket])[:-1]
start_pos = np.cumsum([0] + [row.tokens.shape[0] for (slot, row) in prefill_bucket])[:-1]
start_pos = start_pos.tolist()
#pad slots to keep static shape of jitted function input

# pad slots to keep static shape of jitted function input
def pad_num_prompts_len_array(array_to_pad, pad_len):
if len(array_to_pad) < pad_len:
array_to_pad.extend([0] * (pad_len - len(array_to_pad)))
return jnp.array(array_to_pad)

slots = pad_num_prompts_len_array(slots, 8)
padded_lengths = pad_num_prompts_len_array(padded_lengths, 8)
true_lengths = pad_num_prompts_len_array(true_lengths, 8)
start_pos = pad_num_prompts_len_array(start_pos, 8)

first_tokens, self.decode_state = prefill_fn(
self.params,
tokens=tokens,
slots=slots,
num_prompts = len(prefill_bucket),
decoder_positions = positions,
decoder_segment_ids = sequence_indicator,
start_pos = start_pos,
padded_lengths = padded_lengths,
true_lengths = true_lengths,
decode_state=self.decode_state,
self.params,
tokens=tokens,
slots=slots,
decoder_positions=positions,
decoder_segment_ids=sequence_indicator,
start_pos=start_pos,
true_lengths=true_lengths,
decode_state=self.decode_state,
)
prefill_result = [(first_tokens[idx], slot, row) for (idx, (slot, row)) in enumerate(prefill_bucket)]

return prefill_result



empty_slots = list(range(self.batch_size))
slot_to_id = {}
num_prefills = {}
Expand Down Expand Up @@ -283,7 +285,7 @@ def decode():
for i in range(5):
# result_tokens.copy_to_host_async()
result_tokens = result_tokens_l[i].convert_to_numpy()
self.detokenize_backlog.put((result_tokens, False, 0, 0), block = True)
self.detokenize_backlog.put((result_tokens, False, 0, 0), block=True)
# log.info(f"Decode put result {i} to queue")

def detokenize():
Expand Down Expand Up @@ -319,13 +321,16 @@ def detokenize():
empty_slots.append(slot)
if newly_empty and self.detokenize_backlog.qsize() == 0 and len(slot_to_id.items()) == 0:
break

detokenize_thread = JetThread(
target=functools.partial(detokenize,),
target=functools.partial(
detokenize,
),
name="detokenize",
)
self.live = True
detokenize_thread.start()
num_prefill = 0
total_num_prefills = 0
for row in data:
while not empty_slots:
# If slots are all full, decode until there are free slots
Expand All @@ -339,29 +344,29 @@ def detokenize():
log.info(
f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}"
)
num_prefill += 1
log.info(f"Total num prefill: {num_prefill}")
total_num_prefills += 1
log.info(f"Total num prefill: {total_num_prefills}")
slot = empty_slots.pop()
#directly prefill prompts with 64 or less tokens
# directly prefill prompts with 64 or less tokens, and with 1024 tokens
if num_tokens in (64, 1024) or not self.enable_batch_prefill:
first_token, slot, row = prefill([(slot, row)], num_tokens)[0]
self.detokenize_backlog.put((first_token, True, row.id, slot), block = True)
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
continue
self.prefill_buckets[num_tokens].append((slot, row))
prefill_buckets_len = {k: len(self.prefill_buckets[k]) for k in self.prefill_buckets}
log.info(f"prefill buckets {prefill_buckets_len}")
if len(self.prefill_buckets[num_tokens]) * num_tokens == 1024:
prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens)
for (first_token, slot, row) in prefill_results:
for first_token, slot, row in prefill_results:
log.info(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog")
self.detokenize_backlog.put((first_token, True, row.id, slot), block = True)
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
self.prefill_buckets[num_tokens] = []

# For leftover requests in buckets at the end of computation, do prefill individually.
for num_tokens in self.prefill_buckets.keys():
prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens)
for (first_token, slot, row) in prefill_results:
for first_token, slot, row in prefill_results:
log.info(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog")
self.detokenize_backlog.put((first_token, True, row.id, slot), block = True)
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
self.prefill_buckets = defaultdict(list)
while slot_to_id:
log.info(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
Expand Down
Loading

0 comments on commit 05ddd3f

Please sign in to comment.