Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual model warmup to resolve AOT model warmup performance degradation #126

Merged
merged 7 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,11 +502,6 @@ def _prefill_thread(self, idx: int):
padded_tokens, true_length = self._process_prefill_content(
request, tokenizer, is_bos, prefill_engine.max_prefill_length
)
if isinstance(prefill_engine, engine_api.JetStreamEngine):
request.padded_token_length = token_utils.take_nearest_length(
vivianrwu marked this conversation as resolved.
Show resolved Hide resolved
prefill_engine.prefill_buckets, true_length
)
prefill_engine.set_padded_token_length(request.padded_token_length)

# Compute new kv cache for the prefill_content.
prefill_result, first_token = prefill_engine.prefill(
Expand Down Expand Up @@ -678,11 +673,6 @@ def _generate_thread(self, idx: int):
generate_timestep,
)

if isinstance(generate_engine, engine_api.JetStreamEngine):
vivianrwu marked this conversation as resolved.
Show resolved Hide resolved
generate_engine.set_padded_token_length(
new_request.padded_token_length
)

decode_state = generate_engine.insert(
new_request.prefill_result, decode_state, slot=slot
)
Expand Down
87 changes: 23 additions & 64 deletions jetstream/engine/aot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

"""AOT compilation utils."""

import jax
import jax.numpy as jnp
import concurrent.futures
from typing import Any, Optional, cast
from typing import Any, Optional
import logging
from jetstream.engine import engine_api, token_utils

Expand All @@ -44,34 +43,30 @@ def layout_params_and_compile_executables(
any_prefill_engine = None
any_prefill_params = None

prefill_executables = []
inserts_generate_executables = []
prefills_compiled = []
inserts_generate_compiled = []

for i, pe in enumerate(prefill_engines):
any_prefill_engine = pe
any_prefill_params = prefill_params[i]
prefill_executable = initialize_prefill_jit_cache(
prefill_compiled = initialize_prefill_jit_cache(
prefill_engine=pe,
prefill_params=prefill_params[i],
prefill_idx=i,
)
prefill_executables.append(prefill_executable)
prefills_compiled.append(prefill_compiled)

for i, ge in enumerate(generate_engines):
insert_executable, generate_executable = (
initialize_insert_generate_jit_cache(
prefill_engine=any_prefill_engine,
generate_engine=ge,
prefill_params=any_prefill_params,
generate_params=generate_params[i],
generate_idx=i,
)
)
inserts_generate_executables.append(
[insert_executable, generate_executable]
insert_generate_compiled = initialize_insert_generate_jit_cache(
prefill_engine=any_prefill_engine,
generate_engine=ge,
prefill_params=any_prefill_params,
generate_params=generate_params[i],
generate_idx=i,
)
inserts_generate_compiled.append([insert_generate_compiled])

if prefill_executables and inserts_generate_executables:
if prefills_compiled and inserts_generate_compiled:
return True
return False

Expand Down Expand Up @@ -104,47 +99,32 @@ def initialize_prefill_jit_cache(
def compile_prefill(length):
padded_tokens, true_length = jnp.ones((length), dtype="int32"), length

lowered = jax.jit(
prefill_engine._downstream_engine.prefill, # pylint: disable=protected-access
out_shardings=prefill_engine.get_prefix_destination_sharding(),
).lower(
_, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access
params=prefill_params,
padded_tokens=padded_tokens,
true_length=true_length,
)
logging.info(
"---------Prefill engine %d lowered for prefill length %d.---------",
prefill_idx,
length,
)
compiled = lowered.compile()

logging.info(
"---------Prefill engine %d compiled for prefill length %d.---------",
prefill_idx,
length,
)
return compiled

logging.info("---------Prefill compilation %d begun.---------", prefill_idx)

with concurrent.futures.ThreadPoolExecutor(
max_workers=len(prefill_buckets)
) as executor:
prefill_executable = list(executor.map(compile_prefill, prefill_buckets))

prefill_executable = {
k: cast(jax.stages.Compiled, e)
for k, e in zip(prefill_buckets, prefill_executable)
}
_ = executor.map(compile_prefill, prefill_buckets)

prefill_engine.prefill_executable = prefill_executable
prefill_engine.warm = True

logging.info(
"---------Prefill compilation %d complete.---------", prefill_idx
)

return prefill_executable
return prefill_engine.warm


def initialize_insert_generate_jit_cache(
Expand Down Expand Up @@ -184,39 +164,25 @@ def compile_insert(length):
true_length=true_length,
)

lowered = jax.jit(generate_engine._downstream_engine.insert).lower( # pylint: disable=protected-access
prefix=prefill, decode_state=decode_state, slot=1
)
logging.info(
"---------Generate engine %d lowered for insert length %d.---------",
generate_idx,
length,
)
compiled = lowered.compile()
generate_engine.insert(prefix=prefill, decode_state=decode_state, slot=0)

logging.info(
"---------Generate engine %d compiled for insert length %d.---------",
generate_idx,
length,
)
return compiled

def compile_generate():

logging.info(
"---------Generate compilation %d begun.---------", generate_idx
)

lowered = jax.jit(generate_engine._downstream_engine.generate).lower( # pylint: disable=protected-access
generate_engine._downstream_engine.generate( # pylint: disable=protected-access
params=generate_params,
decode_state=decode_state,
)
logging.info(
"---------Generate engine %d lowered.---------",
generate_idx,
)

compiled = lowered.compile()
logging.info(
"---------Generate engine %d compiled.---------",
generate_idx,
Expand All @@ -226,35 +192,28 @@ def compile_generate():
"---------Generate compilation %d complete.---------", generate_idx
)

return compiled

logging.info(
"---------Insertion generation compilation %d begun.---------",
generate_idx,
)

generate_executable = compile_generate()
compile_generate()

logging.info(
"---------Generate engine %d compiled generation step.---------",
generate_idx,
)
generate_engine.generate_executable = generate_executable

with concurrent.futures.ThreadPoolExecutor(
max_workers=len(prefill_buckets)
) as executor:
insert_executable = list(executor.map(compile_insert, prefill_buckets))
_ = executor.map(compile_insert, prefill_buckets)

insert_executable = {
k: cast(jax.stages.Compiled, e)
for k, e in zip(prefill_buckets, insert_executable)
}
generate_engine.insert_executable = insert_executable
generate_engine.warm = True

logging.info(
"---------Insertion generation compilation %d complete.---------",
generate_idx,
)

return insert_executable, generate_executable
return generate_engine.warm
20 changes: 3 additions & 17 deletions jetstream/engine/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,7 @@ class JetStreamEngine(Engine):
def __init__(self, downstream_engine: Engine):
self._downstream_engine = downstream_engine

# Executables
self.prefill_executable = None
self.insert_executable = None
self.generate_executable = None

self.prefill_buckets = None

# Nearest right token length
self._padded_token_length = None

self.warm = False

def prefill(
Expand All @@ -284,9 +275,7 @@ def prefill(
true_length: int,
) -> Tuple[Prefix, ResultTokens]:

prefill_result, first_token = self.prefill_executable[
self.padded_token_length
](
prefill_result, first_token = self._downstream_engine.prefill(
params=params,
padded_tokens=padded_tokens,
true_length=true_length,
Expand All @@ -300,7 +289,7 @@ def insert(
slot: int,
) -> DecodeState:

decode_state = self.insert_executable[self.padded_token_length](
decode_state = self._downstream_engine.insert(
prefix=prefix,
decode_state=decode_state,
slot=slot,
Expand All @@ -310,7 +299,7 @@ def insert(
def generate(
self, params: Params, decode_state: DecodeState
) -> Tuple[DecodeState, ResultTokens]:
decode_state, sampled_tokens = self.generate_executable( # pylint: disable=not-callable
decode_state, sampled_tokens = self._downstream_engine.generate(
params=params, decode_state=decode_state
)
return decode_state, sampled_tokens
Expand Down Expand Up @@ -355,6 +344,3 @@ def mesh(self) -> jax.sharding.Mesh:
@property
def colocated_cpus(self) -> Union[list[CpuDevices], None]:
return self._downstream_engine.colocated_cpus

def set_padded_token_length(self, padded_token_length: int):
self.padded_token_length = padded_token_length
Loading