Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Clean-up Management of Model/Artifact/Engine Info #66

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mlc_llm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,17 @@ def main():
# Post processing of arguments
parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access

core.build_model_from_args(parsed_args)


# if num_shard>1 without -convert-weight-only or --build-model-only, we implicitly run it sequentially
if parsed_args.num_shards > 1 and not (parsed_args.build_model_only or parsed_args.convert_weight_only):
parsed_args.build_model_only = True
parsed_args.convert_weight_only = False # just to be explicit
core.build_model_from_args(parsed_args)

parsed_args.build_model_only = False
parsed_args.convert_weight_only = True
core.build_model_from_args(parsed_args)
else:
core.build_model_from_args(parsed_args)

if __name__ == "__main__":
main()
105 changes: 79 additions & 26 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import pickle
import shutil
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -262,6 +263,15 @@ class BuildArgs:
"action": "store_true",
},
)
no_cache_dump: bool = field(
default=False,
metadata={
"help": (
"Disable dumping `mod_cache_before_build.pkl`. When this flag is set, cached build would not be available."
),
"action": "store_true",
},
)
use_cuda_graph: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -383,6 +393,8 @@ def _parse_args(parsed) -> argparse.Namespace:
model_name.append(f"presharded-{parsed.num_shards}gpu")

parsed.artifact_path = os.path.join(parsed.artifact_path, "-".join(model_name))
parsed.lib_name = f"{parsed.model}-{parsed.quantization.name}-{parsed.target_kind}.{parsed.lib_format}"
parsed.lib_path = os.path.join(parsed.artifact_path, parsed.lib_name)

return parsed

Expand Down Expand Up @@ -590,6 +602,18 @@ def mod_transform_before_build(

return mod_deploy

def dump_build_config(
args: argparse.Namespace
):
build_config_path = os.path.join(args.artifact_path, "build_config.json")
config: Dict[str, Any] = {
"num_shards": args.num_shards,
"quantization": args.quantization.name,
"library_name": args.lib_name,
"build_options": str(args)
}
with open(build_config_path, "w", encoding="utf-8") as outfile:
json.dump(config, outfile, indent=4)

def dump_mlc_chat_config(
args: argparse.Namespace,
Expand Down Expand Up @@ -689,10 +713,7 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3})
ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib)

output_filename = f"{args.model}-{args.quantization.name}-{target_kind}.{args.lib_format}"

utils.debug_dump_shader(ex, f"{args.model}_{args.quantization.name}_{target_kind}", args)
args.lib_path = os.path.join(args.artifact_path, output_filename)
ex.export_library(args.lib_path, **args.export_kwargs)
print(f"Finish exporting to {args.lib_path}")

Expand Down Expand Up @@ -732,7 +753,7 @@ def build_model_from_args(args: argparse.Namespace):
with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f:
config = json.load(i_f)

if not use_cache or args.convert_weight_only:
if not use_cache or args.convert_weight_only or not os.path.exists(cache_path):
model_generators = {
"llama": llama,
"mistral": mistral,
Expand Down Expand Up @@ -824,25 +845,56 @@ def build_model_from_args(args: argparse.Namespace):

utils.save_params(params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1)

if args.model_category != "minigpt":
utils.copy_tokenizer(args)
if args.model_category == "rwkv" or args.model_category == "rwkv_world":
# TODO: refactor config into model definition
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
top_p=0.6,
temperature=1.2,
repetition_penalty=0.996,
rwkv_world=True,
)
else:
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
)
if not args.enable_batching:
if args.model_category == "rwkv" or args.model_category == "rwkv_world":
# TODO: refactor config into model definition
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
top_p=0.6,
temperature=1.2,
repetition_penalty=0.996,
rwkv_world=True,
)
else:
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
)

if args.enable_batching:
# when batching is enabled, we dump info for mlc_serve runtime
dump_build_config(args)
model_info_path = os.path.join(args.artifact_path, "model")
os.makedirs(model_info_path, exist_ok=True)
mlc_model_config_path = os.path.join(model_info_path, "mlc-model-config.json")

max_context_length = args.max_seq_len
if args.max_seq_len == -1:
# for llama-1 family
if "max_sequence_length" in config:
max_context_length = config["max_sequence_length"]
# for llama-2, mistral, etc.
elif "max_position_embeddings" in config:
max_context_length = config["max_position_embeddings"]
else:
raise Exception("The model config should contain information about maximum context length.")

# Overwrite some configs
config["max_context_length"] = max_context_length
if args.sliding_window != -1 and "sliding_window" in config:
config["sliding_window"] = args.sliding_window

# copy hf config into mlc_model_config
mlc_model_config = config.copy()

with open(mlc_model_config_path, "w", encoding="utf-8") as outfile:
json.dump(mlc_model_config, outfile, indent=4)

if args.model_category != "minigpt":
utils.copy_tokenizer(args)

if args.convert_weight_only:
exit(0)
Expand All @@ -856,9 +908,10 @@ def build_model_from_args(args: argparse.Namespace):
sharding_module = create_shard_info_func(param_manager, args, model_config)
mod.update(sharding_module)

with open(cache_path, "wb") as outfile:
pickle.dump(mod, outfile)
print(f"Save a cached module to {cache_path}.")
if not args.no_cache_dump:
with open(cache_path, "wb") as outfile:
pickle.dump(mod, outfile)
print(f"Save a cached module to {cache_path}.")
else:
print(
f"Load cached module from {cache_path} and skip tracing. "
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def copy_tokenizer(args: argparse.Namespace) -> None:
]:
shutil.copy(
os.path.join(args.model_path, filename),
os.path.join(args.artifact_path, "params"),
os.path.join(args.artifact_path, "model") if args.enable_batching else os.path.join(args.artifact_path, "params"),
)


Expand Down
60 changes: 27 additions & 33 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
import json
import random
import time
import time, os
from typing import List, Tuple

import pandas as pd
Expand All @@ -12,6 +12,7 @@
Request,
SamplingParams,
StoppingCriteria,
get_engine_config
)
from mlc_serve.engine.staging_engine import StagingInferenceEngine
from mlc_serve.engine.sync_engine import SynchronousInferenceEngine
Expand Down Expand Up @@ -98,54 +99,46 @@ def run_mlc(
def create_engine_and_tokenizer_module(
args: argparse.Namespace,
):
engine_config = get_engine_config({
"use_staging_engine": args.use_staging_engine,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_input_len": args.max_input_len,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"prompt_allocate_ratio": args.prompt_allocate_ratio
})

if args.use_staging_engine:
tokenizer_module = HfTokenizerModule(args.model, args.artifact_path)
engine = StagingInferenceEngine(
tokenizer_module=tokenizer_module,
tokenizer_module=HfTokenizerModule(args.model_artifact_path),
model_module_loader=PagedCacheModelModule,
model_module_loader_kwargs={
"model_name": args.model,
"artifact_path": args.artifact_path,
"quantization": args.quantization.name,
"num_shards": args.num_shards,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_input_len": args.max_input_len,
"model_artifact_path": args.model_artifact_path,
"engine_config": engine_config,
},
max_batched_tokens=args.max_num_batched_tokens,
min_decode_steps=args.min_decode_steps,
max_decode_steps=args.max_decode_steps,
)
engine.start()
tokenizer = engine.tokenizer
else:
model_module = PagedCacheModelModule(
args.model,
args.artifact_path,
args.quantization.name,
args.num_shards,
max_num_batched_tokens=args.max_num_batched_tokens,
max_input_len=args.max_input_len,
)
tokenizer_module = model_module

engine = SynchronousInferenceEngine(
model_module,
max_batched_tokens=args.max_num_batched_tokens,
min_decode_steps=args.min_decode_steps,
max_decode_steps=args.max_decode_steps,
)
PagedCacheModelModule(
model_artifact_path = args.model_artifact_path,
engine_config = engine_config,
))
tokenizer = engine.tokenizer

return engine, tokenizer_module
return engine, tokenizer


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)

engine, tokenizer_module = create_engine_and_tokenizer_module(args)
engine, tokenizer = create_engine_and_tokenizer_module(args)

# Sample the requests.
requests = sample_requests(
args.dataset, args.num_prompts, tokenizer_module.tokenizer._tokenizer
args.dataset, args.num_prompts, tokenizer._tokenizer
)

elapsed_time = run_mlc(
Expand Down Expand Up @@ -185,12 +178,12 @@ def main(args: argparse.Namespace):
)
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="dist")
parser.add_argument("--num-shards", type=int, default=1)
parser.add_argument("--use-staging-engine", action="store_true")
parser.add_argument("--max-num-batched-tokens", type=int, default=-1)
parser.add_argument("--max-input-len", type=int, default=-1)
parser.add_argument("--min-decode-steps", type=int, default=32)
parser.add_argument("--max-decode-steps", type=int, default=56)
parser.add_argument("--prompt-allocate-ratio", type=float, default=2.0)
parser.add_argument(
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
Expand All @@ -205,7 +198,8 @@ def main(args: argparse.Namespace):
args = parser.parse_args()
setup_logging(args)

args.model, args.quantization = args.local_id.rsplit("-", 1)
utils.argparse_postproc_common(args)
args.model_artifact_path = os.path.join(args.artifact_path, args.local_id)
if not os.path.exists(args.model_artifact_path):
raise Exception(f"Invalid local id: {args.local_id}")

main(args)
2 changes: 2 additions & 0 deletions serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
RequestId,
RequestOutput,
StoppingCriteria,
MLCServeEngineConfig,
get_engine_config
)
from .sampling_params import SamplingParams, SamplingType
46 changes: 43 additions & 3 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,51 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

from typing import Optional, List
import json
import inspect
from .sampling_params import SamplingParams, SamplingType

RequestId = str

# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
# The maximum number of tokens in the batch.
# TODO(@sunggg): figure out better defaults
use_staging_engine: bool = True
max_num_batched_tokens: int = 4096
max_input_len: int = 512
max_num_sequences: int = 8
min_decode_steps: int = 32
max_decode_steps: int = 48
prompt_allocate_ratio: float = 2.0

@classmethod
def _from_json(config_cls, json_obj: dict):
return config_cls(
**{
k: v
for k, v in json_obj.items()
if k in inspect.signature(config_cls).parameters
}
)

def get_engine_config(dict_config, enable_check = True):
engine_config = MLCServeEngineConfig._from_json(dict_config)
# Checks to make sure engine configs are set correctly
# since engine config is critical to the performance
if enable_check:
# TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly
# assert engine_config.max_num_batched_tokens > 0
# assert engine_config.max_input_len > 0
# assert engine_config.max_num_sequences > 0
# assert engine_config.max_num_sequences * engine_config.max_input_len == engine_config.max_num_batched_tokens

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps
assert engine_config.prompt_allocate_ratio > 0

return engine_config

@dataclass
class StoppingCriteria:
Expand All @@ -14,7 +54,7 @@ class StoppingCriteria:
"""

max_tokens: Optional[int]
stop_sequences: Optional[list[str]]
stop_sequences: Optional[list[str]] = None


@dataclass
Expand Down
Loading