Skip to content

Commit

Permalink
[Refactor] Clean-up Management of Model/Artifact/Engine Info (#66)
Browse files Browse the repository at this point in the history
* wip

* works

* fix

* reflect feedback and add checks for engine configs
  • Loading branch information
sunggg authored Nov 16, 2023
1 parent 33c5a88 commit 858a444
Show file tree
Hide file tree
Showing 15 changed files with 370 additions and 274 deletions.
15 changes: 12 additions & 3 deletions mlc_llm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,17 @@ def main():
# Post processing of arguments
parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access

core.build_model_from_args(parsed_args)


# if num_shard>1 without -convert-weight-only or --build-model-only, we implicitly run it sequentially
if parsed_args.num_shards > 1 and not (parsed_args.build_model_only or parsed_args.convert_weight_only):
parsed_args.build_model_only = True
parsed_args.convert_weight_only = False # just to be explicit
core.build_model_from_args(parsed_args)

parsed_args.build_model_only = False
parsed_args.convert_weight_only = True
core.build_model_from_args(parsed_args)
else:
core.build_model_from_args(parsed_args)

if __name__ == "__main__":
main()
105 changes: 79 additions & 26 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import pickle
import shutil
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -262,6 +263,15 @@ class BuildArgs:
"action": "store_true",
},
)
no_cache_dump: bool = field(
default=False,
metadata={
"help": (
"Disable dumping `mod_cache_before_build.pkl`. When this flag is set, cached build would not be available."
),
"action": "store_true",
},
)
use_cuda_graph: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -383,6 +393,8 @@ def _parse_args(parsed) -> argparse.Namespace:
model_name.append(f"presharded-{parsed.num_shards}gpu")

parsed.artifact_path = os.path.join(parsed.artifact_path, "-".join(model_name))
parsed.lib_name = f"{parsed.model}-{parsed.quantization.name}-{parsed.target_kind}.{parsed.lib_format}"
parsed.lib_path = os.path.join(parsed.artifact_path, parsed.lib_name)

return parsed

Expand Down Expand Up @@ -590,6 +602,18 @@ def mod_transform_before_build(

return mod_deploy

def dump_build_config(
args: argparse.Namespace
):
build_config_path = os.path.join(args.artifact_path, "build_config.json")
config: Dict[str, Any] = {
"num_shards": args.num_shards,
"quantization": args.quantization.name,
"library_name": args.lib_name,
"build_options": str(args)
}
with open(build_config_path, "w", encoding="utf-8") as outfile:
json.dump(config, outfile, indent=4)

def dump_mlc_chat_config(
args: argparse.Namespace,
Expand Down Expand Up @@ -689,10 +713,7 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3})
ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib)

output_filename = f"{args.model}-{args.quantization.name}-{target_kind}.{args.lib_format}"

utils.debug_dump_shader(ex, f"{args.model}_{args.quantization.name}_{target_kind}", args)
args.lib_path = os.path.join(args.artifact_path, output_filename)
ex.export_library(args.lib_path, **args.export_kwargs)
print(f"Finish exporting to {args.lib_path}")

Expand Down Expand Up @@ -732,7 +753,7 @@ def build_model_from_args(args: argparse.Namespace):
with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f:
config = json.load(i_f)

if not use_cache or args.convert_weight_only:
if not use_cache or args.convert_weight_only or not os.path.exists(cache_path):
model_generators = {
"llama": llama,
"mistral": mistral,
Expand Down Expand Up @@ -824,25 +845,56 @@ def build_model_from_args(args: argparse.Namespace):

utils.save_params(params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1)

if args.model_category != "minigpt":
utils.copy_tokenizer(args)
if args.model_category == "rwkv" or args.model_category == "rwkv_world":
# TODO: refactor config into model definition
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
top_p=0.6,
temperature=1.2,
repetition_penalty=0.996,
rwkv_world=True,
)
else:
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
)
if not args.enable_batching:
if args.model_category == "rwkv" or args.model_category == "rwkv_world":
# TODO: refactor config into model definition
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
top_p=0.6,
temperature=1.2,
repetition_penalty=0.996,
rwkv_world=True,
)
else:
dump_mlc_chat_config(
args,
vocab_size=config["vocab_size"],
max_window_size=model_config.max_sequence_length,
)

if args.enable_batching:
# when batching is enabled, we dump info for mlc_serve runtime
dump_build_config(args)
model_info_path = os.path.join(args.artifact_path, "model")
os.makedirs(model_info_path, exist_ok=True)
mlc_model_config_path = os.path.join(model_info_path, "mlc-model-config.json")

max_context_length = args.max_seq_len
if args.max_seq_len == -1:
# for llama-1 family
if "max_sequence_length" in config:
max_context_length = config["max_sequence_length"]
# for llama-2, mistral, etc.
elif "max_position_embeddings" in config:
max_context_length = config["max_position_embeddings"]
else:
raise Exception("The model config should contain information about maximum context length.")

# Overwrite some configs
config["max_context_length"] = max_context_length
if args.sliding_window != -1 and "sliding_window" in config:
config["sliding_window"] = args.sliding_window

# copy hf config into mlc_model_config
mlc_model_config = config.copy()

with open(mlc_model_config_path, "w", encoding="utf-8") as outfile:
json.dump(mlc_model_config, outfile, indent=4)

if args.model_category != "minigpt":
utils.copy_tokenizer(args)

if args.convert_weight_only:
exit(0)
Expand All @@ -856,9 +908,10 @@ def build_model_from_args(args: argparse.Namespace):
sharding_module = create_shard_info_func(param_manager, args, model_config)
mod.update(sharding_module)

with open(cache_path, "wb") as outfile:
pickle.dump(mod, outfile)
print(f"Save a cached module to {cache_path}.")
if not args.no_cache_dump:
with open(cache_path, "wb") as outfile:
pickle.dump(mod, outfile)
print(f"Save a cached module to {cache_path}.")
else:
print(
f"Load cached module from {cache_path} and skip tracing. "
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def copy_tokenizer(args: argparse.Namespace) -> None:
]:
shutil.copy(
os.path.join(args.model_path, filename),
os.path.join(args.artifact_path, "params"),
os.path.join(args.artifact_path, "model") if args.enable_batching else os.path.join(args.artifact_path, "params"),
)


Expand Down
60 changes: 27 additions & 33 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
import json
import random
import time
import time, os
from typing import List, Tuple

import pandas as pd
Expand All @@ -12,6 +12,7 @@
Request,
SamplingParams,
StoppingCriteria,
get_engine_config
)
from mlc_serve.engine.staging_engine import StagingInferenceEngine
from mlc_serve.engine.sync_engine import SynchronousInferenceEngine
Expand Down Expand Up @@ -98,54 +99,46 @@ def run_mlc(
def create_engine_and_tokenizer_module(
args: argparse.Namespace,
):
engine_config = get_engine_config({
"use_staging_engine": args.use_staging_engine,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_input_len": args.max_input_len,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"prompt_allocate_ratio": args.prompt_allocate_ratio
})

if args.use_staging_engine:
tokenizer_module = HfTokenizerModule(args.model, args.artifact_path)
engine = StagingInferenceEngine(
tokenizer_module=tokenizer_module,
tokenizer_module=HfTokenizerModule(args.model_artifact_path),
model_module_loader=PagedCacheModelModule,
model_module_loader_kwargs={
"model_name": args.model,
"artifact_path": args.artifact_path,
"quantization": args.quantization.name,
"num_shards": args.num_shards,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_input_len": args.max_input_len,
"model_artifact_path": args.model_artifact_path,
"engine_config": engine_config,
},
max_batched_tokens=args.max_num_batched_tokens,
min_decode_steps=args.min_decode_steps,
max_decode_steps=args.max_decode_steps,
)
engine.start()
tokenizer = engine.tokenizer
else:
model_module = PagedCacheModelModule(
args.model,
args.artifact_path,
args.quantization.name,
args.num_shards,
max_num_batched_tokens=args.max_num_batched_tokens,
max_input_len=args.max_input_len,
)
tokenizer_module = model_module

engine = SynchronousInferenceEngine(
model_module,
max_batched_tokens=args.max_num_batched_tokens,
min_decode_steps=args.min_decode_steps,
max_decode_steps=args.max_decode_steps,
)
PagedCacheModelModule(
model_artifact_path = args.model_artifact_path,
engine_config = engine_config,
))
tokenizer = engine.tokenizer

return engine, tokenizer_module
return engine, tokenizer


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)

engine, tokenizer_module = create_engine_and_tokenizer_module(args)
engine, tokenizer = create_engine_and_tokenizer_module(args)

# Sample the requests.
requests = sample_requests(
args.dataset, args.num_prompts, tokenizer_module.tokenizer._tokenizer
args.dataset, args.num_prompts, tokenizer._tokenizer
)

elapsed_time = run_mlc(
Expand Down Expand Up @@ -185,12 +178,12 @@ def main(args: argparse.Namespace):
)
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="dist")
parser.add_argument("--num-shards", type=int, default=1)
parser.add_argument("--use-staging-engine", action="store_true")
parser.add_argument("--max-num-batched-tokens", type=int, default=-1)
parser.add_argument("--max-input-len", type=int, default=-1)
parser.add_argument("--min-decode-steps", type=int, default=32)
parser.add_argument("--max-decode-steps", type=int, default=56)
parser.add_argument("--prompt-allocate-ratio", type=float, default=2.0)
parser.add_argument(
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
Expand All @@ -205,7 +198,8 @@ def main(args: argparse.Namespace):
args = parser.parse_args()
setup_logging(args)

args.model, args.quantization = args.local_id.rsplit("-", 1)
utils.argparse_postproc_common(args)
args.model_artifact_path = os.path.join(args.artifact_path, args.local_id)
if not os.path.exists(args.model_artifact_path):
raise Exception(f"Invalid local id: {args.local_id}")

main(args)
2 changes: 2 additions & 0 deletions serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
RequestId,
RequestOutput,
StoppingCriteria,
MLCServeEngineConfig,
get_engine_config
)
from .sampling_params import SamplingParams, SamplingType
46 changes: 43 additions & 3 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,51 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

from typing import Optional, List
import json
import inspect
from .sampling_params import SamplingParams, SamplingType

RequestId = str

# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
# The maximum number of tokens in the batch.
# TODO(@sunggg): figure out better defaults
use_staging_engine: bool = True
max_num_batched_tokens: int = 4096
max_input_len: int = 512
max_num_sequences: int = 8
min_decode_steps: int = 32
max_decode_steps: int = 48
prompt_allocate_ratio: float = 2.0

@classmethod
def _from_json(config_cls, json_obj: dict):
return config_cls(
**{
k: v
for k, v in json_obj.items()
if k in inspect.signature(config_cls).parameters
}
)

def get_engine_config(dict_config, enable_check = True):
engine_config = MLCServeEngineConfig._from_json(dict_config)
# Checks to make sure engine configs are set correctly
# since engine config is critical to the performance
if enable_check:
# TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly
# assert engine_config.max_num_batched_tokens > 0
# assert engine_config.max_input_len > 0
# assert engine_config.max_num_sequences > 0
# assert engine_config.max_num_sequences * engine_config.max_input_len == engine_config.max_num_batched_tokens

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps
assert engine_config.prompt_allocate_ratio > 0

return engine_config

@dataclass
class StoppingCriteria:
Expand All @@ -14,7 +54,7 @@ class StoppingCriteria:
"""

max_tokens: Optional[int]
stop_sequences: Optional[list[str]]
stop_sequences: Optional[list[str]] = None


@dataclass
Expand Down
Loading

0 comments on commit 858a444

Please sign in to comment.