Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual model warmup to resolve AOT model warmup performance degradation #126

Merged
merged 7 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ class ActiveRequest:
prefill_result: Any = None
#################### Information relevant for prefill ########################
prefill_content: Optional[str | list[int]] = None
padded_token_length: Optional[int] = None
################## Information relevant for detokenization ###################
# Which generate step this was added at.
generate_timestep_added: Optional[int] = None
Expand Down Expand Up @@ -502,19 +501,13 @@ def _prefill_thread(self, idx: int):
padded_tokens, true_length = self._process_prefill_content(
request, tokenizer, is_bos, prefill_engine.max_prefill_length
)
if isinstance(prefill_engine, engine_api.JetStreamEngine):
request.padded_token_length = token_utils.take_nearest_length(
vivianrwu marked this conversation as resolved.
Show resolved Hide resolved
prefill_engine.prefill_buckets, true_length
)
prefill_engine.set_padded_token_length(request.padded_token_length)

# Compute new kv cache for the prefill_content.
prefill_result, first_token = prefill_engine.prefill(
params=prefill_params,
padded_tokens=padded_tokens,
true_length=true_length,
)

request.prefill_result = prefill_result

# put first token to detokenize queue
Expand Down Expand Up @@ -678,11 +671,6 @@ def _generate_thread(self, idx: int):
generate_timestep,
)

if isinstance(generate_engine, engine_api.JetStreamEngine):
vivianrwu marked this conversation as resolved.
Show resolved Hide resolved
generate_engine.set_padded_token_length(
new_request.padded_token_length
)

decode_state = generate_engine.insert(
new_request.prefill_result, decode_state, slot=slot
)
Expand Down
8 changes: 4 additions & 4 deletions jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from jetstream.core import orchestrator
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
from jetstream.core.proto import jetstream_pb2_grpc
from jetstream.engine import aot_utils, engine_api
from jetstream.engine import warmup_utils, engine_api

from prometheus_client import start_http_server

Expand Down Expand Up @@ -107,7 +107,7 @@ def create_driver(
devices: Device objects, will be used to get engine with proper slicing.
jax_padding: The flag to enable JAX padding during tokenization.
metrics_collector: The JetStream Promethus metric collector.
enable_model_warmup: The flag to enable model server warmup with AOT.
enable_model_warmup: The flag to enable model server warmup.

Returns:
An orchestrator driver.
Expand Down Expand Up @@ -142,7 +142,7 @@ def create_driver(
]

try:
_ = aot_utils.layout_params_and_compile_executables(
_ = warmup_utils.layout_params_and_compile_executables(
prefill_engines, # pylint: disable=protected-access
generate_engines, # pylint: disable=protected-access
prefill_params, # pylint: disable=protected-access
Expand Down Expand Up @@ -191,7 +191,7 @@ def run(
metrics_server_config: The config to enable Promethus metric server.
enable_jax_profiler: The flag to enable JAX profiler server.
jax_profiler_port: The port JAX profiler server (default to 9999).
enable_model_warmup: The flag to enable model server warmup with AOT.
enable_model_warmup: The flag to enable model server warmup.

Returns:
JetStreamServer that wraps the grpc server and orchestrator driver.
Expand Down
22 changes: 4 additions & 18 deletions jetstream/engine/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,13 @@ def colocated_cpus(self) -> Union[list[CpuDevices], None]:
class JetStreamEngine(Engine):
"""A wrapper engine of the Engine class.

JetStreamEngine defines the AOT warmed up model server engine.
JetStreamEngine defines the warmed up model server engine.
"""

def __init__(self, downstream_engine: Engine):
self._downstream_engine = downstream_engine

# Executables
self.prefill_executable = None
self.insert_executable = None
self.generate_executable = None

self.prefill_buckets = None

# Nearest right token length
self._padded_token_length = None

self.warm = False

def prefill(
Expand All @@ -284,9 +275,7 @@ def prefill(
true_length: int,
) -> Tuple[Prefix, ResultTokens]:

prefill_result, first_token = self.prefill_executable[
self.padded_token_length
](
prefill_result, first_token = self._downstream_engine.prefill(
params=params,
padded_tokens=padded_tokens,
true_length=true_length,
Expand All @@ -300,7 +289,7 @@ def insert(
slot: int,
) -> DecodeState:

decode_state = self.insert_executable[self.padded_token_length](
decode_state = self._downstream_engine.insert(
prefix=prefix,
decode_state=decode_state,
slot=slot,
Expand All @@ -310,7 +299,7 @@ def insert(
def generate(
self, params: Params, decode_state: DecodeState
) -> Tuple[DecodeState, ResultTokens]:
decode_state, sampled_tokens = self.generate_executable( # pylint: disable=not-callable
decode_state, sampled_tokens = self._downstream_engine.generate(
params=params, decode_state=decode_state
)
return decode_state, sampled_tokens
Expand Down Expand Up @@ -355,6 +344,3 @@ def mesh(self) -> jax.sharding.Mesh:
@property
def colocated_cpus(self) -> Union[list[CpuDevices], None]:
return self._downstream_engine.colocated_cpus

def set_padded_token_length(self, padded_token_length: int):
self.padded_token_length = padded_token_length
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""AOT compilation utils."""
"""Model server warmup utils."""

import jax
import jax.numpy as jnp
import concurrent.futures
from typing import Any, Optional, cast
from typing import Any, Optional
import logging
from jetstream.engine import engine_api, token_utils

Expand All @@ -44,34 +43,30 @@ def layout_params_and_compile_executables(
any_prefill_engine = None
any_prefill_params = None

prefill_executables = []
inserts_generate_executables = []
prefills_compiled = []
inserts_generate_compiled = []

for i, pe in enumerate(prefill_engines):
any_prefill_engine = pe
any_prefill_params = prefill_params[i]
prefill_executable = initialize_prefill_jit_cache(
prefill_compiled = initialize_prefill_jit_cache(
prefill_engine=pe,
prefill_params=prefill_params[i],
prefill_idx=i,
)
prefill_executables.append(prefill_executable)
prefills_compiled.append(prefill_compiled)

for i, ge in enumerate(generate_engines):
insert_executable, generate_executable = (
initialize_insert_generate_jit_cache(
prefill_engine=any_prefill_engine,
generate_engine=ge,
prefill_params=any_prefill_params,
generate_params=generate_params[i],
generate_idx=i,
)
)
inserts_generate_executables.append(
[insert_executable, generate_executable]
insert_generate_compiled = initialize_insert_generate_jit_cache(
prefill_engine=any_prefill_engine,
generate_engine=ge,
prefill_params=any_prefill_params,
generate_params=generate_params[i],
generate_idx=i,
)
inserts_generate_compiled.append([insert_generate_compiled])

if prefill_executables and inserts_generate_executables:
if any(prefills_compiled) and any(inserts_generate_compiled):
vivianrwu marked this conversation as resolved.
Show resolved Hide resolved
return True
return False

Expand Down Expand Up @@ -104,47 +99,32 @@ def initialize_prefill_jit_cache(
def compile_prefill(length):
padded_tokens, true_length = jnp.ones((length), dtype="int32"), length

lowered = jax.jit(
prefill_engine._downstream_engine.prefill, # pylint: disable=protected-access
out_shardings=prefill_engine.get_prefix_destination_sharding(),
).lower(
_, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access
params=prefill_params,
padded_tokens=padded_tokens,
true_length=true_length,
)
logging.info(
"---------Prefill engine %d lowered for prefill length %d.---------",
prefill_idx,
length,
)
compiled = lowered.compile()

logging.info(
"---------Prefill engine %d compiled for prefill length %d.---------",
prefill_idx,
length,
)
return compiled

logging.info("---------Prefill compilation %d begun.---------", prefill_idx)

with concurrent.futures.ThreadPoolExecutor(
max_workers=len(prefill_buckets)
) as executor:
prefill_executable = list(executor.map(compile_prefill, prefill_buckets))

prefill_executable = {
k: cast(jax.stages.Compiled, e)
for k, e in zip(prefill_buckets, prefill_executable)
}
_ = executor.map(compile_prefill, prefill_buckets)

prefill_engine.prefill_executable = prefill_executable
prefill_engine.warm = True

logging.info(
"---------Prefill compilation %d complete.---------", prefill_idx
)

return prefill_executable
return prefill_engine.warm


def initialize_insert_generate_jit_cache(
Expand Down Expand Up @@ -184,39 +164,25 @@ def compile_insert(length):
true_length=true_length,
)

lowered = jax.jit(generate_engine._downstream_engine.insert).lower( # pylint: disable=protected-access
prefix=prefill, decode_state=decode_state, slot=1
)
logging.info(
"---------Generate engine %d lowered for insert length %d.---------",
generate_idx,
length,
)
compiled = lowered.compile()
generate_engine.insert(prefix=prefill, decode_state=decode_state, slot=0)

logging.info(
"---------Generate engine %d compiled for insert length %d.---------",
generate_idx,
length,
)
return compiled

def compile_generate():

logging.info(
"---------Generate compilation %d begun.---------", generate_idx
)

lowered = jax.jit(generate_engine._downstream_engine.generate).lower( # pylint: disable=protected-access
generate_engine._downstream_engine.generate( # pylint: disable=protected-access
params=generate_params,
decode_state=decode_state,
)
logging.info(
"---------Generate engine %d lowered.---------",
generate_idx,
)

compiled = lowered.compile()
logging.info(
"---------Generate engine %d compiled.---------",
generate_idx,
Expand All @@ -226,35 +192,28 @@ def compile_generate():
"---------Generate compilation %d complete.---------", generate_idx
)

return compiled

logging.info(
"---------Insertion generation compilation %d begun.---------",
generate_idx,
)

generate_executable = compile_generate()
compile_generate()

logging.info(
"---------Generate engine %d compiled generation step.---------",
generate_idx,
)
generate_engine.generate_executable = generate_executable

with concurrent.futures.ThreadPoolExecutor(
max_workers=len(prefill_buckets)
) as executor:
insert_executable = list(executor.map(compile_insert, prefill_buckets))
_ = executor.map(compile_insert, prefill_buckets)

insert_executable = {
k: cast(jax.stages.Compiled, e)
for k, e in zip(prefill_buckets, insert_executable)
}
generate_engine.insert_executable = insert_executable
generate_engine.warm = True

logging.info(
"---------Insertion generation compilation %d complete.---------",
generate_idx,
)

return insert_executable, generate_executable
return generate_engine.warm
Loading