diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml deleted file mode 100644 index 478f75e8fd..0000000000 --- a/.github/workflows/lint.yml +++ /dev/null @@ -1,87 +0,0 @@ -name: Lint -on: [push, pull_request] -env: - IMAGE: 'mlcaidev/ci-cpu:caab922' - -jobs: - isort: - name: Python / isort - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - submodules: '' - - name: Version - run: | - wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh - chmod u+x ./ci/bash.sh - ./ci/bash.sh $IMAGE "conda env export --name ci-lint" - - name: Lint - run: | - ./ci/bash.sh $IMAGE bash ./ci/task/isort.sh - - black: - name: Python / black - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - submodules: '' - - name: Version - run: | - wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh - chmod u+x ./ci/bash.sh - ./ci/bash.sh $IMAGE "conda env export --name ci-lint" - - name: Lint - run: | - ./ci/bash.sh $IMAGE bash ./ci/task/black.sh - - mypy: - name: Python / mypy - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - submodules: '' - - name: Version - run: | - wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh - chmod u+x ./ci/bash.sh - ./ci/bash.sh $IMAGE "conda env export --name ci-lint" - - name: Lint - run: | - ./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh - - pylint: - name: Python / pylint - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - submodules: '' - - name: Version - run: | - wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh - chmod u+x ./ci/bash.sh - ./ci/bash.sh $IMAGE "conda env export --name ci-lint" - - name: Lint - run: | - ./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh - - clang-format: - name: C++ / clang-format - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - submodules: '' - ref: ${{ github.event.pull_request.head.sha }} - fetch-depth: 0 - - name: Version - run: | - wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh - chmod u+x ./ci/bash.sh - ./ci/bash.sh $IMAGE "conda env export --name ci-lint" - - name: Lint - run: | - ./ci/bash.sh $IMAGE bash ./ci/task/clang-format.sh diff --git a/ci/jenkinsfile.groovy b/ci/jenkinsfile.groovy new file mode 100644 index 0000000000..ed34d874c2 --- /dev/null +++ b/ci/jenkinsfile.groovy @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.jenkinsci.plugins.pipeline.modeldefinition.Utils + +image = 'mlcaidev/ci-cpu:caab922' +docker_run = "bash ci/bash.sh ${image}" + +def per_exec_ws(folder) { + return "workspace/exec_${env.EXECUTOR_NUMBER}/" + folder +} + +def init_git(submodule = false) { + checkout scm + if (submodule) { + retry(5) { + timeout(time: 2, unit: 'MINUTES') { + sh(script: 'git submodule update --init --recursive -f', label: 'Update git submodules') + } + } + } +} + +stage('Lint') { + parallel( + 'isort': { + node('CPU-SMALL') { + ws(per_exec_ws('mlc-llm-lint-isort')) { + init_git() + sh(script: "ls", label: 'debug') + sh(script: "${docker_run} conda env export --name ci-lint", label: 'Checkout version') + sh(script: "${docker_run} bash ci/task/isort.sh", label: 'Lint') + } + } + }, + 'black': { + node('CPU-SMALL') { + ws(per_exec_ws('mlc-llm-lint-black')) { + init_git() + sh(script: "ls", label: 'debug') + sh(script: "${docker_run} conda env export --name ci-lint", label: 'Checkout version') + sh(script: "${docker_run} bash ci/task/black.sh", label: 'Lint') + } + } + }, + 'mypy': { + node('CPU-SMALL') { + ws(per_exec_ws('mlc-llm-lint-mypy')) { + init_git() + sh(script: "ls", label: 'debug') + sh(script: "${docker_run} conda env export --name ci-lint", label: 'Checkout version') + sh(script: "${docker_run} bash ci/task/mypy.sh", label: 'Lint') + } + } + }, + 'pylint': { + node('CPU-SMALL') { + ws(per_exec_ws('mlc-llm-lint-pylint')) { + init_git() + sh(script: "ls", label: 'debug') + sh(script: "${docker_run} conda env export --name ci-lint", label: 'Checkout version') + sh(script: "${docker_run} bash ci/task/pylint.sh", label: 'Lint') + } + } + }, + 'clang-format': { + node('CPU-SMALL') { + ws(per_exec_ws('mlc-llm-lint-clang-format')) { + init_git() + sh(script: "ls", label: 'debug') + sh(script: "${docker_run} conda env export --name ci-lint", label: 'Checkout version') + sh(script: "${docker_run} bash ci/task/clang-format.sh", label: 'Lint') + } + } + }, + ) +} \ No newline at end of file diff --git a/ci/task/mypy.sh b/ci/task/mypy.sh index 52da13da5f..95753c2dee 100755 --- a/ci/task/mypy.sh +++ b/ci/task/mypy.sh @@ -8,4 +8,4 @@ export PYTHONPATH="./python:$PYTHONPATH" set -x -mypy ./python/ ./tests/python/ +mypy --install-types --non-interactive ./python/ ./tests/python/ diff --git a/ci/task/pylint.sh b/ci/task/pylint.sh index 7d2a0d326b..fb07ba6087 100755 --- a/ci/task/pylint.sh +++ b/ci/task/pylint.sh @@ -9,7 +9,7 @@ export PYTHONPATH="./python:$PYTHONPATH" set -x # TVM Unity is a dependency to this testing -pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly +pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly requests pylint --jobs $NUM_THREADS ./python/ pylint --jobs $NUM_THREADS --recursive=y ./tests/python/ diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 1255c18bcc..2e9d4868d5 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -317,25 +317,33 @@ class LLMChat { return os.str(); } - bool UpdateMaxWindowSizeFromMetadata() { + void UpdateConfigFromMetadata() { if (ft_.use_disco) { - return false; - } - if (this->sliding_window_ != -1) { - return false; + return; } + PackedFunc fget_metadata = ft_.mod_get_func("get_metadata"); if (fget_metadata == nullptr) { - return false; + return; } ObjectRef ret = fget_metadata(); std::string metadata_str = std::string(Downcast(ret)); picojson::value metadata_info; picojson::parse(metadata_info, std::string(metadata_str)); auto metadata = metadata_info.get(); + ICHECK(metadata["max_window_size"].is()); max_window_size_ = std::min(max_window_size_, metadata["max_window_size"].get()); - return true; + + if (metadata.count("prefill_chunk_size")) { + ICHECK(metadata["prefill_chunk_size"].is()); + prefill_chunk_size_ = + std::min(prefill_chunk_size_, metadata["prefill_chunk_size"].get()); + } + if (metadata.count("sliding_window")) { + ICHECK(metadata["sliding_window"].is()); + sliding_window_ = std::min(sliding_window_, metadata["sliding_window"].get()); + } } /*! @@ -410,21 +418,12 @@ class LLMChat { << "Cannot specify both sliding_window and max_window_size."; this->sliding_window_ = config["sliding_window"].get(); CHECK(this->sliding_window_ > 0) << "Sliding window size needs to be positive"; - CHECK(config.count("sliding_window_chunk_size")) + CHECK(config.count("prefill_chunk_size")) << "Need to specify chunk size if using sliding window attention."; } - if (config.count("sliding_window_chunk_size")) { - CHECK(config["sliding_window_chunk_size"].is()); - this->sliding_window_chunk_size_ = config["sliding_window_chunk_size"].get(); - CHECK(this->sliding_window_chunk_size_ > 0) - << "Sliding window chunk size needs to be positive"; - CHECK(config.count("sliding_window")) << "Need to specify sliding window size."; - } - if (config.count("model_name")) { - CHECK(config["model_name"].is()); - this->model_name_ = config["model_name"].get(); - } else { - CHECK(partial_update) << "Key \"model_name\" not found."; + if (config.count("prefill_chunk_size")) { + CHECK(config["prefill_chunk_size"].is()); + this->prefill_chunk_size_ = config["prefill_chunk_size"].get(); } if (config.count("top_p")) { CHECK(config["top_p"].is()); @@ -513,8 +512,8 @@ class LLMChat { // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. this->ft_.Init(reload_lib, device_, this->num_shards_); + UpdateConfigFromMetadata(); if (this->sliding_window_ == -1) { - UpdateMaxWindowSizeFromMetadata(); CHECK(max_window_size_ != std::numeric_limits::max()) << "Key \"max_window_size\" not found."; } @@ -807,9 +806,8 @@ class LLMChat { if (ft_.use_disco) { LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model"; } - if (this->sliding_window_ != -1) { - LOG(FATAL) - << "NotImplementedError: Sliding window attention does not support separate embedding"; + if (this->prefill_chunk_size_ != -1) { + LOG(FATAL) << "NotImplementedError: Separate embedding does not support chunking"; } NDArray embedding = Downcast( EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str)); @@ -832,10 +830,10 @@ class LLMChat { int32_t new_seq_len = total_seq_len_; NDArray logits_on_device; - if (this->sliding_window_ != -1) { - // Use chunking if we use sliding window attention (see Mistral paper figure 3). - for (int64_t begin = 0; begin < token_len; begin += this->sliding_window_chunk_size_) { - int64_t end = std::min(token_len, begin + this->sliding_window_chunk_size_); + if (this->prefill_chunk_size_ > 0) { + // Perform chunking. + for (int64_t begin = 0; begin < token_len; begin += this->prefill_chunk_size_) { + int64_t end = std::min(token_len, begin + this->prefill_chunk_size_); std::vector chunk = std::vector(prompt_tokens.begin() + begin, prompt_tokens.begin() + end); new_seq_len += static_cast(chunk.size()); @@ -844,6 +842,7 @@ class LLMChat { ICHECK_EQ(new_seq_len, total_seq_len_ + token_len) << "Expect chunking process all tokens"; } else { // Otherwise, prefill entire prompt at once. + CHECK(sliding_window_ == -1) << "Expect chunking with sliding window attention"; new_seq_len += token_len; logits_on_device = this->ForwardTokens(prompt_tokens, new_seq_len); } @@ -1356,8 +1355,6 @@ class LLMChat { //---------------------------- // Conversation //---------------------------- - // model name - std::string model_name_; // conversation Conversation conversation_; // total sequence len, @@ -1365,7 +1362,7 @@ class LLMChat { // max window size, mean and max generation length, sliding window // If we use sliding window, max window size is its default max() value int64_t max_window_size_{std::numeric_limits::max()}, mean_gen_len_{128}, - max_gen_len_{512}, sliding_window_{-1}, sliding_window_chunk_size_{-1}; + max_gen_len_{512}, sliding_window_{-1}, prefill_chunk_size_{-1}; // size of the vocab table int64_t vocab_size_; // number of shards in distributed inference diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index 0c2ed8535f..c11132dc23 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -33,7 +33,7 @@ Prerequisite TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/darwin-x86_64/bin/aarch64-linux-android24-clang # Example on Windows ANDROID_NDK: $HOME/Library/Android/sdk/ndk/25.2.9519653 - TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/darwin-x86_64/bin/aarch64-linux-android24-clang + TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang **JDK**, such as OpenJDK >= 17, to compile Java bindings of TVM Unity runtime. It could be installed via Homebrew on macOS, apt on Ubuntu or other package managers. Set up the following environment variable: @@ -164,6 +164,6 @@ Instructions have been provided to build an Android App with MLC LLM in previous .. code-block:: bash adb install android/MLCChat/app/release/app-release.apk - adb push dist/${MODEL_NAME}-${QUANTIZATION}/params /data/local/tmp/${MODEL_NAME}/ + adb push dist/${MODEL_NAME}-${QUANTIZATION}/params /data/local/tmp/${MODEL_NAME}-${QUANTIZATION}/ adb shell "mkdir -p /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" - adb shell "mv /data/local/tmp/${MODEL_NAME} /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/${MODEL_NAME}" + adb shell "mv /data/local/tmp/${MODEL_NAME}-${QUANTIZATION} /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" diff --git a/docs/index.rst b/docs/index.rst index 345b5d9603..b255c59fa4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -151,7 +151,7 @@ It is recommended to have at least 6GB free VRAM to run it. - Redmi Note 12 Pro with Snapdragon 685 - Google Pixel phones - **Tutorial and source code**. The source code of the iOS app is fully `open source `__, + **Tutorial and source code**. The source code of the android app is fully `open source `__, and a :doc:`tutorial ` is included in documentation. .. figure:: https://blog.mlc.ai/img/android/android-recording.gif diff --git a/mlc_llm/build.py b/mlc_llm/build.py index 5931fb2a2b..afbb8c6e6c 100644 --- a/mlc_llm/build.py +++ b/mlc_llm/build.py @@ -40,17 +40,18 @@ def main(): # Post processing of arguments parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access - # if num_shard>1 without -convert-weight-only or --build-model-only, we implicitly run it sequentially - if parsed_args.num_shards > 1 and not (parsed_args.build_model_only or parsed_args.convert_weight_only): + # if num_shard>1 without -convert-weight-only or --build-model-only, we implicitly run it sequentially + if parsed_args.num_shards > 1 and not (parsed_args.build_model_only or parsed_args.convert_weights_only): parsed_args.build_model_only = True - parsed_args.convert_weight_only = False # just to be explicit + parsed_args.convert_weights_only = False # just to be explicit core.build_model_from_args(parsed_args) - + parsed_args.build_model_only = False - parsed_args.convert_weight_only = True + parsed_args.convert_weights_only = True core.build_model_from_args(parsed_args) else: core.build_model_from_args(parsed_args) - + + if __name__ == "__main__": main() diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 7d0485611c..a542d971e5 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -8,15 +8,9 @@ from dataclasses import asdict, dataclass, field, fields from typing import Any, Dict, Optional +import mlc_llm import tvm import tvm.relax.backend.contrib.cublas as _ -from tvm import dlight as dl -from tvm import relax -from tvm.contrib.nvcc import parse_compute_version -from tvm.relax.backend import get_patterns_with_prefix -from tvm.relax.backend.contrib.cutlass import annotate_workspace - -import mlc_llm from mlc_llm import utils from mlc_llm.relax_model import ( chatglm, @@ -31,9 +25,20 @@ rwkv, stablelm_3b, ) -from mlc_llm.relax_model.commons import create_shard_info_func, create_shard_transformation_func -from mlc_llm.relax_model.param_manager import transform_params_for_each_rank, chain_parameter_transforms +from mlc_llm.relax_model.commons import ( + create_shard_info_func, + create_shard_transformation_func, +) +from mlc_llm.relax_model.param_manager import ( + chain_parameter_transforms, + transform_params_for_each_rank, +) from mlc_llm.transform import fuse_split_rotary_embedding, rewrite_attention +from tvm import dlight as dl +from tvm import relax +from tvm.contrib.nvcc import parse_compute_version +from tvm.relax.backend import get_patterns_with_prefix +from tvm.relax.backend.contrib.cutlass import annotate_workspace @dataclass @@ -51,75 +56,103 @@ class BuildArgs: The name of the model to build. If it is ``auto``, we will automatically set the model name according to ``--model-path``, ``hf-path``, or the model folders under ``--artifact-path/models``. + hf_path: str Hugging Face path from which to download params, tokenizer, and config. + quantization: str The quantization mode we use to compile. + max_seq_len: int The maximum allowed sequence length for the model. + target: str The target platform to compile the model for. + db_path: str Path to log database for all models. Default: ``./log_db/``. + reuse_lib: str Whether to reuse a previously generated lib. + artifact_path: str Where to store the output. + use_cache: int Whether to use previously pickled IRModule and skip trace. - convert_weight_only: bool + + convert_weights_only: bool Whether to only convert model weights and not build the model. If both ``convert_weight_only`` and ``build_model_only`` are set, the behavior is undefined. + build_model_only: bool Whether to only build model and do not convert model weights. + debug_dump: bool Whether to dump debugging files during compilation. + debug_load_script: bool Whether to load the script for debugging. + llvm_mingw: str ``/path/to/llvm-mingw-root``, use llvm-mingw to cross compile to windows. + system_lib: bool A parameter to ``relax.build``. + sep_embed: bool Build with separated embedding layer, only applicable to LlaMa. This feature is in testing stage, and will be formally replaced after massive overhaul of embedding feature for all models and use cases. + sliding_window: int The sliding window size in sliding window attention (SWA). This optional field overrides the `sliding_window` in config.json for those models that use SWA. Currently only useful when compiling Mistral. - sliding_window_chunk_size: int - The chunk size in sliding window attention (SWA) during prefilling. By default, - the chunk size is the same as sliding window. Currently only useful when compiling Mistral. + + prefill_chunk_size: int + The chunk size during prefilling. By default, the chunk size is the same as + max sequence length. Currently only useful when compiling Mistral. + cc_path: str ``/path/to/cross_compiler_path``; currently only used for cross-compile for nvidia/jetson device. + use_safetensors: bool Specifies whether to use ``.safetensors`` instead of the default ``.bin`` when loading in model weights. + enable_batching: bool Build the model for batched inference. This is a temporary flag used to control the model execution flow in single- sequence and batching settings for now. We will eventually merge two flows in the future and remove this flag then. + no_cutlass_attn: bool Disable offloading attention operations to CUTLASS. + no_cutlass_norm: bool Disable offloading layer and RMS norm operations to CUTLASS. + no_cublas: bool Disable the step that offloads matmul to cuBLAS. Without this flag, matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or ``q0f32``, target is CUDA and TVM has been built with cuBLAS enabled. + use_cuda_graph: bool Specifies whether to enable CUDA Graph for the decoder. MLP and QKV projection between two attention layers are put into a graph. + num_shards: int Number of shards to split the model into in tensor parallelism multi-gpu inference. Only useful when ``build_model_only`` is set. + use_flash_attn_mqa: bool Offload multi-query attention workload to Flash Attention. + pdb: bool If set, drop into a pdb debugger on error. + use_vllm_attention: bool Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True. """ @@ -148,6 +181,10 @@ class BuildArgs: default=-1, metadata={"help": "The maximum allowed sequence length for the model."}, ) + max_vocab_size: int = field( + default=40000, + metadata={"help": "The maximum allowed vocabulary size for the model."}, + ) target: str = field( default="auto", metadata={"help": "The target platform to compile the model for."}, @@ -161,11 +198,12 @@ class BuildArgs: default=1, metadata={"help": "Whether to use previously pickled IRModule and skip trace."}, ) - convert_weight_only: bool = field( + convert_weights_only: bool = field( default=False, metadata={ - "help": "Whether to only convert model weights and not build the model.", + "dest": "convert_weights_only", "action": "store_true", + "help": "Whether to only convert model weights and not build the model.", }, ) build_model_only: bool = field( @@ -239,6 +277,14 @@ class BuildArgs: "action": "store_true", }, ) + max_batch_size: int = field( + default=80, + metadata={ + "help": ( + "The maximum batch size for build. It has effect only when batching is enabled." + ), + }, + ) no_cutlass_attn: bool = field( default=False, metadata={ @@ -316,12 +362,12 @@ class BuildArgs: ), }, ) - sliding_window_chunk_size: int = field( + prefill_chunk_size: int = field( default=-1, metadata={ "help": ( - "The chunk size in sliding window attention (SWA) during prefilling. " - "By default, the chunk size is the same as sliding window. " + "The chunk size during prefilling. By default, the chunk size is " + "the same as the sliding window size or the max sequence length. " "Currently only useful when compiling Mistral." ), }, @@ -344,6 +390,11 @@ class BuildArgs: }, ) + @property + def convert_weight_only(self): + """A backwards-compatibility helper""" + return self.convert_weights_only + def convert_build_args_to_argparser() -> argparse.ArgumentParser: """Convert from BuildArgs to an equivalent ArgumentParser.""" @@ -358,6 +409,20 @@ def convert_build_args_to_argparser() -> argparse.ArgumentParser: args.add_argument(field_name, default=field.default, **kwargs) else: args.add_argument(field_name, type=field.type, default=field.default, **kwargs) + + # Most models contain more than a single parameter (citation + # needed), so "weights" should be plural. The initial use of + # "--convert-weight-only" caused enough typos that it is worth + # fixing. The old argument spelling is retained for backwards + # compatibility. + args.add_argument( + "--convert-weight-only", + default=False, + dest="convert_weights_only", + action="store_true", + help="Equivalent to --convert-weights-only, retained for backwards compatibility.", + ) + return args @@ -393,7 +458,7 @@ def _parse_args(parsed) -> argparse.Namespace: if parsed.use_presharded_weights: model_name.append(f"presharded-{parsed.num_shards}gpu") - # TODO(@sunggg): currently, we overwrite the artifact_path which forces to rely on name deduction rule. + # TODO(@sunggg): currently, we overwrite the artifact_path which forces to rely on name deduction rule. # Ideally, it is better to separate its root path and name tag. # Until we make the change in upstream, this is a temporary hack. artifact_tag = parsed.artifact_tag if parsed.artifact_tag else "-".join(model_name) @@ -526,7 +591,12 @@ def mod_transform_before_build( mod = param_manager.transform_dequantize()(mod) mod = relax.transform.BundleModelParams()(mod) - use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] + use_ft_quant = args.quantization.name in [ + "q4f16_ft", + "q8f16_ft", + "q4f16_ft_group", + "q8f16_ft_group", + ] mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) if ( @@ -660,10 +730,10 @@ def dump_mlc_chat_config( config["model_category"] = args.model_category config["model_name"] = args.model config["vocab_size"] = vocab_size + config["prefill_chunk_size"] = args.prefill_chunk_size if args.sliding_window != -1: # Do not add max window size if use sliding window config["sliding_window"] = args.sliding_window - config["sliding_window_chunk_size"] = args.sliding_window_chunk_size else: config["max_window_size"] = max_window_size @@ -702,6 +772,7 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: mod_deploy ) ) + if not args.enable_batching: mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy) if args.debug_load_script: @@ -731,10 +802,10 @@ def build_model_from_args(args: argparse.Namespace): "and it is highly recommended to use q4f16_1 instead" ) - use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] + use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"] if args.num_shards > 1: - if (not args.build_model_only) and (not args.convert_weight_only): + if (not args.build_model_only) and (not args.convert_weights_only): raise ValueError( "`num_shards` should be used together with " "`--build-model-only` and `--convert-weight-only`" @@ -759,7 +830,7 @@ def build_model_from_args(args: argparse.Namespace): with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) - if not use_cache or args.convert_weight_only or not os.path.exists(cache_path): + if not use_cache or args.convert_weights_only or not os.path.exists(cache_path): model_generators = { "llama": llama, "mistral": mistral, @@ -838,7 +909,7 @@ def build_model_from_args(args: argparse.Namespace): if args.num_shards > 1 and use_ft_quant: preprocessed = [] weight_preprocess_func = tvm.get_global_func("cutlass.ft_preprocess_weight") - is_int4 = args.quantization.name == "q4f16_ft" + is_int4 = args.quantization.name in ["q4f16_ft", "q4f16_ft_group"] sm = get_cuda_sm_version() for p in params: @@ -858,16 +929,25 @@ def build_model_from_args(args: argparse.Namespace): args, vocab_size=config["vocab_size"], max_window_size=model_config.max_sequence_length, + max_gen_len=model_config.max_sequence_length, top_p=0.6, temperature=1.2, repetition_penalty=0.996, rwkv_world=True, ) + elif args.model_category == "chatglm": + dump_mlc_chat_config( + args, + vocab_size=config["padded_vocab_size"], + max_window_size=model_config.max_sequence_length, + max_gen_len=model_config.max_sequence_length, + ) else: dump_mlc_chat_config( args, vocab_size=config["vocab_size"], max_window_size=model_config.max_sequence_length, + max_gen_len=model_config.max_sequence_length, ) if args.enable_batching: @@ -895,14 +975,14 @@ def build_model_from_args(args: argparse.Namespace): # copy hf config into mlc_model_config mlc_model_config = config.copy() - + with open(mlc_model_config_path, "w", encoding="utf-8") as outfile: json.dump(mlc_model_config, outfile, indent=4) if args.model_category != "minigpt": utils.copy_tokenizer(args) - if args.convert_weight_only: + if args.convert_weights_only: exit(0) mod = mod_transform_before_build(mod, param_manager, args, model_config) diff --git a/mlc_llm/quantization/__init__.py b/mlc_llm/quantization/__init__.py index e8ea71cd0f..6284df6fa8 100644 --- a/mlc_llm/quantization/__init__.py +++ b/mlc_llm/quantization/__init__.py @@ -4,7 +4,7 @@ from .quantization import QuantSpecUpdater from .group_quantization import GroupQuantizationSpec from .autogptq_quantization import AutogptqQuantizationSpec -from .ft_rowwise_quantization import FTRowwiseQuantizationSpec, FTQuantizeUpdater +from .ft_quantization import FTQuantizationSpec, FTQuantizeUpdater # The predefined quantization schemes. @@ -114,9 +114,28 @@ ), "q4f16_ft": QuantizationScheme( name="q4f16_ft", - linear_weight=FTRowwiseQuantizationSpec( + linear_weight=FTQuantizationSpec( dtype="float16", nbit=4, + group_size=-1, + ), + embedding_table=GroupQuantizationSpec( + dtype="float16", + mode="int4", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + qspec_updater_class=FTQuantizeUpdater, + ), + "q4f16_ft_group": QuantizationScheme( + name="q4f16_ft_group", + linear_weight=FTQuantizationSpec( + dtype="float16", + nbit=4, + group_size=64, ), embedding_table=GroupQuantizationSpec( dtype="float16", @@ -164,9 +183,27 @@ ), "q8f16_ft": QuantizationScheme( name="q8f16_ft", - linear_weight=FTRowwiseQuantizationSpec( + linear_weight=FTQuantizationSpec( + dtype="float16", + nbit=8, + ), + embedding_table=GroupQuantizationSpec( + dtype="float16", + mode="int8", + sym=True, + storage_nbit=32, + group_size=32, + transpose=False, + ), + final_fc_weight="same_as_linear_weight", + qspec_updater_class=FTQuantizeUpdater, + ), + "q8f16_ft_group": QuantizationScheme( + name="q8f16_ft_group", + linear_weight=FTQuantizationSpec( dtype="float16", nbit=8, + group_size=64, ), embedding_table=GroupQuantizationSpec( dtype="float16", diff --git a/mlc_llm/quantization/ft_rowwise_quantization.py b/mlc_llm/quantization/ft_quantization.py similarity index 80% rename from mlc_llm/quantization/ft_rowwise_quantization.py rename to mlc_llm/quantization/ft_quantization.py index e5ba8a0c5d..3e90b00519 100644 --- a/mlc_llm/quantization/ft_rowwise_quantization.py +++ b/mlc_llm/quantization/ft_quantization.py @@ -14,12 +14,14 @@ @dataclass -class FTRowwiseQuantizationSpec(QuantizationSpec): +class FTQuantizationSpec(QuantizationSpec): """The quantization specification for the FasterTransformer kernel.""" - def __init__(self, dtype, nbit): + def __init__(self, dtype, nbit, group_size=-1): super().__init__(dtype) self.nbit = nbit + assert group_size in [-1, 64, 128], f"Group size {group_size} is not supported." + self.group_size = group_size if tvm.cuda(0).exist: major, minor = parse_compute_version(tvm.cuda(0).compute_version) @@ -40,6 +42,7 @@ def f_quantize(bb: relax.BlockBuilder, inputs: List[relax.Expr]): encoding_func( self.nbit, 8, + group_size=self.group_size, dtype=self.dtype, ), inputs[0], @@ -73,33 +76,47 @@ def get_dequantize_func( decoding_func( self.nbit, storage_nbit=8, + group_size=self.group_size, ), func_name="decode", ) -def encoding_func(nbit: int, storage_nbit: int, dtype: str = "float32"): +def encoding_func(nbit: int, storage_nbit: int, group_size: int, dtype: str = "float32"): def te_encode_sym(weight: te.Tensor): + """Encode the weight tensor of shape [N, K] into a quantized weight tensor of shape + [K, N // float_per_int] and a scale tensor of shape [K // group_size, N] + """ n_float_per_int = storage_nbit // nbit max_int_value = (1 << (nbit - 1)) - 1 - scale_min_shape = (weight.shape[0],) - k = te.reduce_axis((0, weight.shape[1]), name="k") + cur_group_size = weight.shape[1] if group_size == -1 else group_size + scale_min_shape = (tir.ceildiv(weight.shape[1], cur_group_size), weight.shape[0]) + k = te.reduce_axis((0, cur_group_size), name="k") max_abs_value = te.compute( shape=scale_min_shape, - fcompute=lambda i: te.max(te.abs(weight[i, k]), axis=k), + fcompute=lambda group, i: te.max( + te.abs( + tir.if_then_else( + group * cur_group_size + k < weight.shape[1], + weight[i, group * cur_group_size + k], + tir.const(0, dtype=weight.dtype), + ) + ), + axis=k, + ), name="max_abs_value", ) - def f_compute_scale(i): - max_value = tir.max(tir.Cast(dtype, max_abs_value[i]), tir.const(1e-4, dtype)) + def f_compute_scale(*idx): + max_value = tir.max(tir.Cast(dtype, max_abs_value(*idx)), tir.const(1e-4, dtype)) return max_value / tir.const(max_int_value + 1, dtype) scale = te.compute(shape=scale_min_shape, fcompute=f_compute_scale, name="scale") storage_dtype = "int" + str(storage_nbit) def f_scale_weight(i, j): - w_scaled = tir.round(tir.Cast(dtype, weight[i, j]) / scale[i]) + w_scaled = tir.round(tir.Cast(dtype, weight[i, j]) / scale[j // cur_group_size, i]) w_scaled = T.min( T.max(w_scaled, tir.const(-max_int_value - 1, dtype)), tir.const(max_int_value, dtype), @@ -142,9 +159,10 @@ def f_scale_weight(i, j): return te_encode_sym -def decoding_func(nbit: int, storage_nbit: int): +def decoding_func(nbit: int, storage_nbit: int, group_size: int): def te_decode_sym(data, scale): n_float_per_int = storage_nbit // nbit + cur_group_size = data.shape[0] if group_size == -1 else group_size def f_decode_sym(i, j): if n_float_per_int == 1: @@ -155,7 +173,7 @@ def f_decode_sym(i, j): nbit, data[i, j // n_float_per_int], j % n_float_per_int, dtype="float16" ) - scale_float = scale[j] + scale_float = scale[i // cur_group_size, j] return data_float * scale_float shape = (data.shape[0], data.shape[1] * n_float_per_int) diff --git a/mlc_llm/relax_model/chatglm.py b/mlc_llm/relax_model/chatglm.py index e21ed800bb..b12abdbfce 100644 --- a/mlc_llm/relax_model/chatglm.py +++ b/mlc_llm/relax_model/chatglm.py @@ -1,37 +1,31 @@ -import math import argparse +import math from dataclasses import dataclass -from typing import Tuple, List - -from .commons import create_metadata_func +from typing import List, Tuple import tvm from tvm import relax, te, tir -from tvm.relax.testing import nn -from tvm.relax.op.nn import softmax, silu -from tvm.script import relax as R from tvm.relax.op import ( + astype, broadcast_to, - permute_dims, expand_dims, + matmul, maximum, minimum, + permute_dims, + repeat, reshape, - squeeze, - astype, - matmul, split, - repeat, + squeeze, ) +from tvm.relax.op.nn import silu, softmax +from tvm.relax.testing import nn +from tvm.script import relax as R from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import Embedding, Linear, ModuleList, RotaryEmbedding from .param_manager import ParamManager -from .modules import ( - ModuleList, - Embedding, - Linear, - RotaryEmbedding, -) @dataclass @@ -92,11 +86,7 @@ def f_rms_norm(x, weight): is_float32 = x.dtype == "float32" def f_square(x): - return ( - tir.Cast("float32", x) * tir.Cast("float32", x) - if not is_float32 - else x * x - ) + return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x k = te.reduce_axis((0, x.shape[2]), name="k") square_sum = te.compute( @@ -137,9 +127,7 @@ def __init__(self, config: ChatGLMConfig): # Per attention head and per partition values. self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = ( - projection_size // config.num_attention_heads - ) + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -206,9 +194,7 @@ def __init__( self.projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. - self.hidden_size_per_attention_head = ( - self.projection_size // config.num_attention_heads - ) + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads # Multi-query attention config @@ -256,8 +242,7 @@ def forward( split( self.query_key_value(hidden_states), indices_or_sections=[ - self.num_attention_heads_per_partition - * self.hidden_size_per_attention_head, + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, ( self.num_attention_heads_per_partition + self.num_multi_query_groups_per_partition @@ -294,9 +279,7 @@ def forward( q, k = self.rotary_pos_emb(q, k, kv_sl - sl) assert k.struct_info.shape[0] == 1 and v.struct_info.shape[0] == 1 - squeezed_k, squeezed_v = nn.emit(squeeze(k, axis=0)), nn.emit( - squeeze(v, axis=0) - ) + squeezed_k, squeezed_v = nn.emit(squeeze(k, axis=0)), nn.emit(squeeze(v, axis=0)) k_cache, v_cache = past_key_value f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") @@ -335,10 +318,7 @@ def forward( ) ) - n_rep = ( - self.num_attention_heads_per_partition - // self.num_multi_query_groups_per_partition - ) + n_rep = self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition kv_attn_shape = R.shape( [ bsz, @@ -440,9 +420,7 @@ class GLMTransformer(nn.Module): def __init__(self, config: ChatGLMConfig, rotary_pos_emb: RotaryEmbedding): self.num_layers = config.num_layers - self.layers = ModuleList( - [GLMBlock(config, rotary_pos_emb) for _ in range(self.num_layers)] - ) + self.layers = ModuleList([GLMBlock(config, rotary_pos_emb) for _ in range(self.num_layers)]) self.final_layernorm = RMSNorm( hidden_size=config.hidden_size, dtype=config.dtype, @@ -617,9 +595,7 @@ def te_slice_last(x: te.Tensor): return lm_logits, key_value_cache -def get_param_quant_kind( - name: str, param_info: relax.TensorStructInfo -) -> ParamQuantKind: +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: if "embedding.weight" in name: return ParamQuantKind.embedding_table elif "transformer.output_layer.weight" in name: @@ -643,19 +619,13 @@ def create_encoding_func( all_seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = ChatGLMForCausalLM(config) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, sl), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var( - "all_seq_len", relax.ShapeStructInfo((all_seq_len,)) - ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) past_key_values = relax.Var( "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_layers * 2)] - ), + relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.num_layers * 2)]), ) with bb.dataflow(): @@ -686,23 +656,17 @@ def create_decoding_func( func_name = "decode" bsz = 1 - all_seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = ChatGLMForCausalLM(config) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var( - "all_seq_len", relax.ShapeStructInfo((all_seq_len,)) - ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) past_key_values = relax.Var( "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_layers * 2)] - ), + relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.num_layers * 2)]), ) with bb.dataflow(): logits, key_value_cache = model( @@ -752,9 +716,7 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None: def create_softmax_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None: with bb.function("softmax_with_temperature"): - logits = nn.Placeholder( - (1, 1, config.padded_vocab_size), dtype="float32", name="logits" - ) + logits = nn.Placeholder((1, 1, config.padded_vocab_size), dtype="float32", name="logits") temperature = nn.Placeholder((), dtype="float32", name="temperature") with bb.dataflow(): div = bb.emit(relax.op.divide(logits, temperature)) @@ -785,19 +747,20 @@ def get_model(args: argparse.Namespace, hf_config): max_window_size=config.max_sequence_length, stop_tokens=[0], add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, ) mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length for gv in mod.functions: func = mod[gv] if isinstance(func, relax.Function): - mod[gv] = func.with_attr( - "tir_var_upper_bound", - { - "n": config.max_sequence_length, - "m": config.max_sequence_length, - }, - ) + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) if args.build_model_only: return mod, param_manager, None, config @@ -805,9 +768,7 @@ def get_model(args: argparse.Namespace, hf_config): def f_convert_pname_fwd(pname: str) -> List[str]: if "transformer.embedding" in pname: return [ - pname.replace( - "transformer.embedding", "transformer.embedding.word_embeddings" - ) + pname.replace("transformer.embedding", "transformer.embedding.word_embeddings") ] else: return [pname] diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 0f2f41d93b..3eb67b5b1c 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -1,10 +1,9 @@ import json -from typing import List, Optional, Dict - -import tvm -from tvm import relax, tir, te, topi +from typing import Dict, List, Optional import mlc_llm +import tvm +from tvm import relax, te, tir, topi def create_metadata_func( @@ -13,6 +12,8 @@ def create_metadata_func( max_window_size: int, stop_tokens: List[int], add_prefix_space: bool, + prefill_chunk_size: int = -1, + sliding_window: int = -1, ): metadata = json.dumps( { @@ -20,6 +21,8 @@ def create_metadata_func( "max_window_size": max_window_size, "stop_tokens": stop_tokens, "add_prefix_space": add_prefix_space, + "prefill_chunk_size": prefill_chunk_size, + "sliding_window": sliding_window, } ) with bb.function("get_metadata", params=[]): @@ -94,8 +97,8 @@ def _get_shard_strategies_ft( q_heads = model_config.num_attention_heads kv_heads = model_config.get_num_key_value_heads() - def shard_qkv_weight(weight: relax.TensorStructInfo): - (red, spatial), dtype = weight.shape, weight.dtype + def shard_qkv_weight_scale(x: relax.TensorStructInfo): + (red, spatial), dtype = x.shape, x.dtype red, spatial = int(red), int(spatial) if param_shape_is_already_sharded: spatial *= num_shards @@ -114,31 +117,6 @@ def shard_qkv_weight(weight: relax.TensorStructInfo): func = te.create_prim_func([a, w]) return func - def shard_qkv_scale(scale: relax.TensorStructInfo): - (spatial,), dtype = scale.shape, scale.dtype - spatial = int(spatial) - if param_shape_is_already_sharded: - spatial *= num_shards - head_dim = spatial // (q_heads + 2 * kv_heads) - a = te.placeholder((spatial,), dtype=dtype) - w = topi.reshape(a, (spatial // head_dim, head_dim)) - q = te.compute((q_heads, head_dim), lambda i, j: w[i, j]) - k = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + i, j]) - v = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + kv_heads + i, j]) - q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim)) - k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim)) - v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim)) - w = topi.concatenate((q, k, v), axis=1) - w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim)) - func = te.create_prim_func([a, w]) - return func - - def shard_qkv_weight_scale(x: relax.TensorStructInfo): - if x.ndim == 2: - return shard_qkv_weight(x) - else: - return shard_qkv_scale(x) - def shard_k_weight(weight: relax.TensorStructInfo): (red, spatial), dtype = weight.shape, weight.dtype red, spatial = int(red), int(spatial) @@ -149,8 +127,8 @@ def shard_k_weight(weight: relax.TensorStructInfo): func = te.create_prim_func([a, w]) return func - def shard_gate_up_weight(weight: relax.TensorStructInfo): - (red, spatial), dtype = weight.shape, weight.dtype + def shard_gate_up_weight_scale(x: relax.TensorStructInfo): + (red, spatial), dtype = x.shape, x.dtype red, spatial = int(red), int(spatial) if param_shape_is_already_sharded: spatial *= num_shards @@ -165,27 +143,6 @@ def shard_gate_up_weight(weight: relax.TensorStructInfo): func = te.create_prim_func([a, w]) return func - def shard_gate_up_scale(weight: relax.TensorStructInfo): - (spatial,), dtype = weight.shape, weight.dtype - spatial = int(spatial) - if param_shape_is_already_sharded: - spatial *= num_shards - a = te.placeholder((spatial,), dtype=dtype) - g = te.compute((spatial // 2,), lambda i: a[i]) - u = te.compute((spatial // 2,), lambda i: a[spatial // 2 + i]) - g = topi.reshape(g, (num_shards, spatial // 2 // num_shards)) - u = topi.reshape(u, (num_shards, spatial // 2 // num_shards)) - w = topi.concatenate((g, u), axis=1) - w = topi.reshape(w, (num_shards, spatial // num_shards)) - func = te.create_prim_func([a, w]) - return func - - def shard_gate_up_weight_scale(x: relax.TensorStructInfo): - if x.ndim == 2: - return shard_gate_up_weight(x) - else: - return shard_gate_up_scale(x) - return { "shard_qkv": shard_qkv_weight_scale, "shard_mlp_k": shard_k_weight, @@ -246,7 +203,7 @@ def add_to_shard_info(param_name: str, func_name: Optional[str]): def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule: - use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] + use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"] if use_ft_quant: shard_strategy_to_func = _get_shard_strategies_ft( @@ -307,13 +264,14 @@ def create_shard_transformation_func(param_manager, args, model_config) -> tvm.I if param.shard_strategy is None or ( use_ft_quant and param.shard_strategy in ["shard_mlp_k", "shard_o_proj_k"] - and len(qparam_sinfo.shape) == 1 + and qparam_sinfo.shape[0] == 1 ): sharded = arg else: strategy_func = shard_strategy_to_func[param.shard_strategy]( qparam_sinfo ).without_attr("global_symbol") + strategy_gvar = bb.add_func( strategy_func, func_name=f"{arg_name}.sharding_func", diff --git a/mlc_llm/relax_model/gpt_bigcode.py b/mlc_llm/relax_model/gpt_bigcode.py index 35ba683992..466933f8f2 100644 --- a/mlc_llm/relax_model/gpt_bigcode.py +++ b/mlc_llm/relax_model/gpt_bigcode.py @@ -1,29 +1,28 @@ -import math import argparse +import math from dataclasses import dataclass from typing import Optional, Tuple, Union -from .commons import create_metadata_func - import tvm from tvm import relax, te -from tvm.relax.testing import nn -from tvm.relax.op.nn import gelu, softmax, layer_norm -from tvm.script import relax as R from tvm.relax.op import ( + astype, broadcast_to, - permute_dims, expand_dims, + matmul, maximum, minimum, + permute_dims, reshape, squeeze, - astype, - matmul, ) +from tvm.relax.op.nn import gelu, layer_norm, softmax +from tvm.relax.testing import nn +from tvm.script import relax as R from ..quantization import ParamQuantKind, QuantizationScheme -from .modules import ModuleList, Embedding, Linear +from .commons import create_metadata_func +from .modules import Embedding, Linear, ModuleList from .param_manager import ParamManager @@ -167,9 +166,7 @@ def __init__(self, config: GPTBigCodeConfig): self.n_head = config.n_head self.head_dim = config.n_embd // config.n_head - self.c_attn = Linear( - self.n_embd, self.n_embd + 2 * self.head_dim, config.dtype, bias=True - ) + self.c_attn = Linear(self.n_embd, self.n_embd + 2 * self.head_dim, config.dtype, bias=True) self.c_proj = Linear(self.n_embd, self.n_embd, config.dtype, bias=True) self.dtype = config.dtype @@ -198,9 +195,7 @@ def te_slice(x: te.Tensor, start: int, end: int): query_key_value = self.c_attn(hidden_states) # queries: [batch_size, seq_len, n_embd] - q = nn.emit_te( - te_slice, query_key_value, 0, self.n_embd, primfunc_name_hint="slice" - ) + q = nn.emit_te(te_slice, query_key_value, 0, self.n_embd, primfunc_name_hint="slice") # keys: [batch_size, seq_len, head_dim] k = nn.emit_te( te_slice, @@ -473,9 +468,7 @@ def te_slice_last(x: te.Tensor): return logits, key_value_cache -def get_param_quant_kind( - name: str, param_info: relax.TensorStructInfo -) -> ParamQuantKind: +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: if "wte.weight" in name: return ParamQuantKind.embedding_table elif "lm_head.weight" in name: @@ -499,21 +492,13 @@ def create_encoding_func( all_seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = GPTBigCodeForCausalLM(config) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - input_ids = nn.Placeholder( - (batch_size, seq_len), dtype="int32", name="input_ids" - ) - all_seq_len_shape = relax.Var( - "all_seq_len", relax.ShapeStructInfo((all_seq_len,)) - ) + input_ids = nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) past_key_values = relax.Var( "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.n_layer * 2)] - ), + relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.n_layer * 2)]), ) with bb.dataflow(): @@ -545,23 +530,17 @@ def create_decoding_func( bsz = tvm.tir.IntImm("int64", 1) seq_len = tvm.tir.IntImm("int64", 1) - all_seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = GPTBigCodeForCausalLM(config) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var( - "all_seq_len", relax.ShapeStructInfo((all_seq_len,)) - ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) past_key_values = relax.Var( "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.n_layer * 2)] - ), + relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.n_layer * 2)]), ) with bb.dataflow(): logits, key_value_cache = model( @@ -610,9 +589,7 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: GPTBigCodeConfig) -> No def create_softmax_func(bb: relax.BlockBuilder, config: GPTBigCodeConfig) -> None: with bb.function("softmax_with_temperature"): - logits = nn.Placeholder( - (1, 1, config.vocab_size), dtype="float32", name="logits" - ) + logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") temperature = nn.Placeholder((), dtype="float32", name="temperature") with bb.dataflow(): div = bb.emit(relax.op.divide(logits, temperature)) @@ -652,19 +629,20 @@ def get_model(args: argparse.Namespace, hf_config): max_window_size=config.max_sequence_length, stop_tokens=[0], add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, ) mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length for gv in mod.functions: func = mod[gv] if isinstance(func, relax.Function): - mod[gv] = func.with_attr( - "tir_var_upper_bound", - { - "n": config.max_sequence_length, - "m": config.max_sequence_length, - }, - ) + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) if args.build_model_only: return mod, param_manager, None, config diff --git a/mlc_llm/relax_model/gpt_neox.py b/mlc_llm/relax_model/gpt_neox.py index b5b5861262..4864621a37 100644 --- a/mlc_llm/relax_model/gpt_neox.py +++ b/mlc_llm/relax_model/gpt_neox.py @@ -21,13 +21,7 @@ from ..quantization import ParamQuantKind, QuantizationScheme from .commons import create_metadata_func -from .modules import ( - Embedding, - LayerNorm, - Linear, - ModuleList, - RotaryEmbedding, -) +from .modules import Embedding, LayerNorm, Linear, ModuleList, RotaryEmbedding from .param_manager import ParamManager @@ -506,7 +500,7 @@ def create_embed_func( func_name = "embed" bsz = 1 - seq_len = tvm.tir.Var("n", "int64") + seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = GPTNeoXEmbedTokensWrapper(config) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) @@ -584,7 +578,7 @@ def create_decoding_func( batch_size = tvm.tir.IntImm("int64", 1) seq_len = tvm.tir.IntImm("int64", 1) - all_seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = GPTNeoXForCausalLM(config) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) @@ -702,18 +696,19 @@ def get_model( max_window_size=config.max_sequence_length, stop_tokens=stop_tokens, add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, ) mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length for gv in mod.functions: func = mod[gv] if isinstance(func, relax.Function): - mod[gv] = func.with_attr( - "tir_var_upper_bound", - { - "n": config.max_sequence_length, - "m": config.max_sequence_length, - }, - ) + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) if args.build_model_only: return mod, param_manager, None, config diff --git a/mlc_llm/relax_model/gptj.py b/mlc_llm/relax_model/gptj.py index caff411143..73ffeb8122 100644 --- a/mlc_llm/relax_model/gptj.py +++ b/mlc_llm/relax_model/gptj.py @@ -23,13 +23,7 @@ from ..quantization import ParamQuantKind, QuantizationScheme from .commons import create_metadata_func from .gpt_neox import create_kv_cache_func -from .modules import ( - Embedding, - LayerNorm, - Linear, - ModuleList, - RotaryEmbedding, -) +from .modules import Embedding, LayerNorm, Linear, ModuleList, RotaryEmbedding from .param_manager import ParamManager @@ -459,21 +453,15 @@ def _slice(x: te.Tensor): def check_parameters(param_dict, param_list): relax_shape_to_list = lambda _: [s.value for s in _.values] - shape_dict_0 = { - k: relax_shape_to_list(v.struct_info.shape) for k, v in param_dict.items() - } + shape_dict_0 = {k: relax_shape_to_list(v.struct_info.shape) for k, v in param_dict.items()} shape_dict_1 = {k: list(v.shape) for (k, v) in param_list} assert len(shape_dict_0) == len(shape_dict_1) for k, v in shape_dict_0.items(): assert k in shape_dict_1, "{}".format(k) - assert v == shape_dict_1[k], "key={}, shape_0={}, shape_1={}".format( - k, v, shape_dict_1[k] - ) + assert v == shape_dict_1[k], "key={}, shape_0={}, shape_1={}".format(k, v, shape_dict_1[k]) -def get_param_quant_kind( - name: str, param_info: relax.TensorStructInfo -) -> ParamQuantKind: +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: if "wte.weight" in name: return ParamQuantKind.embedding_table elif "lm_head.weight" in name: @@ -493,12 +481,10 @@ def create_embed_func( func_name = "embed" bsz = 1 - seq_len = tvm.tir.Var("n", "int64") + seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = GPTJEmbedTokensWrapper(config) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") with bb.dataflow(): @@ -527,9 +513,7 @@ def create_encoding_func( hidden_size = config.hidden_size with bb.function(func_name): model = GPTJForCausalLM(config, sep_embed) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs = ( nn.Placeholder( @@ -540,9 +524,7 @@ def create_encoding_func( if sep_embed else nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") ) - all_seq_len_shape = relax.Var( - "all_seq_len", relax.ShapeStructInfo((all_seq_len,)) - ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) past_key_values = relax.Var( "kv_cache", relax.TupleStructInfo( @@ -577,16 +559,12 @@ def create_decoding_func( batch_size = tvm.tir.IntImm("int64", 1) seq_len = tvm.tir.IntImm("int64", 1) - all_seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = GPTJForCausalLM(config) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - input_ids = nn.Placeholder( - (batch_size, seq_len), dtype="int32", name="input_ids" - ) + input_ids = nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") all_seq_len_shape = relax.Var( "all_seq_len", relax.ShapeStructInfo((all_seq_len,)), @@ -617,9 +595,7 @@ def create_decoding_func( def create_softmax_func(bb: relax.BlockBuilder, config: GPTJConfig) -> None: with bb.function("softmax_with_temperature"): - logits = nn.Placeholder( - (1, 1, config.vocab_size), dtype="float32", name="logits" - ) + logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") temperature = nn.Placeholder((), dtype="float32", name="temperature") with bb.dataflow(): div = bb.emit(relax.op.divide(logits, temperature)) @@ -657,18 +633,19 @@ def get_model(args, hf_config): max_window_size=config.max_sequence_length, stop_tokens=stop_tokens, add_prefix_space=True, + prefill_chunk_size=args.prefill_chunk_size, ) mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length for gv in mod.functions: func = mod[gv] if isinstance(func, relax.Function): - mod[gv] = func.with_attr( - "tir_var_upper_bound", - { - "n": config.max_sequence_length, - "m": config.max_sequence_length, - }, - ) + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) if args.build_model_only: return mod, param_manager, None, config @@ -684,9 +661,7 @@ def f_convert_pname_fwd(pname: str) -> List[str]: hidden_size = config.hidden_size - def f_convert_param_bkwd( - torch_pname: str, torch_param - ) -> Optional[List[Tuple[str, Any]]]: + def f_convert_param_bkwd(torch_pname: str, torch_param) -> Optional[List[Tuple[str, Any]]]: # torch_param: numpy.ndarray if torch_pname.endswith("qkv_proj.weight"): assert torch_param.ndim == 2 diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 8294313324..88fde9509a 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -387,7 +387,8 @@ def __init__(self, config: LlamaConfig): super().__init__(config) ctx_mod = relax.BlockBuilder.current().get() self.kv_cache_transpose_append = ctx_mod.get_global_var("kv_cache_transpose_append") - self.attention_compute = ctx_mod.get_global_var("attention") + self.attention_compute_prefill = ctx_mod.get_global_var("attention_prefill") + self.attention_compute_decode = ctx_mod.get_global_var("attention_decode") def attention_fwd( self, @@ -416,12 +417,13 @@ def attention_fwd( ) f_kv_cache_attention = relax.extern("vm.builtin.paged_attention_kv_cache_attention") + is_decode = query_states.struct_info.shape[1] == 1 attn_output = nn.emit( relax.call_dps_packed( f_kv_cache_attention, [ past_key_values, - self.attention_compute, + self.attention_compute_decode if is_decode else self.attention_compute_prefill, query_states, relax.PrimValue(layer_id), True, @@ -456,14 +458,7 @@ def attention_fwd( attention_mask = kwargs["attention_mask"] kv_seq_len = kwargs["all_seq_len_shape"].struct_info.values[0] - from tvm.relax.op import ( - astype, - matmul, - maximum, - permute_dims, - reshape, - squeeze, - ) + from tvm.relax.op import astype, matmul, maximum, permute_dims, reshape, squeeze from tvm.relax.op.nn import softmax offset = kv_seq_len - q_len @@ -832,6 +827,7 @@ def forward( inputs: relax.Expr, all_seq_len_shape: Optional[relax.Expr], past_key_values: relax.Expr, + logit_positions: Optional[relax.Expr] = None, ): hidden_states, key_value_cache = self.model( inputs=inputs, @@ -847,7 +843,13 @@ def te_slicing(x: te.Tensor): name="slice", ) - logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) + if hidden_states.struct_info.shape[1] != 1: + if logit_positions is None: + hidden_states = nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice") + else: + hidden_states = relax.op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.struct_info.dtype != "float32": logits = nn.emit(relax.op.astype(logits, "float32")) @@ -873,13 +875,12 @@ def create_embed_func( ) -> None: func_name = "embed" - bsz = tvm.tir.Var("nseq", "int64") - seq_len = tvm.tir.Var("n", "int64") + seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = LlamaEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + input_ids = nn.Placeholder((1, seq_len), dtype="int32", name="input_ids") with bb.dataflow(): inputs_embeds = model(input_ids) params = [input_ids] + model.parameters() @@ -947,8 +948,8 @@ def create_prefill_func_for_batching( ) -> None: func_name = "prefill_with_embed" - bsz = 1 - seq_len = tvm.tir.Var("n", "int64") + bsz = tir.Var("nseq", "int64") + total_seq_len = tvm.tir.Var("m", "int64") hidden_size = config.hidden_size with bb.function(func_name): model = LlamaForCausalLM( @@ -957,22 +958,24 @@ def create_prefill_func_for_batching( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs = nn.Placeholder( - (bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" + (1, total_seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" ) + logit_pos = nn.Placeholder((bsz,), dtype="int32", name="logit_positions") past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) with bb.dataflow(): logits, key_value_cache = model( inputs, all_seq_len_shape=None, past_key_values=past_key_values, + logit_positions=logit_pos, ) - params = [inputs, past_key_values] + model.parameters() + params = [inputs, logit_pos, past_key_values] + model.parameters() gv = bb.emit_output((logits, key_value_cache)) bb.emit_func_output(gv, params) mod = bb.get() gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 2)) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) def create_decoding_func_for_single_seq( @@ -984,7 +987,7 @@ def create_decoding_func_for_single_seq( func_name = "decode" bsz = 1 - all_seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = LlamaForCausalLM(config, tvm.tir.Var("vocab_size", "int64")) @@ -1099,6 +1102,7 @@ def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> N relax.PrimValue(num_key_value_heads), relax.PrimValue(head_dim), zeros, + relax.PrimValue(0), ], sinfo_args=[relax.ObjectStructInfo()], ) @@ -1136,9 +1140,13 @@ def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig bb.emit_func_output(gv, [logits, temperature]) -def emit_paged_kv_cache_op(bb: relax.BlockBuilder, dtype: str) -> None: +def emit_paged_kv_cache_op(bb: relax.BlockBuilder, config: LlamaConfig) -> None: from tvm.script import tir as T + num_layers = config.num_hidden_layers + num_heads = config.num_key_value_heads + head_dim = config.hidden_size // config.num_attention_heads + # fmt: off @T.prim_func def kv_cache_transpose_append( @@ -1150,31 +1158,28 @@ def kv_cache_transpose_append( var_last_page_offset: T.handle, var_append_length_indptr: T.handle, var_pos2seqidx: T.handle, - layer_id: T.int32, + layer_id: T.int64, ): - nseq = T.int32() - ntoken = T.int32() - nhead = T.int32() - nfeat = T.int32() - nlayer = T.int32() - npage = T.int32() - page_size = T.int32() - num_pages = T.int32() - - pages = T.match_buffer(var_pages, (num_pages, nlayer, 2, nhead, page_size, nfeat), dtype) - k_data = T.match_buffer(var_k_data, (ntoken, nhead, nfeat), dtype) - v_data = T.match_buffer(var_v_data, (ntoken, nhead, nfeat), dtype) + nseq = T.int64() + ntoken = T.SizeVar("ntoken", "int64") + npage = T.int64() + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + + pages = T.match_buffer(var_pages, (num_pages, num_layers, 2, num_heads, page_size, head_dim), config.dtype) + k_data = T.match_buffer(var_k_data, (ntoken, num_heads, head_dim), config.dtype) + v_data = T.match_buffer(var_v_data, (ntoken, num_heads, head_dim), config.dtype) last_page_offset = T.match_buffer(var_last_page_offset, (nseq,), "int32") page_table_indptr = T.match_buffer(var_page_table_indptr, (nseq + 1,), "int32") page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32") append_length_indptr = T.match_buffer(var_append_length_indptr, (nseq + 1,), "int32") pos2seqidx = T.match_buffer(var_pos2seqidx, (ntoken,), "int32") - for global_pos, h, f in T.grid(ntoken, nhead, nfeat): + for global_pos, h, f in T.grid(ntoken, num_heads, head_dim): with T.block("k_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - seq_idx = pos2seqidx[vgpos] - seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + seq_idx: T.int64 = T.Cast("int64", pos2seqidx[vgpos]) + seqlen: T.int64 = T.Cast("int64", (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx]) pages[ page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], layer_id, @@ -1185,8 +1190,8 @@ def kv_cache_transpose_append( ] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - seq_idx = pos2seqidx[vgpos] - seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + seq_idx: T.int64 = T.Cast("int64", pos2seqidx[vgpos]) + seqlen: T.int64 = T.Cast("int64", (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx]) pages[ page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], layer_id, @@ -1198,8 +1203,8 @@ def kv_cache_transpose_append( # fmt: on bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append") - # Todo: integrating attention TIR func/kernel. - bb.add_func(relax.extern("attention_func"), "attention") + bb.add_func(relax.extern("paged_kv_cache.attention_kernel_prefill"), "attention_prefill") + bb.add_func(relax.extern("paged_kv_cache.attention_kernel_decode"), "attention_decode") def setup_params(mod, param_manager, dtype, config, args): @@ -1325,7 +1330,9 @@ def get_model(args, hf_config): build_model_only=args.build_model_only, ) else: - raise Exception("The model config should contain information about maximum sequence length.") + raise Exception( + "The model config should contain information about maximum sequence length." + ) # If there is a user-provided maximum sequence length, override hf config. if args.max_seq_len != -1: @@ -1338,7 +1345,7 @@ def get_model(args, hf_config): create_embed_func(bb, param_manager, config, args.quantization) if enable_batching: - emit_paged_kv_cache_op(bb, dtype) + emit_paged_kv_cache_op(bb, config) create_prefill_func_for_batching(bb, param_manager, config, args.quantization) create_decoding_func_for_batching(bb, param_manager, config, args.quantization) create_paged_kv_cache_func(bb, config) @@ -1355,19 +1362,23 @@ def get_model(args, hf_config): max_window_size=config.max_sequence_length, stop_tokens=[2], add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, ) mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length + tir_bound_map["vocab_size"] = args.max_vocab_size + if enable_batching: + tir_bound_map["nseq"] = args.max_batch_size for gv in mod.functions: func = mod[gv] if isinstance(func, relax.Function): - mod[gv] = func.with_attr( - "tir_var_upper_bound", - { - "n": config.max_sequence_length, - "m": config.max_sequence_length, - }, - ) + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) if args.build_model_only: return mod, param_manager, None, config diff --git a/mlc_llm/relax_model/mistral.py b/mlc_llm/relax_model/mistral.py index 31ed39fdb5..bd8094a83b 100644 --- a/mlc_llm/relax_model/mistral.py +++ b/mlc_llm/relax_model/mistral.py @@ -41,8 +41,7 @@ def __init__( tie_word_embeddings=False, vocab_size=32000, dtype="float32", - sliding_window_chunk_size=-1, - max_sequence_length=-1, # Does not play a role, kept for compatibility. + max_sequence_length=16384, combine_matmul=True, build_model_only=False, num_shards=1, @@ -65,12 +64,7 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings self.vocab_size = vocab_size self.dtype = dtype - if sliding_window_chunk_size == -1: - # chunk size same as sliding window by default - self.sliding_window_chunk_size = self.sliding_window - else: - self.sliding_window_chunk_size = sliding_window_chunk_size - self.max_sequence_length = max_sequence_length + self.max_sequence_length = sliding_window * 4 self.combine_matmul = combine_matmul if build_model_only and num_shards > 1: self.num_shards = num_shards @@ -292,7 +286,7 @@ def interleave_kv( key_cur: relax.Expr, value_cur: relax.Expr, kv_seq_len: int, - cache_len: int, + rolling_cache_len: int, cache_offset: int, past_key_value: Tuple[relax.Expr], ): @@ -303,9 +297,9 @@ def interleave_kv( kv_cur_dtype = key_cur.struct_info.dtype assert kv_cur_shape[0] == 1 # bsz kv_batched_cache_shape = R.shape( - [kv_cur_shape[0], cache_len, kv_cur_shape[2], kv_cur_shape[3]] + [kv_cur_shape[0], rolling_cache_len, kv_cur_shape[2], kv_cur_shape[3]] ) - kv_cache_shape = R.shape([cache_len, kv_cur_shape[2], kv_cur_shape[3]]) + kv_cache_shape = R.shape([rolling_cache_len, kv_cur_shape[2], kv_cur_shape[3]]) # fecth past keys and values from cache k_cache, v_cache = past_key_value @@ -328,16 +322,16 @@ def interleave_kv( key_cached = nn.emit(reshape(key_cached, kv_batched_cache_shape)) value_cached = nn.emit(reshape(value_cached, kv_batched_cache_shape)) - def te_unrotate_concat(x, x_cached, cache_offset, cache_len): + def te_unrotate_concat(x, x_cached, cache_offset, rolling_cache_len): return te.compute( (kv_cur_shape[0], kv_seq_len, kv_cur_shape[2], kv_cur_shape[3]), lambda b, s, h, d: te.if_then_else( - s < cache_len - cache_offset, + s < rolling_cache_len - cache_offset, x_cached[b, cache_offset + s, h, d], te.if_then_else( - s < cache_len, - x_cached[b, s + cache_offset - cache_len, h, d], - x[b, s - cache_len, h, d], + s < rolling_cache_len, + x_cached[b, s + cache_offset - rolling_cache_len, h, d], + x[b, s - rolling_cache_len, h, d], ), ), name="unrotate_concat_te", @@ -348,7 +342,7 @@ def te_unrotate_concat(x, x_cached, cache_offset, cache_len): key_cur, key_cached, cache_offset, - cache_len, + rolling_cache_len, primfunc_name_hint="te_unrotate_concat_key", ) value = nn.emit_te( @@ -356,7 +350,7 @@ def te_unrotate_concat(x, x_cached, cache_offset, cache_len): value_cur, value_cached, cache_offset, - cache_len, + rolling_cache_len, primfunc_name_hint="te_unrotate_concat_value", ) @@ -376,17 +370,17 @@ def te_squeeze(x): squeezed_key = nn.emit_te(te_squeeze, key_cur) squeezed_value = nn.emit_te(te_squeeze, value_cur) - f_kv_cache_overwrite = relax.extern("vm.builtin.attention_kv_cache_window_override") + f_kv_cache_override = relax.extern("vm.builtin.attention_kv_cache_window_override") k_cache = nn.emit( relax.Call( - f_kv_cache_overwrite, + f_kv_cache_override, args=[k_cache, squeezed_key, relax.PrimValue(self.sliding_window)], sinfo_args=[relax.ObjectStructInfo()], ) ) v_cache = nn.emit( relax.Call( - f_kv_cache_overwrite, + f_kv_cache_override, args=[v_cache, squeezed_value, relax.PrimValue(self.sliding_window)], sinfo_args=[relax.ObjectStructInfo()], ) @@ -458,11 +452,11 @@ def forward( ) # concat current kv with cached kv (unrotating the cache) - cache_len = cache_len_shape.struct_info.values[0] + rolling_cache_len = cache_len_shape.struct_info.values[0] kv_seq_len = kv_seq_len_shape.struct_info.values[0] cache_offset = (all_seq_len - q_len) % self.sliding_window key, value, updated_key_value = self.interleave_kv( - key_cur, value_cur, kv_seq_len, cache_len, cache_offset, past_key_value + key_cur, value_cur, kv_seq_len, rolling_cache_len, cache_offset, past_key_value ) if self.num_key_value_heads != self.num_query_heads: @@ -791,7 +785,7 @@ def create_encoding_func( bsz = 1 seq_len = tvm.tir.Var("n", "int64") # number of tokens for the input all_seq_len = tvm.tir.Var("m", "int64") # total_seq_len in `llm_chat.cc` (including seq_len) - cache_len = tvm.tir.Var("c", "int64") # cache_len captures number of elements in the cache + rolling_cache_len = tvm.tir.Var("c", "int64") # rolling_cache_len captures number of elements in the cache kv_seq_len = tvm.tir.Var( "k", "int64" ) # kv_seq_len captures number of elements in cache + seq_len @@ -807,7 +801,7 @@ def create_encoding_func( else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") ) all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - cache_len_shape = relax.Var("cache_len", relax.ShapeStructInfo((cache_len,))) + cache_len_shape = relax.Var("rolling_cache_len", relax.ShapeStructInfo((rolling_cache_len,))) kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) past_key_values = relax.Var( "kv_cache", @@ -848,7 +842,7 @@ def create_decoding_func( bsz = 1 all_seq_len = tvm.tir.Var("m", "int64") - cache_len = tvm.tir.Var("c", "int64") # cache_len captures number of elements in the cache + rolling_cache_len = tvm.tir.Var("c", "int64") # rolling_cache_len captures number of elements in the cache kv_seq_len = tvm.tir.Var( "k", "int64" ) # kv_seq_len captures number of elements in cache + seq_len @@ -859,7 +853,7 @@ def create_decoding_func( input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - cache_len_shape = relax.Var("cache_len", relax.ShapeStructInfo((cache_len,))) + cache_len_shape = relax.Var("rolling_cache_len", relax.ShapeStructInfo((rolling_cache_len,))) kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) past_key_values = relax.Var( "kv_cache", @@ -939,6 +933,8 @@ def get_model(args, hf_config): if args.sliding_window != -1: hf_config["sliding_window"] = args.sliding_window + if args.max_seq_len != -1: + hf_config["max_sequence_length"] = args.max_seq_len config = MistralConfig( **hf_config, @@ -946,11 +942,13 @@ def get_model(args, hf_config): combine_matmul=True, num_shards=args.num_shards, build_model_only=args.build_model_only, - sliding_window_chunk_size=args.sliding_window_chunk_size, ) assert config.sliding_window != -1 - assert config.sliding_window_chunk_size != -1 + + # prefill chunk size same as sliding window by default + if args.prefill_chunk_size < 1: + args.prefill_chunk_size = config.sliding_window param_manager = ParamManager() bb = relax.BlockBuilder() @@ -966,7 +964,7 @@ def get_model(args, hf_config): stop_tokens=[2], add_prefix_space=False, sliding_window=config.sliding_window, - sliding_window_chunk_size=config.sliding_window_chunk_size, + prefill_chunk_size=args.prefill_chunk_size, ) mod = bb.get() @@ -976,9 +974,9 @@ def get_model(args, hf_config): mod[gv] = func.with_attr( "tir_var_upper_bound", { - "n": config.sliding_window_chunk_size, + "n": args.prefill_chunk_size, "c": config.sliding_window, - "k": config.sliding_window + config.sliding_window_chunk_size, + "k": config.sliding_window + args.prefill_chunk_size, }, ) diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py index 89c15a7955..d641fcc54a 100644 --- a/mlc_llm/relax_model/stablelm_3b.py +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -12,9 +12,9 @@ from ..quantization import ParamQuantKind, QuantizationScheme from .commons import create_metadata_func +from .llama import Embedding, Linear from .modules import ModuleList, RotaryEmbedding from .param_manager import ParamManager -from .llama import Embedding, Linear @dataclass @@ -40,7 +40,7 @@ def __init__( combine_matmul=True, num_shards=1, build_model_only=False, - convert_weight_only=False, + convert_weights_only=False, **kwargs, ): self.dtype = dtype @@ -376,17 +376,21 @@ def forward( all_seq_len_shape=all_seq_len_shape, ) if self.self_attn.num_shards > 1: - residual = nn.emit(residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype)) + residual = nn.emit( + residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) + ) hidden_states = nn.emit(residual + hidden_states) if self.self_attn.num_shards > 1: hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) if self.mlp.num_shards > 1: - residual = nn.emit(residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype)) + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) hidden_states = nn.emit(residual + hidden_states) if self.mlp.num_shards > 1: hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) @@ -444,7 +448,9 @@ def forward(self, input_ids: relax.Expr): class StableLM3bModell(nn.Module): - def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + def __init__( + self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False + ): rotary_embedding = RotaryEmbedding( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -461,7 +467,10 @@ def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_em self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) self.layers = ModuleList( - [StableLM3bDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] + [ + StableLM3bDecoderLayer(config, rotary_embedding) + for _ in range(config.num_hidden_layers) + ] ) self.norm = LayerNorm(config.hidden_size, dtype=config.dtype, eps=config.norm_eps) @@ -530,7 +539,9 @@ def forward( class StableLM3bForCausalLM(nn.Module): - def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + def __init__( + self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False + ): self.model = StableLM3bModell(config, vocab_size_var, sep_embed) self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) @@ -582,7 +593,7 @@ def create_embed_func( func_name = "embed" bsz = 1 - seq_len = tvm.tir.Var("n", "int64") + seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) @@ -654,7 +665,7 @@ def create_decoding_func( func_name = "decode" bsz = 1 - all_seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") with bb.function(func_name): model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64")) @@ -779,7 +790,7 @@ def get_model(args, hf_config): combine_matmul=True, num_shards=args.num_shards, build_model_only=args.build_model_only, - convert_weight_only=args.convert_weight_only, + convert_weights_only=args.convert_weights_only, ) if max_seq_len != -1: config.max_sequence_length = max_seq_len @@ -800,19 +811,20 @@ def get_model(args, hf_config): max_window_size=config.max_sequence_length, stop_tokens=[2], add_prefix_space=False, + prefill_chunk_size=args.prefill_chunk_size, ) mod = bb.get() + + tir_bound_map = dict() + tir_bound_map["n"] = ( + args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length + ) + tir_bound_map["m"] = config.max_sequence_length for gv in mod.functions: func = mod[gv] if isinstance(func, relax.Function): - mod[gv] = func.with_attr( - "tir_var_upper_bound", - { - "n": config.max_sequence_length, - "m": config.max_sequence_length, - }, - ) + mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) if args.build_model_only: return mod, param_manager, None, config diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 7c13b0ba30..767138f673 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -120,7 +120,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: if args.quantization not in quantization_schemes: raise ValueError(f'Quantization "{args.quantization}" is not supported.') - use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft"] + use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"] args.quantization = quantization_schemes[args.quantization] if use_ft_quant and args.num_shards > 1: @@ -290,7 +290,6 @@ def save_params(params: List[tvm.nd.NDArray], artifact_path: str, num_presharded meta_data = {} param_dict = {} meta_data["ParamSize"] = len(params) - for i, nd in enumerate(params): if num_presharded == 1: param_name = f"param_{i}" diff --git a/pyproject.toml b/pyproject.toml index b1f082240c..1ffd135abf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ show_error_context = true follow_imports = "skip" ignore_errors = false strict_optional = false -install_types = true [tool.pylint.messages_control] max-line-length = 100 diff --git a/python/mlc_chat/__main__.py b/python/mlc_chat/__main__.py new file mode 100644 index 0000000000..4e51dcb15d --- /dev/null +++ b/python/mlc_chat/__main__.py @@ -0,0 +1,44 @@ +"""Entrypoint of all CLI commands from MLC LLM""" +import logging +import sys + +from mlc_chat.support.argparse import ArgumentParser + +logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +def main(): + """Entrypoint of all CLI commands from MLC LLM""" + parser = ArgumentParser("MLC LLM Command Line Interface.") + parser.add_argument( + "subcommand", + type=str, + choices=["compile", "convert_weight", "gen_mlc_chat_config"], + help="Subcommand to to run. (choices: %(choices)s)", + ) + parsed = parser.parse_args(sys.argv[1:2]) + # pylint: disable=import-outside-toplevel + if parsed.subcommand == "compile": + from mlc_chat.cli import compile as cli + + cli.main(sys.argv[2:]) + elif parsed.subcommand == "convert_weight": + from mlc_chat.cli import convert_weight as cli + + cli.main(sys.argv[2:]) + elif parsed.subcommand == "gen_mlc_chat_config": + from mlc_chat.cli import gen_mlc_chat_config as cli + + cli.main(sys.argv[2:]) + else: + raise ValueError(f"Unknown subcommand {parsed.subcommand}") + # pylint: enable=import-outside-toplevel + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index bcadaa84ba..215a5a46d1 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -13,6 +13,8 @@ import tvm from tvm.runtime import disco # pylint: disable=unused-import +from mlc_chat.support.auto_device import detect_device + from . import base # pylint: disable=unused-import if TYPE_CHECKING: @@ -62,6 +64,8 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes When the ``stop_str`` is encountered, the model will stop generating output. stop_tokens : Optional[List[int]] A list of token IDs that act as stop tokens. + prefix_tokens : Optional[List[int]] + Token list prefixing the conversation. add_bos : Optional[bool] Determines whether a beginning-of-string (bos) token should be added before the input tokens. @@ -78,6 +82,7 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes role_empty_sep: Optional[str] = None stop_str: Optional[str] = None stop_tokens: Optional[List[int]] = None + prefix_tokens: Optional[List[int]] = None add_bos: Optional[bool] = None def __post_init__(self): @@ -588,89 +593,6 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio return json.dumps(asdict(generation_config)) -def _parse_device_str(device: str) -> Tuple[tvm.runtime.Device, str]: - """Parse the input device identifier into device name and id. - - Parameters - ---------- - device : str - The device identifier to parse. - It can be "device_name" (e.g., "cuda") or - "device_name:device_id" (e.g., "cuda:1"). - - Returns - ------- - dev : tvm.runtime.Device - The device. - - device_name : str - The name of the device. - """ - device_err_msg = ( - f"Invalid device name: {device}. Please enter the device in the form " - "'device_name:device_id' or 'device_name', where 'device_name' needs to be " - "one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'." - ) - device_args = device.split(":") - if len(device_args) == 1: - device_name, device_id = device_args[0], 0 - elif len(device_args) == 2: - device_name, device_id = device_args[0], int(device_args[1]) - elif len(device_args) > 2: - raise ValueError(device_err_msg) - - if device_name == "cuda": - device = tvm.cuda(device_id) - elif device_name == "metal": - device = tvm.metal(device_id) - elif device_name == "vulkan": - device = tvm.vulkan(device_id) - elif device_name == "rocm": - device = tvm.rocm(device_id) - elif device_name == "opencl": - device = tvm.opencl(device_id) - elif device_name == "auto": - device, device_name = _detect_local_device(device_id) - logging.info("System automatically detected device: %s", device_name) - else: - raise ValueError(device_err_msg) - - return device, device_name - - -def _detect_local_device(device_id: int = 0) -> Tuple[tvm.runtime.Device, str]: - """Automatically detect the local device if user does not specify. - - Parameters - ---------- - device_id : int - The local device id. - - Returns - ------ - dev : tvm.runtime.Device - The local device. - - device_name : str - The name of the device. - """ - if tvm.metal().exist: - return tvm.metal(device_id), "metal" - if tvm.rocm().exist: - return tvm.rocm(device_id), "rocm" - if tvm.cuda().exist: - return tvm.cuda(device_id), "cuda" - if tvm.vulkan().exist: - return tvm.vulkan(device_id), "vulkan" - if tvm.opencl().exist: - return tvm.opencl(device_id), "opencl" - logging.info( - "None of the following device is detected: metal, rocm, cuda, vulkan, opencl. " - "Switch to llvm instead." - ) - return tvm.cpu(device_id), "llvm" - - class ChatModule: # pylint: disable=too-many-instance-attributes r"""The ChatModule for MLC LLM. @@ -735,7 +657,7 @@ def __init__( ): # 0. Get device: # Retrieve device_name and device_id (if any, default 0) from device arg - self.device, device_name = _parse_device_str(device) + self.device = detect_device(device) device_type = self.device.device_type device_id = self.device.device_id @@ -777,7 +699,7 @@ def __init__( self.model_path, self.chat_config, model_lib_path, - device_name, + self.device.MASK2STR[self.device.device_type], self.config_file_path, ) diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index c340119c98..5df1164c19 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -1,37 +1,25 @@ """Command line entrypoint of compilation.""" import argparse -import logging import re from pathlib import Path from typing import Union from mlc_chat.compiler import ( # pylint: disable=redefined-builtin + HELP, MODELS, QUANTIZATION, OptimizationFlags, compile, ) +from ..support.argparse import ArgumentParser from ..support.auto_config import detect_config, detect_model_type from ..support.auto_target import detect_target_and_host -logging.basicConfig( - level=logging.INFO, - style="{", - datefmt="%Y-%m-%d %H:%M:%S", - format="[{asctime}] {levelname} {filename}:{lineno}: {message}", -) - -def main(): +def main(argv): """Parse command line argumennts and call `mlc_llm.compiler.compile`.""" - def _parse_config(path: Union[str, Path]) -> Path: - try: - return detect_config(path) - except ValueError as err: - raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") - def _parse_output(path: Union[str, Path]) -> Path: path = Path(path) parent = path.parent @@ -48,98 +36,78 @@ def _check_prefix_symbols(prefix: str) -> str: "numbers (0-9), alphabets (A-Z, a-z) and underscore (_)." ) - parser = argparse.ArgumentParser("MLC LLM Compiler") + parser = ArgumentParser("MLC LLM Compiler") parser.add_argument( - "--config", - type=_parse_config, + "--model", + type=detect_config, required=True, - help="Path to config.json file or to the directory that contains config.json, which is " - "a HuggingFace standard that defines model architecture, for example, " - "https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json", + dest="config", + help=HELP["model"] + " (required)", ) parser.add_argument( "--quantization", type=str, required=True, choices=list(QUANTIZATION.keys()), - help="Quantization format.", + help=HELP["quantization"] + " (required, choices: %(choices)s)", ) parser.add_argument( "--model-type", type=str, default="auto", choices=["auto"] + list(MODELS.keys()), - help="Model architecture, for example, llama. If not set, it is inferred " - "from the config.json file. " - "(default: %(default)s)", + help=HELP["model_type"] + ' (default: "%(default)s")', ) parser.add_argument( "--device", type=str, default="auto", - help="The GPU device to compile the model to. If not set, it is inferred from locally " - "available GPUs. " - "(default: %(default)s)", + help=HELP["device_compile"] + ' (default: "%(default)s")', ) parser.add_argument( "--host", type=str, default="auto", - choices=[ - "auto", - "arm", - "arm64", - "aarch64", - "x86-64", - ], - help="The host CPU ISA to compile the model to. If not set, it is inferred from the " - "local CPU. " - "(default: %(default)s)", + help=HELP["host"] + ' (default: "%(default)s")', ) parser.add_argument( "--opt", type=OptimizationFlags.from_str, default="O2", - help="Optimization flags. MLC LLM maintains a predefined set of optimization flags, " - "denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, " - "and O3 represents extreme optimization that could potentially break the system. " - "Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. " - '--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0. ' - "(default: %(default)s)", + help=HELP["opt"] + ' (default: "%(default)s")', ) parser.add_argument( "--prefix-symbols", type=str, default="", - help='Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". ' - "This is useful when compiling multiple models into a single library to avoid symbol " - "conflicts. Differet from objcopy, this takes no effect for shared library. " - '(default: "")', + help=HELP["prefix_symbols"] + ' (default: "%(default)s")', ) parser.add_argument( - "--max-sequence-length", + "--context-window-size", type=int, default=None, - help="Option to override the maximum sequence length supported by the model. " - "An LLM is usually trained with a fixed maximum sequence length, which is usually " - "explicitly specified in model spec. By default, if this option is not set explicitly, " - "the maximum sequence length is determined by `max_sequence_length` or " - "`max_position_embeddings` in config.json, which can be inaccuate for some models.", + help=HELP["context_window_size"] + ' (default: "%(default)s")', ) parser.add_argument( "--output", "-o", type=_parse_output, required=True, - help="The name of the output file. The suffix determines if the output file is a " - "shared library or objects. Available suffixes: " - "1) Linux: .so (shared), .tar (objects); " - "2) macOS: .dylib (shared), .tar (objects); " - "3) Windows: .dll (shared), .tar (objects); " - "4) Android, iOS: .tar (objects); " - "5) Web: .wasm (web assembly)", + help=HELP["output_compile"] + " (required)", + ) + parser.add_argument( + "--sliding-window", + type=int, + default=None, + help=HELP["sliding_window"] + ' (default: "%(default)s")', ) - parsed = parser.parse_args() + parser.add_argument( + "--prefill-chunk-size", + type=int, + default=None, + help=HELP["prefill_chunk_size"] + ' (default: "%(default)s")', + ) + parsed = parser.parse_args(argv) target, build_func = detect_target_and_host(parsed.device, parsed.host) parsed.model_type = detect_model_type(parsed.model_type, parsed.config) compile( @@ -151,9 +119,7 @@ def _check_prefix_symbols(prefix: str) -> str: build_func=build_func, prefix_symbols=parsed.prefix_symbols, output=parsed.output, - max_sequence_length=parsed.max_sequence_length, + context_window_size=parsed.context_window_size, + sliding_window=parsed.sliding_window, + prefill_chunk_size=parsed.prefill_chunk_size, ) - - -if __name__ == "__main__": - main() diff --git a/python/mlc_chat/cli/convert_weight.py b/python/mlc_chat/cli/convert_weight.py index cf4c205009..97d437aa49 100644 --- a/python/mlc_chat/cli/convert_weight.py +++ b/python/mlc_chat/cli/convert_weight.py @@ -1,38 +1,25 @@ """Command line entrypoint of weight conversion.""" import argparse -import logging from pathlib import Path from typing import Union -from mlc_chat.compiler import MODELS, QUANTIZATION, convert_weight +from mlc_chat.compiler import HELP, MODELS, QUANTIZATION, convert_weight +from ..support.argparse import ArgumentParser from ..support.auto_config import detect_config, detect_model_type -from ..support.auto_target import detect_device +from ..support.auto_device import detect_device from ..support.auto_weight import detect_weight -logging.basicConfig( - level=logging.INFO, - style="{", - datefmt="%Y-%m-%d %H:%M:%S", - format="[{asctime}] {levelname} {filename}:{lineno}: {message}", -) - -def main(): +def main(argv): """Parse command line argumennts and apply quantization.""" - def _parse_config(path: Union[str, Path]) -> Path: - try: - return detect_config(path) - except ValueError as err: - raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") - def _parse_source(path: Union[str, Path], config_path: Path) -> Path: if path == "auto": return config_path.parent path = Path(path) - if not path.is_dir(): - raise argparse.ArgumentTypeError(f"Directory does not exist: {path}") + if not path.exists(): + raise argparse.ArgumentTypeError(f"Model source does not exist: {path}") return path def _parse_output(path: Union[str, Path]) -> Path: @@ -41,64 +28,56 @@ def _parse_output(path: Union[str, Path]) -> Path: path.mkdir(parents=True, exist_ok=True) return path - parser = argparse.ArgumentParser("MLC AutoLLM Quantization Framework") + parser = ArgumentParser("MLC AutoLLM Quantization Framework") parser.add_argument( - "--config", - type=_parse_config, + "--model", + type=detect_config, required=True, - help="Path to config.json file or to the directory that contains config.json, which is " - "a HuggingFace standard that defines model architecture, for example, " - "https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json", - ) - parser.add_argument( - "--source", - type=str, - default="auto", - help="The path to original model weight, infer from `config` if missing. " - "(default: %(default)s)", - ) - parser.add_argument( - "--source-format", - type=str, - choices=["auto", "huggingface-torch", "huggingface-safetensor"], - default="auto", - help="The format of source model weight, infer from `config` if missing. " - "(default: %(default)s)", + dest="config", + help=HELP["model"] + " (required)", ) parser.add_argument( "--quantization", type=str, required=True, choices=list(QUANTIZATION.keys()), - help="Quantization format, for example `q4f16_1`.", + help=HELP["quantization"] + " (required, choices: %(choices)s)", ) parser.add_argument( "--model-type", type=str, default="auto", choices=["auto"] + list(MODELS.keys()), - help="Model architecture, for example, llama. If not set, it is inferred " - "from the config.json file. " - "(default: %(default)s)", + help=HELP["model_type"] + ' (default: "%(default)s")', ) parser.add_argument( "--device", default="auto", type=detect_device, - help="The device used to do quantization, for example, / `cuda:0`. " - "Detect from local environment if not specified. " - "(default: %(default)s)", + help=HELP["device_quantize"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--source", + type=str, + default="auto", + help=HELP["source"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--source-format", + type=str, + choices=["auto", "huggingface-torch", "huggingface-safetensor", "awq"], + default="auto", + help=HELP["source_format"] + ' (default: "%(default)s", choices: %(choices)s")', ) parser.add_argument( "--output", "-o", type=_parse_output, required=True, - help="The output directory to save the quantized model weight, " - "will contain `params_shard_*.bin` and `ndarray-cache.json`.", + help=HELP["output_quantize"] + " (required)", ) - parsed = parser.parse_args() + parsed = parser.parse_args(argv) parsed.source, parsed.source_format = detect_weight( weight_path=_parse_source(parsed.source, parsed.config), config_json_path=parsed.config, @@ -114,7 +93,3 @@ def _parse_output(path: Union[str, Path]) -> Path: source_format=parsed.source_format, output=parsed.output, ) - - -if __name__ == "__main__": - main() diff --git a/python/mlc_chat/cli/delivery.py b/python/mlc_chat/cli/delivery.py new file mode 100644 index 0000000000..3a0e3cc62e --- /dev/null +++ b/python/mlc_chat/cli/delivery.py @@ -0,0 +1,261 @@ +"""Continuous model delivery for MLC LLM models.""" +import argparse +import dataclasses +import json +import logging +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple, Union + +from huggingface_hub import HfApi # pylint: disable=import-error +from huggingface_hub.utils import HfHubHTTPError # pylint: disable=import-error + +from ..support.argparse import ArgumentParser +from ..support.download import git_clone +from ..support.style import bold, green, red + +logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + +logger = logging.getLogger(__name__) +MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None) + + +@dataclasses.dataclass +class ModelInfo: + """Necessary information for the model delivery""" + + model_id: str + model: Path + conv_template: str + context_window_size: int + quantization: str + source_format: str = "auto" + + +class DeferredScope: + """A context manager that defers execution of functions until exiting the scope.""" + + def __init__(self): + self.deferred_functions = [] + + def add(self, func: Callable[[], None]): + """Add a function to be executed when exiting the scope.""" + self.deferred_functions.append(func) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + for func in reversed(self.deferred_functions): + func() + return False + + def create_temp_dir(self) -> Path: + """Create a temporary directory that will be deleted when exiting the scope.""" + temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR) + self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) + return Path(temp_dir) + + +def _clone_repo(model: Union[str, Path], deferred: DeferredScope) -> Path: + if isinstance(model, Path): + if not model.exists(): + raise ValueError(f"Invalid model source: {model}") + return model + if model.startswith("https://") or model.startswith("git://"): + result = deferred.create_temp_dir() / "repo" + git_clone(model, result, ignore_lfs=False) + return result + result = Path(model) + if result.exists(): + return result + raise ValueError(f"Invalid model source: {model}") + + +def _run_quantization( + model_info: ModelInfo, + repo: str, + api: HfApi, +) -> bool: + logger.info("[HF] Creating repo https://huggingface.co/%s", repo) + try: + api.create_repo(repo_id=repo, private=False) + except HfHubHTTPError as error: + if error.response.status_code != 409: + raise + logger.info("[HF] Repo already exists. Recreating...") + api.delete_repo(repo_id=repo) + api.create_repo(repo_id=repo, private=False) + logger.info("[HF] Repo recreated") + succeeded = True + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as output_dir: + log_path = Path(output_dir) / "logs.txt" + with log_path.open("a", encoding="utf-8") as log_file: + assert isinstance(model_info.model, Path) + logger.info("[MLC] Processing in directory: %s", output_dir) + cmd = [ + "mlc_chat", + "gen_mlc_chat_config", + "--model", + str(model_info.model), + "--quantization", + model_info.quantization, + "--conv-template", + model_info.conv_template, + "--context-window-size", + str(model_info.context_window_size), + "--output", + output_dir, + ] + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + cmd = [ + "mlc_chat", + "convert_weight", + "--model", + str(model_info.model), + "--quantization", + model_info.quantization, + "--source-format", + model_info.source_format, + "--output", + output_dir, + ] + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT) + logger.info("[MLC] Complete!") + if not (Path(output_dir) / "ndarray-cache.json").exists(): + logger.error( + "[%s] Model %s. Quantization %s. No weights metadata found.", + red("FAILED"), + model_info.model_id, + model_info.quantization, + ) + succeeded = False + logger.info("[HF] Uploading to: https://huggingface.co/%s", repo) + for _retry in range(10): + try: + api.upload_folder( + folder_path=output_dir, + repo_id=repo, + commit_message="Initial commit", + ) + except Exception as exc: # pylint: disable=broad-except + logger.error("[%s] %s. Retrying...", red("FAILED"), exc) + else: + break + else: + raise RuntimeError("Failed to upload to HuggingFace Hub with 10 retries") + return succeeded + + +def _main( # pylint: disable=too-many-locals + username: str, + api: HfApi, + spec: Dict[str, Any], +): + failed_cases: List[Tuple[str, str]] = [] + for task_index, task in enumerate(spec["tasks"], 1): + with DeferredScope() as deferred: + logger.info( + bold("[{task_index}/{total_tasks}] Processing model: ").format( + task_index=task_index, + total_tasks=len(spec["tasks"]), + ) + + green(task["model_id"]) + ) + model = _clone_repo(task["model"], deferred) + for quantization in spec["default_quantization"] + task.get("quantization", []): + model_info = { + "model_id": task["model_id"], + "model": model, + "context_window_size": task["context_window_size"], + "conv_template": task["conv_template"], + } + if isinstance(quantization, str): + model_info["quantization"] = quantization + else: + model_info["quantization"] = quantization.pop("format") + model_info.update(quantization) + repo = spec.get("destination", "{username}/{model_id}-{quantization}").format( + username=username, + model_id=model_info["model_id"], + quantization=model_info["quantization"], + ) + logger.info( + "%s%s. %s%s. %s%s", + bold("Model: "), + green(task["model_id"]), + bold("Quantization: "), + green(model_info["quantization"]), + bold("Repo: "), + green(f"https://huggingface.co/{repo}"), + ) + with DeferredScope() as inner_deferred: + model_info["model"] = _clone_repo(model_info["model"], inner_deferred) + result = _run_quantization( + ModelInfo(**model_info), + repo=spec["destination"].format( + username=username, + model_id=model_info["model_id"], + quantization=model_info["quantization"], + ), + api=api, + ) + if not result: + failed_cases.append( + (task["model_id"], model_info["quantization"]), + ) + if failed_cases: + logger.info("Total %s %s:", len(failed_cases), red("failures")) + for model_id, quantization in failed_cases: + logger.info(" Model %s. Quantization %s.", model_id, quantization) + + +def main(): + """Entry point.""" + + def _load_spec(path_spec: str) -> Dict[str, Any]: + path = Path(path_spec) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") + with path.open("r", encoding="utf-8") as i_f: + return json.load(i_f) + + parser = ArgumentParser("MLC LLM continuous model delivery") + parser.add_argument( + "--username", + type=str, + required=True, + help="HuggingFace username", + ) + parser.add_argument( + "--token", + type=str, + required=True, + help="HuggingFace access token, obtained under https://huggingface.co/settings/tokens", + ) + parser.add_argument( + "--spec", + type=_load_spec, + required=True, + help="Path to the spec file", + ) + parsed = parser.parse_args() + _main( + parsed.username, + spec=parsed.spec, + api=HfApi(token=parsed.token), + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/cli/gen_mlc_chat_config.py b/python/mlc_chat/cli/gen_mlc_chat_config.py new file mode 100644 index 0000000000..c1b2baad4a --- /dev/null +++ b/python/mlc_chat/cli/gen_mlc_chat_config.py @@ -0,0 +1,71 @@ +"""Command line entrypoint of configuration generation.""" +from pathlib import Path +from typing import Union + +from mlc_chat.compiler import CONV_TEMPLATES, HELP, MODELS, QUANTIZATION, gen_config + +from ..support.argparse import ArgumentParser +from ..support.auto_config import detect_config, detect_model_type + + +def main(argv): + """Parse command line argumennts and call `mlc_llm.compiler.gen_config`.""" + parser = ArgumentParser("MLC LLM Configuration Generator") + + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + if not path.is_dir(): + path.mkdir(parents=True, exist_ok=True) + return path + + parser.add_argument( + "--model", + type=detect_config, + required=True, + dest="config", + help=HELP["model"] + " (required)", + ) + parser.add_argument( + "--quantization", + type=str, + required=True, + choices=list(QUANTIZATION.keys()), + help=HELP["quantization"] + " (required, choices: %(choices)s)", + ) + parser.add_argument( + "--model-type", + type=str, + default="auto", + choices=["auto"] + list(MODELS.keys()), + help=HELP["model_type"] + ' (default: "%(default)s", choices: %(choices)s)', + ) + parser.add_argument( + "--conv-template", + type=str, + required=True, + choices=list(CONV_TEMPLATES), + help=HELP["conv_template"] + " (required, choices: %(choices)s)", + ) + parser.add_argument( + "--context-window-size", + type=int, + default=None, + help=HELP["context_window_size"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help=HELP["output_gen_mlc_chat_config"] + " (required)", + ) + parsed = parser.parse_args(argv) + model = detect_model_type(parsed.model_type, parsed.config) + gen_config( + config=parsed.config, + model=model, + quantization=QUANTIZATION[parsed.quantization], + conv_template=parsed.conv_template, + context_window_size=parsed.context_window_size, + output=parsed.output, + ) diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py index 6d0c8c223d..b34757fad8 100644 --- a/python/mlc_chat/compiler/__init__.py +++ b/python/mlc_chat/compiler/__init__.py @@ -7,6 +7,8 @@ from .convert_weight import ConversionArgs, convert_weight from .flags_model_config_override import ModelConfigOverride from .flags_optimization import OptimizationFlags +from .gen_mlc_chat_config import CONV_TEMPLATES, gen_config +from .help import HELP from .loader import LOADER, ExternMapping, HuggingFaceLoader, QuantizeMapping from .model import MODEL_PRESETS, MODELS, Model from .quantization import QUANTIZATION diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index ade62309c5..121d5b5aab 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -83,15 +83,18 @@ def _emit_metadata(metadata): def _attach_variable_bounds(mod, model_config): + tir_bound_map = {} + tir_bound_map["seq_len"] = model_config.prefill_chunk_size + + if hasattr(model_config, "sliding_window"): + tir_bound_map["rolling_cache_len"] = model_config.sliding_window + tir_bound_map["kv_seq_len"] = model_config.sliding_window + model_config.prefill_chunk_size + else: + tir_bound_map["total_seq_len"] = model_config.context_window_size + for g_var, func in mod.functions_items(): if isinstance(func, relax.Function): - mod[g_var] = func.with_attr( - "tir_var_upper_bound", - { - "seq_len": model_config.max_sequence_length, - "total_seq_len": model_config.max_sequence_length, - }, - ) + mod[g_var] = func.with_attr("tir_var_upper_bound", tir_bound_map) def _compile(args: CompileArgs): @@ -122,7 +125,9 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin build_func: Callable[[IRModule, CompileArgs], None], prefix_symbols: str, output: Path, - max_sequence_length: Optional[int], + context_window_size: Optional[int], + sliding_window: Optional[int], + prefill_chunk_size: Optional[int], ): """Compile a model given its configuration and quantization format to a specific target.""" args = CompileArgs( @@ -134,7 +139,11 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin build_func, prefix_symbols, output, - ModelConfigOverride(max_sequence_length=max_sequence_length), + ModelConfigOverride( + context_window_size=context_window_size, + sliding_window=sliding_window, + prefill_chunk_size=prefill_chunk_size, + ), ) args.display() _compile(args) diff --git a/python/mlc_chat/compiler/convert_weight.py b/python/mlc_chat/compiler/convert_weight.py index 7b1f4576b9..cd762de08a 100644 --- a/python/mlc_chat/compiler/convert_weight.py +++ b/python/mlc_chat/compiler/convert_weight.py @@ -7,6 +7,7 @@ import numpy as np from tvm.contrib import tvmjs +from tvm.relax.frontend import nn from tvm.runtime import Device, NDArray from tvm.runtime import cpu as cpu_device from tvm.target import Target @@ -51,6 +52,14 @@ def _device_to_str(device: Device) -> str: print(out.getvalue().rstrip()) +def _calc_total_params(model: nn.Module) -> int: + _, named_params = model.export_tvm(spec=model.get_default_spec()) # type: ignore + total_params = 0 + for _, param in named_params: + total_params += math.prod(param.shape) + return total_params + + def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-locals # model config & quantization config model_config = args.model.config.from_file(args.config) @@ -82,8 +91,8 @@ def _check_param(name: str, param: NDArray): # load and quantize param_dict = {} + total_params = _calc_total_params(args.model.model(model_config)) total_bytes = 0.0 - total_params = 0 with Target.from_device(args.device), tqdm.redirect(): for name, param in LOADER[args.source_format]( path=args.source, @@ -94,20 +103,32 @@ def _check_param(name: str, param: NDArray): param = param.copyto(cpu_device()) param_dict[name] = param total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize - total_params += math.prod(param.shape) if named_params: raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}") + # Log necessary statistics + logger.info( + "%s after quantization: %.3f GB", + green("Parameter size"), + total_bytes / (1024**3), + ) + logger.info(f"%s: {total_params:,}", green("Total parameters")) + logger.info( + "%s: %.3f", + green("Bits per parameter"), + total_bytes * 8.0 / total_params, + ) # dump to output directory tvmjs.dump_ndarray_cache( param_dict, str(args.output), - meta_data={"ParamSize": len(param_dict)}, + meta_data={ + "ParamSize": len(param_dict), + "ParamBytes": total_bytes, + "BitsPerParam": total_bytes * 8.0 / total_params, + }, encode_format="raw", ) - logger.info("%s to %s", green("Saved"), bold(str(args.output))) - logger.info("%s: %.3f GB", green("Total parameter size"), total_bytes / (1024**3)) - logger.info("%s: %d", green("Total number of parameter tensors"), len(param_dict)) - logger.info(f"%s: {total_params:,}", green("Total number of parameters")) + logger.info("Saved to directory: %s", bold(str(args.output))) def convert_weight( # pylint: disable=too-many-arguments diff --git a/python/mlc_chat/compiler/flags_model_config_override.py b/python/mlc_chat/compiler/flags_model_config_override.py index f1a25346a7..d3655b7e65 100644 --- a/python/mlc_chat/compiler/flags_model_config_override.py +++ b/python/mlc_chat/compiler/flags_model_config_override.py @@ -12,21 +12,49 @@ class ModelConfigOverride: """Flags for overriding model config.""" - max_sequence_length: Optional[int] = None + context_window_size: Optional[int] = None max_batch_size: Optional[int] = None num_shards: Optional[int] = None + sliding_window: Optional[int] = None + prefill_chunk_size: Optional[int] = None def apply(self, model_config): """Apply the overrides to the given model config.""" - if self.max_sequence_length is not None: + if self.context_window_size is not None: logger.info( "Overriding %s from %d to %d", - bold("max_sequence_length"), - model_config.max_sequence_length, - self.max_sequence_length, + bold("context_window_size"), + model_config.context_window_size, + self.context_window_size, ) - model_config.max_sequence_length = self.max_sequence_length + model_config.context_window_size = self.context_window_size if self.max_batch_size is not None: model_config.max_batch_size = self.max_batch_size if self.num_shards is not None: model_config.num_shards = self.num_shards + + # Handle sliding window and sliding window chunk size + if self.sliding_window is not None: + logger.info( + "Overriding %s from %d to %d", + bold("sliding_window"), + model_config.sliding_window, + self.sliding_window, + ) + model_config.sliding_window = self.sliding_window + if self.prefill_chunk_size is None: + logger.info( + "Provided %s but did not provide %s, setting both to %d", + bold("sliding_window"), + bold("prefill_chunk_size"), + model_config.sliding_window, + ) + model_config.prefill_chunk_size = self.prefill_chunk_size + if self.prefill_chunk_size is not None: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + model_config.prefill_chunk_size, + self.prefill_chunk_size, + ) + model_config.prefill_chunk_size = self.prefill_chunk_size diff --git a/python/mlc_chat/compiler/gen_mlc_chat_config.py b/python/mlc_chat/compiler/gen_mlc_chat_config.py new file mode 100644 index 0000000000..2298d55d85 --- /dev/null +++ b/python/mlc_chat/compiler/gen_mlc_chat_config.py @@ -0,0 +1,163 @@ +"""Generator of mlc-chat-config.json and tokenizer configuration.""" +import dataclasses +import json +import logging +import shutil +from pathlib import Path +from typing import Any, Dict, List, Optional + +from ..support.style import bold, green, red +from .flags_model_config_override import ModelConfigOverride +from .model import Model +from .quantization import Quantization + +logger = logging.getLogger(__name__) + +FOUND = green("Found") +NOT_FOUND = red("Not found") +VERSION = "0.1.0" + + +@dataclasses.dataclass +class MLCChatConfig: # pylint: disable=too-many-instance-attributes + """Arguments for `mlc_chat.compiler.gen_config`.""" + + version: str = VERSION + + model_type: str = None + quantization: str = None + model_config: Dict[str, Any] = None + vocab_size: int = None + max_window_size: int = None + + temperature: float = None + repetition_penalty: float = None + top_p: float = None + + mean_gen_len: int = None + max_gen_len: int = None + shift_fill_factor: float = None + + # Conversation template + conv_template: str = None + pad_token_id: int = None + bos_token_id: int = None + eos_token_id: int = None + tokenizer_files: List[str] = dataclasses.field(default_factory=list) + + +def gen_config( # pylint: disable=too-many-locals,too-many-arguments + config: Path, + model: Model, + quantization: Quantization, + conv_template: str, + context_window_size: Optional[int], + output: Path, +): + """Entrypoint of MLC Chat configuration generation.""" + with config.open("r", encoding="utf-8") as in_file: + model_config_json = json.load(in_file) + model_config = model.config.from_dict(model_config_json) + ModelConfigOverride( + context_window_size=context_window_size, + ).apply(model_config) + + mlc_chat_config = MLCChatConfig( + model_type=model.name, + quantization=quantization.name, + model_config=model_config_json, + vocab_size=model_config.vocab_size, + conv_template=conv_template, + max_window_size=model_config.context_window_size, + ) + # Step 1. Load `config.json` + for key, value in model_config_json.items(): + if hasattr(mlc_chat_config, key) and getattr(mlc_chat_config, key) is None: + setattr(mlc_chat_config, key, value) + logger.info("[config.json] Setting %s: %s", bold(key), value) + # Step 2. Load `generation_config.json` + generation_config = config.parent / "generation_config.json" + if generation_config.exists(): + logger.info("%s generation_config.json: %s", FOUND, generation_config) + with generation_config.open("r", encoding="utf-8") as in_file: + generation_config_json = json.load(in_file) + for key, value in generation_config_json.items(): + if hasattr(mlc_chat_config, key) and getattr(mlc_chat_config, key) is None: + setattr(mlc_chat_config, key, value) + logger.info("[generation_config.json] Setting %s: %s", bold(key), value) + else: + logger.info("%s generation_config.json: %s", NOT_FOUND, generation_config) + # Step 3. Copy tokenizer configuration + for filename in TOKENIZER_FILES: + file = config.parent / filename + if file.exists(): + mlc_chat_config.tokenizer_files.append(filename) + dest = output / filename + shutil.copy(file, dest) + logger.info("%s tokenizer config: %s. Copying to %s", FOUND, file, bold(str(dest))) + else: + logger.info("%s tokenizer config: %s", NOT_FOUND, file) + # Step 4. Load system default value + for key, value in DEFAULT_CONFIGS.items(): + if getattr(mlc_chat_config, key) is None: + setattr(mlc_chat_config, key, value) + logger.info("[System default] Setting %s: %s", bold(key), value) + # Dump the configuration file to output directory + out = output / "mlc-chat-config.json" + with out.open("w", encoding="utf-8") as out_file: + json.dump(dataclasses.asdict(mlc_chat_config), out_file, indent=2) + logger.info("Dumping configuration file to: %s", bold(str(out))) + + +DEFAULT_CONFIGS = { + # Conversation + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + # Configuration of text generation + "temperature": 0.7, + "repetition_penalty": 1.0, + "top_p": 0.95, + # Control the behavior of the runtime + "mean_gen_len": 128, + "max_gen_len": 512, + "shift_fill_factor": 0.3, +} + +TOKENIZER_FILES = [ + "tokenizer.model", + "tokenizer.json", + "vocab.json", + "merges.txt", + "added_tokens.json", + "tokenizer_config.json", +] + +CONV_TEMPLATES = { + "chatml", + "llama_default", + "llama-2", + "mistral_default", + "codellama_completion", + "codellama_instruct", + "vicuna_v1.1", + "conv_one_shot", + "redpajama_chat", + "rwkv_world", + "rwkv", + "gorilla", + "guanaco", + "dolly", + "oasst", + "stablelm", + "stablecode_completion", + "stablecode_instruct", + "minigpt", + "moss", + "LM", + "stablelm-3b", + "gpt_bigcode", + "wizardlm_7b", + "wizard_coder_or_math", + "glm", +} diff --git a/python/mlc_chat/compiler/help.py b/python/mlc_chat/compiler/help.py new file mode 100644 index 0000000000..9032959d8e --- /dev/null +++ b/python/mlc_chat/compiler/help.py @@ -0,0 +1,103 @@ +"""Help message for CLI arguments.""" +from .model import MODEL_PRESETS + +HELP = { + "model": ( + """ +1) Path to a HuggingFace model directory that contains a `config.json` or +2) Path to `config.json` in HuggingFace format, or +3) The name of a pre-defined model architecture. + +A `config.json` file in HuggingFace format defines the model architecture, including the vocabulary +size, the number of layers, the hidden size, number of attention heads, etc. +Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json. + +A HuggingFace directory often contains a `config.json` which defines the model architecture, +the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations, +as well as an optional `generation_config.json` provides additional default configuration for +text generation. +Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. + +Pre-defined model architectures include """ + + ", ".join(f'"{preset}"' for preset in MODEL_PRESETS) + + "." + ).strip(), + "quantization": """ +Quantization format. +""".strip(), + "model_type": """ +Model architecture such as "llama". If not set, it is inferred from `config.json`. +""".strip(), + "device_compile": """ +The GPU device to compile the model to. If not set, it is inferred from GPUs available locally. +""".strip(), + "device_quantize": """ +The device used to do quantization such as "cuda" or "cuda:0". Will detect from local available GPUs +if not specified. +""".strip(), + "host": """ +The host LLVM triple to compile the model to. If not set, it is inferred from the local CPU and OS. +Examples of the LLVM triple: +1) iPhones: arm64-apple-ios; +2) ARM64 Android phones: aarch64-linux-android; +3) WebAssembly: wasm32-unknown-unknown-wasm; +4) Windows: x86_64-pc-windows-msvc; +5) ARM macOS: arm64-apple-darwin. +""".strip(), + "opt": """ +Optimization flags. MLC LLM maintains a predefined set of optimization flags, +denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, +and O3 represents extreme optimization that could potentially break the system. +Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. +--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0". +""".strip(), + "prefix_symbols": """ +Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". +This is useful when compiling multiple models into a single library to avoid symbol +conflicts. Differet from objcopy, this takes no effect for shared library. +""".strip(), + "context_window_size": """ +Option to provide the maximum sequence length supported by the model. +This is usually explictly shown as context length or context window in the model card. +If this option is not set explicitly, by default, +it will be determined by `context_window_size` or `max_position_embeddings` in `config.json`, +and the latter is usually inaccurate for some models. +""".strip(), + "output_compile": """ +The name of the output file. The suffix determines if the output file is a shared library or +objects. Available suffixes: +1) Linux: .so (shared), .tar (objects); +2) macOS: .dylib (shared), .tar (objects); +3) Windows: .dll (shared), .tar (objects); +4) Android, iOS: .tar (objects); +5) Web: .wasm (web assembly). +""".strip(), + "source": """ +The path to original model weight, infer from `config` if missing. +""".strip(), + "source_format": """ +The format of source model weight, infer from `config` if missing. +""".strip(), + "output_quantize": """ +The output directory to save the quantized model weight. Will create `params_shard_*.bin` and +`ndarray-cache.json` in this directory. +""".strip(), + "conv_template": """ +Conversation template. It depends on how the model is tuned. Use "LM" for vanilla base model +""".strip(), + "output_gen_mlc_chat_config": """ +The output directory for generated configurations, including `mlc-chat-config.json` and tokenizer +configuration. +""".strip(), + "sliding_window": """ +(Experimental) The sliding window size in sliding window attention (SWA). +This optional field overrides the `sliding_window` in config.json for +those models that use SWA. Currently only useful when compiling Mistral. +This flag subjects to future refactoring. +""".strip(), + "prefill_chunk_size": """ +(Experimental) The chunk size during prefilling. By default, +the chunk size is the same as sliding window or max sequence length. +This flag subjects to future refactoring. +""".strip(), +} diff --git a/python/mlc_chat/compiler/loader/huggingface_loader.py b/python/mlc_chat/compiler/loader/huggingface_loader.py index 651c43b21f..9611cda87f 100644 --- a/python/mlc_chat/compiler/loader/huggingface_loader.py +++ b/python/mlc_chat/compiler/loader/huggingface_loader.py @@ -77,7 +77,7 @@ def __init__( The quantization mapping from MLC to quantized MLC parameters, default to None, which means no quantization. """ - assert path.is_file() + assert path.is_file(), f"Path {path} is not a file" self.stats = Stats() self.extern_param_map = extern_param_map self.cached_files = {} diff --git a/python/mlc_chat/compiler/loader/loader.py b/python/mlc_chat/compiler/loader/loader.py index 267ece72ab..e4c397c5ab 100644 --- a/python/mlc_chat/compiler/loader/loader.py +++ b/python/mlc_chat/compiler/loader/loader.py @@ -8,4 +8,5 @@ LOADER: Dict[str, Any] = { "huggingface-torch": HuggingFaceLoader, "huggingface-safetensor": HuggingFaceLoader, + "awq": HuggingFaceLoader, } diff --git a/python/mlc_chat/compiler/loader/stats.py b/python/mlc_chat/compiler/loader/stats.py index d12cd2f257..23db57df13 100644 --- a/python/mlc_chat/compiler/loader/stats.py +++ b/python/mlc_chat/compiler/loader/stats.py @@ -4,6 +4,8 @@ import time from contextlib import contextmanager +from mlc_chat.support.style import green + logger = logging.getLogger(__name__) @@ -67,10 +69,11 @@ def mem_rm(self, nbytes: int): def log_time_info(self, weight_format: str): """Log the time used in loading, pre-quantization and quantization.""" logger.info( - "Time usage: " + "%s: " "%s loading: %.3f sec; " "Pre-quantization mapping: %.3f sec; " "Quantization: %.3f sec", + green("Time usage"), weight_format, self.load_time_sec, self.map_time_sec, @@ -80,7 +83,8 @@ def log_time_info(self, weight_format: str): def log_mem_usage(self): """Log the Memory usage information.""" logger.info( - "RAM usage: Peak RAM: %.3f GB. Total bytes loaded from disk: %.3f GB", - self.total_memory_gb, + "%s: Peak RAM: %.3f GB. Total bytes loaded from disk: %.3f GB", + green("RAM usage"), self.max_memory_gb, + self.total_memory_gb, ) diff --git a/python/mlc_chat/compiler/model/__init__.py b/python/mlc_chat/compiler/model/__init__.py index a42dda9a09..87dcd49097 100644 --- a/python/mlc_chat/compiler/model/__init__.py +++ b/python/mlc_chat/compiler/model/__init__.py @@ -1,2 +1,3 @@ """Model definition for the compiler.""" +from . import llama, mistral from .model import MODEL_PRESETS, MODELS, Model diff --git a/python/mlc_chat/compiler/model/llama/__init__.py b/python/mlc_chat/compiler/model/llama/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_chat/compiler/model/llama/llama_loader.py b/python/mlc_chat/compiler/model/llama/llama_loader.py new file mode 100644 index 0000000000..7118340d21 --- /dev/null +++ b/python/mlc_chat/compiler/model/llama/llama_loader.py @@ -0,0 +1,158 @@ +""" +This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from ...loader import ExternMapping +from ...quantization import Quantization +from .llama_model import LlamaConfig, LlamaForCasualLM +from .llama_quantization import awq_quant + + +def huggingface(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = LlamaForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm(spec=model.get_default_spec()) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping + + +def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params = model.export_tvm(spec=model.get_default_spec()) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"model.layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama/llama_model.py similarity index 78% rename from python/mlc_chat/compiler/model/llama_model.py rename to python/mlc_chat/compiler/model/llama/llama_model.py index 27d7db0825..8650c7ed2a 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama/llama_model.py @@ -11,8 +11,8 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from ...support.config import ConfigBase -from ...support.style import bold +from ....support.config import ConfigBase +from ....support.style import bold logger = logging.getLogger(__name__) @@ -28,26 +28,29 @@ class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes rms_norm_eps: float vocab_size: int position_embedding_base: int = 0 - max_sequence_length: int = 0 + context_window_size: int = 0 num_key_value_heads: int = 0 head_dim: int = 0 + prefill_chunk_size: int = 0 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): - if self.max_sequence_length == 0: - if "max_position_embeddings" in self.kwargs: - self.max_sequence_length = self.kwargs.pop("max_position_embeddings") - logger.info( - "%s not found in config.json. Falling back to %s (%d)", - bold("max_sequence_length"), - bold("max_position_embeddings"), - self.max_sequence_length, - ) + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because neither " - "`max_sequence_length` nor `max_position_embeddings` is provided " - "in `config.json`." + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." ) if self.position_embedding_base == 0: if "rope_theta" in self.kwargs: @@ -61,71 +64,12 @@ def __post_init__(self): assert self.num_attention_heads % self.num_key_value_heads == 0 assert self.head_dim * self.num_attention_heads == self.hidden_size + if self.prefill_chunk_size == 0: + # chunk size same as context window size by default + self.prefill_chunk_size = self.context_window_size -# pylint: disable=invalid-name,missing-docstring - - -class RMSNorm(nn.Module): - """ - Module for rms norm layer. - """ - def __init__( # pylint: disable=too-many-arguments - self, - hidden_size: int, - axes, # pylint: disable=unused-argument - epsilon: float = 1e-5, - bias: bool = True, - dtype: Optional[str] = None, - ): - super().__init__() - self.epsilon = epsilon - self.weight = nn.Parameter((hidden_size,), dtype=dtype) - if bias: - self.bias = nn.Parameter((hidden_size,), dtype=dtype) - else: - self.bias = None - - def forward(self, x: Tensor): - """ - Forward method for rms norm layer. - - Parameters - ---------- - x : Tensor - The input tensor. - - Returns - ------- - ret : Tensor - The output tensor for the rms norm layer. - """ - - def f_square(x): - x = x.astype("float32") - return x * x - - def f_div_mult(x, square_sum, weight, *indices): - *i, k = indices - s = tir.sqrt(square_sum[*i] / x.shape[-1] + self.epsilon) - s = x[*i, k].astype("float32") / s - s = (weight[k] * s).astype(x.dtype) - return s - - def te_op(x: te.Tensor, weight: te.Tensor): - k = te.reduce_axis((0, x.shape[-1]), name="k") - square_sum = te.compute( - x.shape[:-1], - lambda *i: te.sum(f_square(x[*i, k]), axis=k), - name=x.op.name + "red_temp", - ) - return te.compute( - x.shape, - lambda *i: f_div_mult(x, square_sum, weight, *i), - name="rms_norm", - ) - - return op.tensor_expr_op(te_op, "rms_norm", args=[x, self.weight]) +# pylint: disable=invalid-name,missing-docstring class RotaryEmbedding(nn.Module): @@ -193,8 +137,8 @@ def __init__(self, config: LlamaConfig, rotary_embedding: RotaryEmbedding): bias=False, ) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.k_cache = nn.KVCache(config.max_sequence_length, [self.num_kv_heads, self.head_dim]) - self.v_cache = nn.KVCache(config.max_sequence_length, [self.num_kv_heads, self.head_dim]) + self.k_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) + self.v_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) def forward( # pylint: disable=too-many-locals self, @@ -241,8 +185,8 @@ def __init__(self, config: LlamaConfig, rotary_embedding: RotaryEmbedding): rms_norm_eps = config.rms_norm_eps self.self_attn = LlamaAttention(config, rotary_embedding) self.mlp = LlamaFFN(config) - self.input_layernorm = RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) - self.post_attention_layernorm = RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): hidden_states = ( @@ -261,7 +205,7 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList( [LlamaDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] ) - self.norm = RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): hidden_states = self.embed_tokens(inputs) diff --git a/python/mlc_chat/compiler/model/llama/llama_quantization.py b/python/mlc_chat/compiler/model/llama/llama_quantization.py new file mode 100644 index 0000000000..598c7be3fb --- /dev/null +++ b/python/mlc_chat/compiler/model/llama/llama_quantization.py @@ -0,0 +1,52 @@ +"""This file specifies how MLC's Llama parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from ...loader import QuantizeMapping +from ...quantization import AWQQuantize, GroupQuantize, NoQuantize +from .llama_model import LlamaConfig, LlamaForCasualLM + + +def group_quant( + model_config: LlamaConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama-architecture model using group quantization.""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: LlamaConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: LlamaConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model without quantization.""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/compiler/model/mistral/__init__.py b/python/mlc_chat/compiler/model/mistral/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_chat/compiler/model/mistral/mistral_loader.py b/python/mlc_chat/compiler/model/mistral/mistral_loader.py new file mode 100644 index 0000000000..6c79e5b8e3 --- /dev/null +++ b/python/mlc_chat/compiler/model/mistral/mistral_loader.py @@ -0,0 +1,158 @@ +""" +This file specifies how MLC's Mistral parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from ...loader import ExternMapping +from ...quantization import Quantization +from .mistral_model import MistralConfig, MistralForCasualLM +from .mistral_quantization import awq_quant + + +def huggingface(model_config: MistralConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : MistralConfig + The configuration of the Mistral model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = MistralForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm(spec=model.get_default_spec()) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping + + +def awq(model_config: MistralConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : MistralConfig + The configuration of the Mistral model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params = model.export_tvm(spec=model.get_default_spec()) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"model.layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_chat/compiler/model/mistral/mistral_model.py b/python/mlc_chat/compiler/model/mistral/mistral_model.py new file mode 100644 index 0000000000..17942bf7f9 --- /dev/null +++ b/python/mlc_chat/compiler/model/mistral/mistral_model.py @@ -0,0 +1,489 @@ +""" +Implementation for Mistral architecture. +""" +import dataclasses +import logging +import math +from typing import Any, Dict, Optional + +from tvm import relax as rx +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from ....support.config import ConfigBase +from ....support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class MistralConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Mistral model.""" + + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + rms_norm_eps: float + vocab_size: int + position_embedding_base: int = 0 + context_window_size: int = 0 + num_key_value_heads: int = 0 + head_dim: int = 0 + sliding_window: int = 4096 + prefill_chunk_size: int = 0 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.num_key_value_heads == 0: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.num_attention_heads % self.num_key_value_heads == 0 + assert self.head_dim * self.num_attention_heads == self.hidden_size + + if self.prefill_chunk_size == 0: + # chunk size same as sliding window by default + self.prefill_chunk_size = self.sliding_window + self.context_window_size = -1 + logger.info( + "Using sliding window attention, setting %s to -1", + bold("context_window_size"), + ) + + +# pylint: disable=invalid-name,missing-docstring + + +class RotaryEmbedding(nn.Module): + """Same as in Llama architecture.""" + + def __init__(self, config: MistralConfig): + super().__init__() + self.head_dim = config.head_dim + self.position_embedding_base = config.position_embedding_base + + def forward(self, q: Tensor, k: Tensor, offset: tir.Var): + def te_op(x: te.Tensor, offset: tir.Var): + dtype = x.dtype + + def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var): + head_dim = tir.const(self.head_dim, "int32") + position_embedding_base = tir.const(self.position_embedding_base, "float32") + freq = tir.power( + position_embedding_base, + (d * 2 % head_dim).astype("float32") / head_dim, + ) + freq = (offset + s) / freq + cos = tir.cos(freq).astype(dtype) * x[b, s, h, d] + sin = tir.sin(freq).astype(dtype) * tir.if_then_else( + d < head_dim // 2, + -x[b, s, h, d + head_dim // 2], + x[b, s, h, d - head_dim // 2], + ) + return cos + sin + + return te.compute(x.shape, compute, name="rotary") + + q_embed = op.tensor_expr_op(te_op, "rotary_embedding", args=[q, offset]) + k_embed = op.tensor_expr_op(te_op, "rotary_embedding", args=[k, offset]) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + """Same as in Llama architecture (LlamaFFN).""" + + def __init__(self, config: MistralConfig): + super().__init__() + self.gate_up_proj = nn.MultiLinear( + in_features=config.hidden_size, + out_features=[config.intermediate_size, config.intermediate_size], + bias=False, + ) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + x1, x2 = self.gate_up_proj(x) + return self.down_proj(op.silu(x1) * x2) + + +class MistralAttention(nn.Module): # pylint: disable=too-many-instance-attributes + """Same as LlamaAttention, but with sliding window attention using a rolling buffer cache.""" + + def __init__(self, config: MistralConfig, rotary_embedding: RotaryEmbedding): + self.rotary_embedding = rotary_embedding + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + self.num_q_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.sliding_window = config.sliding_window + self.qkv_proj = nn.MultiLinear( + in_features=config.hidden_size, + out_features=[ + self.num_q_heads * self.head_dim, + self.num_kv_heads * self.head_dim, + self.num_kv_heads * self.head_dim, + ], + bias=False, + ) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.k_cache = RollingKVCache(self.sliding_window, [self.num_kv_heads, self.head_dim]) + self.v_cache = RollingKVCache(self.sliding_window, [self.num_kv_heads, self.head_dim]) + + def interleave_kv( # pylint: disable=too-many-arguments,too-many-locals + self, + k_cur: Tensor, + v_cur: Tensor, + total_seq_len: tir.Var, + kv_seq_len: tir.Var, + rolling_cache_len: tir.Var, + ): + """Unrotate and concatenate currunt and cached k and v""" + d, _, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + t, kv_s, c = total_seq_len, kv_seq_len, rolling_cache_len + b, s, _, _ = k_cur.shape + cache_offset = (t - s) % self.sliding_window + + k_cached = op.reshape(self.k_cache.view(c), (b, c, h_kv, d)) + v_cached = op.reshape(self.v_cache.view(c), (b, c, h_kv, d)) + + def _unrotate_concat(x_cur, x_cached, cache_offset, rolling_cache_len): + return te.compute( + (b, kv_s, h_kv, d), + lambda xb, xs, xh, xd: te.if_then_else( + xs < rolling_cache_len - cache_offset, + x_cached[xb, cache_offset + xs, xh, xd], + te.if_then_else( + xs < rolling_cache_len, + x_cached[xb, xs + cache_offset - rolling_cache_len, xh, xd], + x_cur[xb, xs - rolling_cache_len, xh, xd], + ), + ), + name="unrotate_concat_te", + ) + + k = op.tensor_expr_op( + _unrotate_concat, + name_hint="te_unrotate_concat_key", + args=[k_cur, k_cached, cache_offset, c], + ) + v = op.tensor_expr_op( + _unrotate_concat, + name_hint="te_unrotate_concat_value", + args=[v_cur, v_cached, cache_offset, c], + ) + + self.k_cache.override(op.squeeze(k_cur, axis=0), self.sliding_window) + self.v_cache.override(op.squeeze(v_cur, axis=0), self.sliding_window) + + return k, v + + def forward( # pylint: disable=too-many-arguments, too-many-locals + self, + hidden_states: Tensor, + attention_mask: Tensor, + total_seq_len: tir.Var, # Number of already-processed tokens plus ``seq_len``. + rolling_cache_len: tir.Var, # Number of elements currently in the cache. + kv_seq_len: tir.Var, # Equals to ``seq_len + rolling_cache_len``. + ): + """Forward pass of MistralAttention, performing QKV.""" + d, h_q, h_kv, t = self.head_dim, self.num_q_heads, self.num_kv_heads, total_seq_len + b, s, _ = hidden_states.shape + assert b == 1, "Only support batch size 1 at this moment." + + q, k_cur, v_cur = self.qkv_proj(hidden_states) + q = op.reshape(q, (b, s, h_q, d)) + k_cur = op.reshape(k_cur, (b, s, h_kv, d)) + v_cur = op.reshape(v_cur, (b, s, h_kv, d)) + q, k_cur = self.rotary_embedding(q, k_cur, t - s) + + k, v = self.interleave_kv(k_cur, v_cur, total_seq_len, kv_seq_len, rolling_cache_len) + + if h_kv != h_q: + k = k.repeat(h_q // h_kv, axis=2) + v = v.repeat(h_q // h_kv, axis=2) + q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d] + k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d] + v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d] + attn_weights = op.matmul( + q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t] + ) / math.sqrt(d) + dtype = attn_weights.dtype + attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask) + if dtype == "float32": + attn_weights = op.softmax(attn_weights, axis=-1) + else: + attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype) + # [b, h, s, t] x [b, h, t, d] => [b, h, s, d] => [b, s, h, d] + output = op.matmul(attn_weights, v) + return self.o_proj(output.permute_dims([0, 2, 1, 3]).reshape((b, s, h_q * d))) + + +class RollingKVCache(nn.KVCache): + """ + Rolling buffer cache implementation. + """ + + cache: Optional[rx.Var] + + def override(self, new_element: Tensor, max_cache_size: int) -> None: + """ + Override cache elements in RollingKVCache. + + Parameters + ---------- + new_element : Tensor + The new tensor to append. + + max_cache_size : int + Max size of the cache. + """ + if new_element.dtype != self.dtype: + raise TypeError( + f'RollingKVCache has been set to use dtype "{self.dtype}", ' + f'but got "{new_element.dtype}"' + ) + self.cache = rx.BlockBuilder.current().emit( + rx.Call( + rx.extern("vm.builtin.attention_kv_cache_window_override"), + args=[ + self.cache, + new_element._expr, # pylint: disable=protected-access + rx.PrimValue(max_cache_size), + ], + sinfo_args=[rx.ObjectStructInfo()], + ) + ) + + +class MistralDecoderLayer(nn.Module): + """Exact same as LlamaDecoderLayer.""" + + def __init__(self, config: MistralConfig, rotary_embedding: RotaryEmbedding): + rms_norm_eps = config.rms_norm_eps + self.self_attn = MistralAttention(config, rotary_embedding) + self.mlp = MistralMLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + + def forward( # pylint: disable=too-many-arguments + self, + hidden_states: Tensor, + attention_mask: Tensor, + total_seq_len: tir.Var, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + ): + """Forward pass of a decoder layer; calculate attention, and add an residual connection.""" + hidden_states = ( + self.self_attn( + self.input_layernorm(hidden_states), + attention_mask, + total_seq_len, + rolling_cache_len, + kv_seq_len, + ) + + hidden_states + ) + hidden_states = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states + return hidden_states + + +class MistralModel(nn.Module): + """Exact same as LlamaModel.""" + + def __init__(self, config: MistralConfig): + assert config.hidden_size % config.num_attention_heads == 0 + rotary_embedding = RotaryEmbedding(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward( # pylint: disable=too-many-arguments + self, + inputs: Tensor, + total_seq_len: tir.Var, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + attention_mask: Tensor, + ): + """Forward pass of the model, passing through all decoder layers.""" + hidden_states = self.embed_tokens(inputs) + for layer in self.layers: + hidden_states = layer( + hidden_states, attention_mask, total_seq_len, rolling_cache_len, kv_seq_len + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MistralForCasualLM(nn.Module): + """Same as LlamaForCausalLM, except for the use of sliding window attention.""" + + def __init__(self, config: MistralConfig): + self.model = MistralModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.sliding_window = config.sliding_window + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward( # pylint: disable=too-many-arguments + self, + inputs: Tensor, + total_seq_len: tir.Var, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + attention_mask: Tensor, + ): + """Forward pass.""" + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model( + inputs, total_seq_len, rolling_cache_len, kv_seq_len, attention_mask + ) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def prefill( + self, + inputs: Tensor, + total_seq_len: tir.Var, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + ): + """ + Prefilling the prompt. + + Parameters + ---------- + inputs: Tensor + Input tokens, having ``seq_len`` number of tokens. + + total_seq_len: tir.Var + Number of already-processed tokens plus ``seq_len``. + + rolling_cache_len: tir.Var + Number of elements currently in the cache. + + kv_seq_len: tir.Var + Equals to ``seq_len + rolling_cache_len``. + """ + + def _sliding_window_attention_mask( + batch_size, seq_len, rolling_cache_len, kv_seq_len, sliding_window + ): + # See `tests/legacy-python/test_sliding_window_mask.py` for its behavior + return te.compute( + (batch_size, 1, seq_len, kv_seq_len), + lambda b, _, i, j: tir.Select( + tir.all(i + rolling_cache_len >= j, i + rolling_cache_len - j < sliding_window), + tir.max_value(self.dtype), + tir.min_value(self.dtype), + ), + name="sliding_window_attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask = op.tensor_expr_op( + _sliding_window_attention_mask, + name_hint="sliding_window_attention_mask_prefill", + args=[ + batch_size, + seq_len, + rolling_cache_len, + kv_seq_len, + self.sliding_window, + ], + ) + return self.forward(inputs, total_seq_len, rolling_cache_len, kv_seq_len, attention_mask) + + def decode( + self, + inputs: Tensor, + total_seq_len: tir.Var, + rolling_cache_len: tir.Var, + kv_seq_len: tir.Var, + ): + """Decoding step.""" + batch_size, seq_len = inputs.shape + attention_mask = op.full( + shape=[batch_size, 1, seq_len, kv_seq_len], + fill_value=tir.max_value(self.dtype), + dtype=self.dtype, + ) + return self.forward(inputs, total_seq_len, rolling_cache_len, kv_seq_len, attention_mask) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + """Softmax.""" + return op.softmax(logits / temperature, axis=-1) + + def get_default_spec(self): + """Needed for ``export_tvm()``.""" + batch_size = 1 + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "total_seq_len": int, + "rolling_cache_len": int, + "kv_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "decode": { + "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "total_seq_len": int, + "rolling_cache_len": int, + "kv_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/compiler/model/mistral/mistral_quantization.py b/python/mlc_chat/compiler/model/mistral/mistral_quantization.py new file mode 100644 index 0000000000..eecff1a63a --- /dev/null +++ b/python/mlc_chat/compiler/model/mistral/mistral_quantization.py @@ -0,0 +1,41 @@ +"""This file specifies how MLC's Mistral parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from ...loader import QuantizeMapping +from ...quantization import AWQQuantize, GroupQuantize +from .mistral_model import MistralConfig, MistralForCasualLM + + +def group_quant( + model_config: MistralConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mistral-architecture model using group quantization.""" + model: nn.Module = MistralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: MistralConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mistral-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = MistralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index a5e40818de..1323656811 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -6,7 +6,8 @@ from ..loader import ExternMapping, QuantizeMapping from ..quantization.quantization import Quantization -from . import llama_loader, llama_model, llama_quantization +from .llama import llama_loader, llama_model, llama_quantization +from .mistral import mistral_loader, mistral_model, mistral_quantization ModelConfig = Any """A ModelConfig is an object that represents a model architecture. It is required to have @@ -61,9 +62,24 @@ class Model: "awq": llama_loader.awq, }, quantize={ + "no-quant": llama_quantization.no_quant, "group-quant": llama_quantization.group_quant, + "awq": llama_quantization.awq_quant, }, - ) + ), + "mistral": Model( + name="mistral", + model=mistral_model.MistralForCasualLM, + config=mistral_model.MistralConfig, + source={ + "huggingface-torch": mistral_loader.huggingface, + "huggingface-safetensor": mistral_loader.huggingface, + "awq": mistral_loader.awq, + }, + quantize={ + "group-quant": mistral_quantization.group_quant, + }, + ), } MODEL_PRESETS: Dict[str, Any] = { @@ -76,6 +92,7 @@ class Model: "initializer_range": 0.02, "intermediate_size": 11008, "max_position_embeddings": 2048, + "context_window_size": 4096, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, @@ -100,6 +117,7 @@ class Model: "initializer_range": 0.02, "intermediate_size": 13824, "max_position_embeddings": 2048, + "context_window_size": 4096, "model_type": "llama", "num_attention_heads": 40, "num_hidden_layers": 40, @@ -123,6 +141,7 @@ class Model: "initializer_range": 0.02, "intermediate_size": 28672, "max_position_embeddings": 2048, + "context_window_size": 4096, "model_type": "llama", "num_attention_heads": 64, "num_hidden_layers": 80, @@ -205,4 +224,26 @@ class Model: "use_cache": True, "vocab_size": 32016, }, + "mistral_7b": { + "architectures": ["MistralForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "sliding_window": 4096, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.34.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, } diff --git a/python/mlc_chat/compiler/quantization/__init__.py b/python/mlc_chat/compiler/quantization/__init__.py index 74950df832..7a449eb96c 100644 --- a/python/mlc_chat/compiler/quantization/__init__.py +++ b/python/mlc_chat/compiler/quantization/__init__.py @@ -1,4 +1,5 @@ """A subpackage for quantization and dequantization algorithms""" from .awq_quantization import AWQQuantize from .group_quantization import GroupQuantize +from .no_quantization import NoQuantize from .quantization import QUANTIZATION, Quantization diff --git a/python/mlc_chat/compiler/quantization/awq_quantization.py b/python/mlc_chat/compiler/quantization/awq_quantization.py index 944ded0ba0..29a2b165be 100644 --- a/python/mlc_chat/compiler/quantization/awq_quantization.py +++ b/python/mlc_chat/compiler/quantization/awq_quantization.py @@ -159,7 +159,7 @@ def _dequantize( tir.subtract(float_weight[i, j], float_zeros[i, j // self.group_size]), scale[i, j // self.group_size], ), - name="decode", + name="dequantize", ) @@ -250,7 +250,7 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name scale, [tir.IntImm("int64", self.out_features), tir.IntImm("int64", self.in_features)], ), - name_hint="decode", + name_hint="dequantize", args=[self.qweight, self.qzeros, self.scales], ) w = nn.op.permute_dims(w) # pylint: disable=invalid-name @@ -356,7 +356,7 @@ def forward(self, x: nn.Tensor) -> Sequence[nn.Tensor]: # pylint: disable=inval tir.IntImm("int64", self.in_features), ], ), - name_hint="decode", + name_hint="dequantize", args=[self.qweight, self.qzeros, self.scales], ) w = nn.op.permute_dims(w) # pylint: disable=invalid-name diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index ecce18d3c0..9fffbf10fd 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -283,13 +283,11 @@ def __init__( # pylint: disable=too-many-arguments self.out_features = out_features self.out_dtype = out_dtype self.config = config + num_group = tir.ceildiv(in_features, config.group_size) self.q_weight = nn.Parameter( - (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), - config.storage_dtype, - ) - self.q_scale = nn.Parameter( - (out_features, tir.ceildiv(in_features, config.group_size)), config.model_dtype + (out_features, config.num_storage_per_group * num_group), config.storage_dtype ) + self.q_scale = nn.Parameter((out_features, num_group), config.model_dtype) if bias: self.bias = nn.Parameter((out_features,), config.model_dtype) else: @@ -370,14 +368,12 @@ def __init__( # pylint: disable=too-many-arguments self.out_features = out_features self.out_dtype = out_dtype self.config = config + num_group = tir.ceildiv(in_features, config.group_size) self.q_weight = nn.Parameter( - (self.total_out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), + (self.total_out_features, config.num_storage_per_group * num_group), config.storage_dtype, ) - self.q_scale = nn.Parameter( - (self.total_out_features, tir.ceildiv(in_features, config.group_size)), - config.model_dtype, - ) + self.q_scale = nn.Parameter((self.total_out_features, num_group), config.model_dtype) if bias: self.bias = nn.Parameter((self.total_out_features,), config.model_dtype) else: @@ -456,14 +452,11 @@ def __init__(self, num: int, dim: int, config: GroupQuantize): self.num = num self.dim = dim self.config = config + num_group = tir.ceildiv(dim, config.group_size) self.q_weight = nn.Parameter( - (num, tir.ceildiv(dim, config.num_elem_per_storage)), - config.storage_dtype, - ) - self.q_scale = nn.Parameter( - (num, tir.ceildiv(dim, config.group_size)), - config.model_dtype, + (num, config.num_storage_per_group * num_group), config.storage_dtype ) + self.q_scale = nn.Parameter((num, num_group), config.model_dtype) @staticmethod def from_embedding(embedding: nn.Embedding, config: GroupQuantize) -> "GroupQuantizeEmbedding": diff --git a/python/mlc_chat/compiler/quantization/no_quantization.py b/python/mlc_chat/compiler/quantization/no_quantization.py new file mode 100644 index 0000000000..a1e4a436aa --- /dev/null +++ b/python/mlc_chat/compiler/quantization/no_quantization.py @@ -0,0 +1,17 @@ +"""The no quantization config""" +import logging +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class NoQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for no quantization""" + + name: str + kind: str + model_dtype: str # "float16", "float32" + + def __post_init__(self): + assert self.kind == "no-quant" diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/compiler/quantization/quantization.py index bae8d07aec..a4d915fc7d 100644 --- a/python/mlc_chat/compiler/quantization/quantization.py +++ b/python/mlc_chat/compiler/quantization/quantization.py @@ -3,6 +3,7 @@ from .awq_quantization import AWQQuantize from .group_quantization import GroupQuantize +from .no_quantization import NoQuantize Quantization = Any """Quantization is an object that represents an quantization algorithm. It is required to @@ -24,6 +25,16 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr """ QUANTIZATION: Dict[str, Quantization] = { + "q0f16": NoQuantize( + name="q0f16", + kind="no-quant", + model_dtype="float16", + ), + "q0f32": NoQuantize( + name="q0f32", + kind="no-quant", + model_dtype="float32", + ), "q3f16_1": GroupQuantize( name="q3f16_1", kind="group-quant", diff --git a/python/mlc_chat/support/argparse.py b/python/mlc_chat/support/argparse.py new file mode 100644 index 0000000000..81211e8e07 --- /dev/null +++ b/python/mlc_chat/support/argparse.py @@ -0,0 +1,15 @@ +"""An enhanced argument parser for mlc-chat.""" +import argparse +import sys + + +class ArgumentParser(argparse.ArgumentParser): + """An enhanced argument parser for mlc-chat.""" + + def error(self, message): + """Overrides the behavior when erroring out""" + print("-" * 25 + " Usage " + "-" * 25) + self.print_help() + print("-" * 25 + " Error " + "-" * 25) + print(message, file=sys.stderr) + sys.exit(2) diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index 708b675513..d266403355 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -36,7 +36,8 @@ def detect_config(config: Union[str, Path]) -> Path: if isinstance(config, str) and config in MODEL_PRESETS: logger.info("%s preset model: %s", FOUND, config) - content = MODEL_PRESETS[config] + content = MODEL_PRESETS[config].copy() + content["model_preset_tag"] = config temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with suffix=".json", delete=False, diff --git a/python/mlc_chat/support/auto_device.py b/python/mlc_chat/support/auto_device.py new file mode 100644 index 0000000000..f61e64430f --- /dev/null +++ b/python/mlc_chat/support/auto_device.py @@ -0,0 +1,41 @@ +"""Automatic detection of the device available on the local machine.""" +import logging + +import tvm +from tvm.runtime import Device + +from .style import bold, green, red + +FOUND = green("Found") +NOT_FOUND = red("Not found") +AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan", "opencl"] + + +logger = logging.getLogger(__name__) + + +def detect_device(device_hint: str) -> Device: + """Detect locally available device from string hint.""" + if device_hint == "auto": + device = None + for device_type in AUTO_DETECT_DEVICES: + cur_device = tvm.device(dev_type=device_type, dev_id=0) + if cur_device.exist: + logger.info("%s device: %s:0", FOUND, device_type) + if device is None: + device = cur_device + else: + logger.info("%s device: %s:0", NOT_FOUND, device_type) + if device is None: + logger.info("%s: No available device detected. Falling back to CPU", NOT_FOUND) + return tvm.device("cpu:0") + device_str = f"{tvm.runtime.Device.MASK2STR[device.device_type]}:{device.device_id}" + logger.info("Using device: %s. Use `--device` to override.", bold(device_str)) + return device + try: + device = tvm.device(device_hint) + except Exception as err: + raise ValueError(f"Invalid device name: {device_hint}") from err + if not device.exist: + raise ValueError(f"Device is not found on your local environment: {device_hint}") + return device diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py index 4c25380e8f..00b291175d 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_chat/support/auto_target.py @@ -3,13 +3,12 @@ import os from typing import TYPE_CHECKING, Callable, Optional, Tuple -import tvm from tvm import IRModule, relax from tvm._ffi import get_global_func, register_func from tvm.contrib import tar, xcode -from tvm.runtime import Device from tvm.target import Target +from .auto_device import AUTO_DETECT_DEVICES from .style import bold, green, red if TYPE_CHECKING: @@ -19,7 +18,6 @@ logger = logging.getLogger(__name__) # TODO: add help message on how to specify the target manually # pylint: disable=fixme -# TODO: include host detection logic below after the new TVM build is done. # pylint: disable=fixme HELP_MSG = """TBD""" FOUND = green("Found") NOT_FOUND = red("Not found") @@ -46,33 +44,6 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T return target, build_func -def detect_device(device_hint: str) -> Device: - """Detect locally available device from string hint.""" - if device_hint == "auto": - device = None - for device_type in AUTO_DETECT_DEVICES: - cur_device = tvm.device(dev_type=device_type, dev_id=0) - if cur_device.exist: - logger.info("%s device: %s:0", FOUND, device_type) - if device is None: - device = cur_device - else: - logger.info("%s device: %s:0", NOT_FOUND, device_type) - if device is None: - logger.info("%s: No available device detected. Falling back to CPU", NOT_FOUND) - return tvm.device("cpu:0") - device_str = f"{tvm.runtime.Device.MASK2STR[device.device_type]}:{device.device_id}" - logger.info("Using device: %s. Use `--device` to override.", bold(device_str)) - return device - try: - device = tvm.device(device_hint) - except Exception as err: - raise ValueError(f"Invalid device name: {device_hint}") from err - if not device.exist: - raise ValueError(f"Device is not found on your local environment: {device_hint}") - return device - - def _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]: if hint in ["iphone", "android", "webgpu", "mali", "opencl"]: hint += ":generic" @@ -115,7 +86,10 @@ def _detect_target_host(hint: str) -> Target: """Detect the host CPU architecture.""" if hint == "auto": target_triple = get_global_func("tvm.codegen.llvm.GetDefaultTargetTriple")() - logger.info("%s host CPU architecture: %s", FOUND, bold(target_triple)) + logger.info("%s host LLVM triple: %s", FOUND, bold(target_triple)) + else: + target_triple = hint + logger.info("Using LLVM triple specified by --host: %s", bold(target_triple)) return Target({"kind": "llvm", "mtriple": target_triple}) @@ -285,8 +259,6 @@ def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument return ptx -AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan"] - PRESET = { "iphone:generic": { "target": { diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py index 959e795169..9d446e4d33 100644 --- a/python/mlc_chat/support/auto_weight.py +++ b/python/mlc_chat/support/auto_weight.py @@ -84,6 +84,8 @@ def detect_weight( weight_config_path = check_func(weight_path) if not weight_config_path: raise ValueError(f"The weight is not in {weight_format} format.") + else: + weight_config_path = weight_path return weight_config_path, weight_format @@ -143,5 +145,5 @@ def _check_safetensor(weight_path: Path) -> Optional[Path]: "huggingface-safetensor": _check_safetensor, } -# "awq", "ggml", "gguf" are not supported yet. -AVAILABLE_WEIGHT_FORMAT = ["huggingface-torch", "huggingface-safetensor"] +# "ggml", "gguf" are not supported yet. +AVAILABLE_WEIGHT_FORMAT = ["huggingface-torch", "huggingface-safetensor", "awq"] diff --git a/python/mlc_chat/support/download.py b/python/mlc_chat/support/download.py new file mode 100644 index 0000000000..9130774d98 --- /dev/null +++ b/python/mlc_chat/support/download.py @@ -0,0 +1,170 @@ +"""Common utilities for downloading files from HuggingFace or other URLs online.""" +import concurrent.futures as cf +import hashlib +import json +import logging +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Optional, Tuple + +import requests # pylint: disable=import-error + +from . import tqdm + +logger = logging.getLogger(__name__) + +MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None) + + +def get_cache_dir() -> Path: + """Return the path to the cache directory.""" + + if os.getenv("MLC_CACHE_DIR"): + result = Path(os.getenv("MLC_CACHE_DIR")) + elif sys.platform == "win32": + result = Path(os.environ["LOCALAPPDATA"]) + result = result / "mlc_chat" + elif os.getenv("XDG_CACHE_HOME", None) is not None: + result = Path(os.getenv("XDG_CACHE_HOME")) + result = result / "mlc_chat" + else: + result = Path(os.path.expanduser("~/.cache")) + result = result / "mlc_chat" + result.mkdir(parents=True, exist_ok=True) + if not result.is_dir(): + raise ValueError( + f"The default cache directory is not a directory: {result}. " + "Use environment variable MLC_CACHE_DIR to specify a valid cache directory." + ) + return result + + +def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None: + if path.exists(): + if force_redo: + logger.info("Deleting existing directory: %s", path) + shutil.rmtree(path) + else: + raise ValueError(f"Directory already exists: {path}") + else: + path.parent.mkdir(parents=True, exist_ok=True) + + +def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None: + """Clone a git repository into a directory.""" + repo_name = ".tmp" + command = ["git", "clone", url, repo_name] + _ensure_directory_not_exist(destination, force_redo=False) + try: + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: + logger.info("[Git] Cloning %s to %s", url, destination) + subprocess.run( + command, + env={"GIT_LFS_SKIP_SMUDGE": "1"}, + cwd=tmp_dir, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + git_dir = os.path.join(tmp_dir, repo_name) + if not ignore_lfs: + git_lfs_pull(Path(git_dir)) + shutil.move(git_dir, str(destination)) + except subprocess.CalledProcessError as error: + raise ValueError( + f"Git clone failed with return code {error.returncode}: {error.stderr}. " + f"The command was: {command}" + ) from error + + +def git_lfs_pull(repo_dir: Path) -> None: + """Pull files with Git LFS.""" + filenames = ( + subprocess.check_output( + ["git", "-C", str(repo_dir), "lfs", "ls-files", "-n"], + stderr=subprocess.STDOUT, + ) + .decode("utf-8") + .splitlines() + ) + logger.info("[Git LFS] Downloading %d files with Git LFS: %s", len(filenames), filenames) + with tqdm.redirect(): + for file in tqdm.tqdm(filenames): + logger.info("[Git LFS] Downloading %s", file) + subprocess.check_output( + ["git", "-C", str(repo_dir), "lfs", "pull", "--include", file], + stderr=subprocess.STDOUT, + ) + + +def download_file( + url: str, + destination: Path, + md5sum: Optional[str], +) -> Tuple[str, Path]: + """Download a file from a URL to a destination file.""" + with requests.get(url, stream=True, timeout=30) as response: + response.raise_for_status() + with destination.open("wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + if md5sum is not None: + hash_md5 = hashlib.md5() + with destination.open("rb") as file: + for chunk in iter(lambda: file.read(8192), b""): + hash_md5.update(chunk) + file_md5 = hash_md5.hexdigest() + if file_md5 != md5sum: + raise ValueError( + f"MD5 checksum mismatch for downloaded file: {destination}. " + f"Expected {md5sum}, got {file_md5}" + ) + return url, destination + + +def download_mlc_weights( # pylint: disable=too-many-locals + model_url: str, + num_processes: int = 4, + force_redo: bool = False, +) -> None: + """Download weights for a model from the HuggingFace Git LFS repo.""" + mlc_prefix = "HF://" + git_url_template = "https://huggingface.co/{user}/{repo}.git" + bin_url_template = "https://huggingface.co/{user}/{repo}/resolve/main/{record_name}" + + if model_url.count("/") != 1 + mlc_prefix.count("/") or not model_url.startswith(mlc_prefix): + raise ValueError(f"Invalid model URL: {model_url}") + assert model_url.startswith(mlc_prefix) + user, repo = model_url[len(mlc_prefix) :].split("/") + git_dir = get_cache_dir() / "model_weights" / repo + try: + _ensure_directory_not_exist(git_dir, force_redo=force_redo) + except ValueError: + logger.info("Weights already downloaded: %s", git_dir) + return + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix: + tmp_dir = Path(tmp_dir_prefix) / "tmp" + git_url = git_url_template.format(user=user, repo=repo) + git_clone(git_url, tmp_dir, ignore_lfs=True) + shutil.rmtree(tmp_dir / ".git", ignore_errors=True) + with (tmp_dir / "ndarray-cache.json").open(encoding="utf-8") as in_file: + param_metadata = json.load(in_file)["records"] + with cf.ProcessPoolExecutor(max_workers=num_processes) as executor: + futures = [] + for record in param_metadata: + record_name = record["dataPath"] + file_url = bin_url_template.format(user=user, repo=repo, record_name=record_name) + file_dest = tmp_dir / record_name + file_md5 = record.get("md5sum", None) + futures.append(executor.submit(download_file, file_url, file_dest, file_md5)) + with tqdm.redirect(): + for future in tqdm.tqdm(cf.as_completed(futures), total=len(futures)): + file_url, file_dest = future.result() + logger.info("Downloaded %s to %s", file_url, file_dest) + logger.info("Moving %s to %s", tmp_dir, git_dir) + shutil.move(str(tmp_dir), str(git_dir)) + shutil.move(str(tmp_dir), str(git_dir)) diff --git a/python/setup.py b/python/setup.py index af471c19c0..e701a08e5d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -89,6 +89,11 @@ def main(): keywords="machine learning", zip_safe=False, packages=find_packages(), + entry_points={ + "console_scripts": [ + "mlc_chat = mlc_chat.__main__:main", + ], + }, package_dir={"mlc_chat": "mlc_chat"}, install_requires=["fastapi", "uvicorn", "shortuuid"], distclass=BinaryDistribution, diff --git a/rust/examples/mlc_chat.rs b/rust/examples/mlc_chat.rs index 2e87d56946..fa7132b052 100644 --- a/rust/examples/mlc_chat.rs +++ b/rust/examples/mlc_chat.rs @@ -1,10 +1,39 @@ extern crate mlc_llm; -use mlc_llm::chat_module::ChatModule; +use mlc_llm::chat_module::{ChatMessage, ChatModule}; fn main() { + // Single prompt example let cm = ChatModule::new("/path/to/Llama2-13B-q8f16_1", "rocm", None).unwrap(); let output = cm.generate("what is the meaning of life?", None).unwrap(); println!("resp: {:?}", output); println!("stats: {:?}", cm.stats(false)); + + // Multiple prompts example + let message1 = ChatMessage { + role: "user".to_owned(), + content: "suppose we already have projects llama, alpaca and vicuna, what do you think would be a great name for the next project?".to_string(), + }; + let message2 = ChatMessage { + role: "assistant".to_owned(), + content: "based on the previous projects, a possible name for the next project could be \"cervidae\" which is the scientific name for deer family. this name reflects the collaboration and teamwork involved in the development of the project, and also nods to the previous projects that have been developed by the team.".to_string(), + }; + let message3 = ChatMessage { + role: "user".to_owned(), + content: "I like cervidae, but the name is too long!".to_string(), + }; + let message4 = ChatMessage { + role: "assistant".to_owned(), + content: "In that case, a shorter and catchier name for the next project could be \"DeerRun\" which plays on the idea of the project being fast and efficient, just like a deer running through the woods. This name is memorable and easy to pronounce, making it a good choice for a project name.".to_string(), + }; + let message5 = ChatMessage { + role: "user".to_owned(), + content: "Summarize our conversations.".to_string(), + }; + + let messages = vec![message1, message2, message3, message4, message5]; + + let output = cm.generate(messages, None).unwrap(); + println!("resp: {:?}", output); + println!("stats: {:?}", cm.stats(false)); } diff --git a/rust/src/chat_module.rs b/rust/src/chat_module.rs index 831905eee8..e1882a3fd6 100644 --- a/rust/src/chat_module.rs +++ b/rust/src/chat_module.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::fs; use std::path::{Path, PathBuf}; use std::result; @@ -22,27 +23,34 @@ impl From for ChatModuleError { pub type Result = result::Result; -/// The ChatModule for MLC LLM. -/// -/// # Examples -/// -/// ``` -/// use mlc_llm::chat_module::ChatModule; -/// -/// // Create a ChatModule instance -/// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap(); -/// -/// // Generate a response for a given prompt -/// let output = cm.generate("What is the meaning of life?", None).unwrap(); -/// -/// // Print prefill and decode performance statistics -/// println!("Statistics: {:?}\n", cm.stats(false).unwrap()); -/// -/// let output = cm.generate("What is Rust?", None).unwrap(); -/// ``` -pub struct ChatModule { - chat_module: Module, - chat_config: ChatConfig, +#[derive(Debug, Clone)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +#[derive(Debug, Clone)] +pub enum Prompt { + String(String), + MessageList(Vec), +} + +impl From<&str> for Prompt { + fn from(s: &str) -> Self { + Prompt::String(s.to_owned()) + } +} + +impl From for Prompt { + fn from(s: String) -> Self { + Prompt::String(s) + } +} + +impl From> for Prompt { + fn from(messages: Vec) -> Self { + Prompt::MessageList(messages) + } } #[derive(Debug, Copy, Clone)] @@ -265,6 +273,29 @@ fn get_lib_module_path( } } +/// The ChatModule for MLC LLM. +/// +/// # Examples +/// +/// ``` +/// use mlc_llm::chat_module::ChatModule; +/// +/// // Create a ChatModule instance +/// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap(); +/// +/// // Generate a response for a given prompt +/// let output = cm.generate("what is the meaning of life?", None).unwrap(); +/// +/// // Print prefill and decode performance statistics +/// println!("Statistics: {:?}\n", cm.stats(false).unwrap()); +/// +/// let output = cm.generate("what is Rust?", None).unwrap(); +/// ``` +pub struct ChatModule { + chat_module: Module, + chat_config: ChatConfig, +} + impl ChatModule { pub fn new(model: &str, device: &str, model_lib_path: Option<&str>) -> Result { let device_err_msg = format!( @@ -278,7 +309,7 @@ impl ChatModule { // 1. Get device name and id let device_type = match device_name { - "cude" => 2, + "cuda" => 2, "opencl" => 4, "vulkan" => 7, "metal" => 8, @@ -312,7 +343,7 @@ impl ChatModule { let chat_mod = Self { chat_module: m, - chat_config: chat_config, + chat_config, }; let model_lib_str = model_lib_path.as_path().display().to_string(); let model_path_str = model_path.as_path().display().to_string(); @@ -350,7 +381,7 @@ impl ChatModule { } let f = self.chat_module.get_function("runtime_stats_text", false)?; let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); - return Ok(res); + Ok(res) } /// Check if the stop condition is met for the current round. @@ -382,12 +413,40 @@ impl ChatModule { Ok(()) } + /// Load JSON config and override existing configurations for the chat module. + fn load_json_override(&self, config_str: &str, partial_update: bool) -> Result<()> { + let f = self.chat_module.get_function("load_json_override", false)?; + f.invoke(vec![config_str.into(), (&partial_update).into()])?; + Ok(()) + } + + /// Get the configuration of the chat module in a single json string. + fn get_config_json(&self) -> Result { + let f = self.chat_module.get_function("get_config_json", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + Ok(res) + } + + /// Get the name of role 0 in the conversation. + fn get_role_0(&self) -> Result { + let f = self.chat_module.get_function("get_role0", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + Ok(res) + } + + /// Get the name of role 0 in the conversation. + fn get_role_1(&self) -> Result { + let f = self.chat_module.get_function("get_role1", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + Ok(res) + } + /// A high-level method that returns the full response from the chat module given a user /// prompt. User can optionally specify which callback method to use upon receiving the /// response. pub fn generate( &self, - prompt: &str, + prompt: impl Into, generation_config: Option<&GenerationConfig>, ) -> Result> { // TODO: add progress_callback @@ -400,9 +459,10 @@ impl ChatModule { } } + let prompt = prompt.into(); for _ in 0..num_return_sequences { self.reset_chat().unwrap(); - self.prefill(prompt, true, PlaceInPrompt::All, generation_config) + self.prefill(&prompt, true, PlaceInPrompt::All, generation_config) .unwrap(); while !self.stopped().unwrap() { @@ -419,7 +479,7 @@ impl ChatModule { /// User can decide where to place the input in the prompt. fn prefill( &self, - input: &str, + input: &Prompt, decode_next_token: bool, place_in_promt: PlaceInPrompt, generation_config: Option<&GenerationConfig>, @@ -432,9 +492,54 @@ impl ChatModule { } }; + let input_string = match input { + Prompt::String(inp) => inp.clone(), + Prompt::MessageList(chat_msgs) => { + let mut chat_msgs = chat_msgs.clone(); + if chat_msgs.len() == 1 { + chat_msgs.remove(0).content + } else { + let chat_config = ChatConfig::from_json(&(self.get_config_json()?)).unwrap(); + let mut conv_config = chat_config + .conv_config + .unwrap_or_else(|| ConvConfigBuilder::default().build().unwrap()); + + let role0 = self.get_role_0()?; + let role1 = self.get_role_1()?; + + let last_msg = chat_msgs + .last() + .expect("No last message in the vector") + .clone(); + if last_msg.role != "user" { + panic!("Last message should be from user."); + } + + let mut messages = Vec::new(); + let msg_len = chat_msgs.len(); + for msg in chat_msgs.into_iter().take(msg_len - 1) { + match msg.role.as_str() { + "user" => messages.push(vec![role0.clone(), msg.content]), + "assistant" => messages.push(vec![role1.clone(), msg.content]), + _ => panic!("Only user and assistant roles are supported."), + } + } + + conv_config.messages = Some(messages); + conv_config.offset = Some(0); + + let mut map = HashMap::new(); + map.insert("conv_config", conv_config); + self.load_json_override(&serde_json::to_string(&map).unwrap(), true)?; + + last_msg.content + } + } + }; + let f = self.chat_module.get_function("prefill", false)?; f.invoke(vec![ - input.into(), + input_string.into(), (&decode_next_token).into(), place_in_promt.to_value().into(), generation_config_str.into(), @@ -442,4 +547,3 @@ impl ChatModule { Ok(()) } } - diff --git a/rust/src/config.rs b/rust/src/config.rs index 61371d197a..52cea1746e 100644 --- a/rust/src/config.rs +++ b/rust/src/config.rs @@ -11,6 +11,9 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Default, Builder, Debug, Serialize, Deserialize)] #[builder(default)] pub struct ConvConfig { + /// Token list prefixing the conversation. + prefix_tokens: Option>, + /// Name of the conversation. name: Option, @@ -21,10 +24,10 @@ pub struct ConvConfig { roles: Option>, /// The chat history represented as an array of string pairs. - messages: Option>>, + pub messages: Option>>, /// The offset used to begin the chat from the chat history. - offset: Option, + pub offset: Option, /// Specifies whether we are in chat-bot mode (`0`) or pure LM prompt mode (`1`). separator_style: Option, @@ -102,7 +105,7 @@ pub struct ChatConfig { tokenizer_files: Option>, /// Partial overriding configuration for conversation template. - conv_config: Option, + pub conv_config: Option, /// The category of the model's architecture (e.g. `llama`, `gpt_neox`, `rwkv`). model_category: Option, diff --git a/tests/legacy-python/test_build_model_from_args.py b/tests/legacy-python/test_build_model_from_args.py index c7990d63df..b342e035bb 100644 --- a/tests/legacy-python/test_build_model_from_args.py +++ b/tests/legacy-python/test_build_model_from_args.py @@ -27,7 +27,7 @@ def setUp(self): self.mock_args.sep_embed = False self.mock_args.build_model_only = True self.mock_args.use_safetensors = False - self.mock_args.convert_weight_only = False + self.mock_args.convert_weights_only = False self.mock_args.no_cutlass_attn = True self.mock_args.no_cutlass_norm = True self.mock_args.reuse_lib = True diff --git a/tests/python/api/test_python.py b/tests/python/api/test_python.py new file mode 100644 index 0000000000..ceba066a13 --- /dev/null +++ b/tests/python/api/test_python.py @@ -0,0 +1,45 @@ +# pylint: disable=missing-docstring +import pytest + +from mlc_chat import ChatModule, GenerationConfig +from mlc_chat.callback import StreamToStdout + +MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] + + +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_creation_and_generate(model: str): + chat_module = ChatModule(model=model) + _ = chat_module.generate( + prompt="How to make a cake?", + ) + print(f"Statistics: {chat_module.stats()}\n") + + +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_creation_and_generate_with_stream(model: str): + chat_module = ChatModule(model=model) + _ = chat_module.generate( + prompt="How to make a cake?", + progress_callback=StreamToStdout(callback_interval=2), + ) + print(f"Statistics: {chat_module.stats()}\n") + + +@pytest.mark.parametrize( + "generation_config", + [ + GenerationConfig(temperature=0.7, presence_penalty=0.1, frequency_penalty=0.5, top_p=0.9), + GenerationConfig(stop=["cake", "make"], n=3), + GenerationConfig(max_gen_len=40, repetition_penalty=1.2), + ], +) +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_generation_config(generation_config: GenerationConfig, model: str): + chat_module = ChatModule(model=model) + output = chat_module.generate( + prompt="How to make a cake?", + generation_config=generation_config, + ) + print(output) + print(f"Statistics: {chat_module.stats()}\n") diff --git a/tests/python/api/test_rest.py b/tests/python/api/test_rest.py new file mode 100644 index 0000000000..de6e2bb793 --- /dev/null +++ b/tests/python/api/test_rest.py @@ -0,0 +1,71 @@ +# pylint: disable=missing-docstring +import json +import os +import signal +import subprocess +import time + +import pytest +import requests + +MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] + + +@pytest.fixture +def run_rest_server(model): + cmd = f"python -m mlc_chat.rest --model {model}" + print(cmd) + os.environ["PYTHONPATH"] = "./python" + with subprocess.Popen(cmd.split()) as server_proc: + # wait for server to start + while True: + try: + _ = requests.get("http://localhost:8000/stats", timeout=5) + break + except requests.exceptions.ConnectionError: + time.sleep(1) + yield + server_proc.send_signal(signal.SIGINT) + server_proc.wait() + + +@pytest.mark.usefixtures("run_rest_server") +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("model", MODELS) +def test_rest_api(model, stream): + payload = { + "model": model, + "messages": [ + { + "role": "user", + "content": "Hello, I am Bob", + }, + { + "role": "assistant", + "content": "Hello, I am a chatbot.", + }, + { + "role": "user", + "content": "What is my name?", + }, + ], + "stream": stream, + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 1.0, + "top_p": 0.95, + } + if stream: + with requests.post( + "http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True, timeout=120 + ) as model_response: + print("With streaming:") + for chunk in model_response: + content = json.loads(chunk[6:-2])["choices"][0]["delta"].get("content", "") + print(f"{content}", end="", flush=True) + print("\n") + else: + model_response = requests.post( + "http://127.0.0.1:8000/v1/chat/completions", json=payload, timeout=120 + ) + print(f"\n{model_response.json()['choices'][0]['message']['content']}\n") diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py index 5bf3b2dd08..79f8074560 100644 --- a/tests/python/model/test_llama_quantization.py +++ b/tests/python/model/test_llama_quantization.py @@ -10,12 +10,12 @@ @pytest.mark.parametrize( - "model_name, quant_name", - [ - ("llama2_7b", "q4f16_1"), - ("llama2_13b", "q4f16_1"), - ("llama2_70b", "q4f16_1"), - ], + "model_name", + ["llama2_7b", "llama2_13b", "llama2_70b"], +) +@pytest.mark.parametrize( + "quant_name", + ["q3f16_1", "q4f16_1", "q4f32_1"], ) def test_llama2_group_quantization(model_name: str, quant_name: str): model_info = MODELS["llama"] @@ -51,6 +51,22 @@ def test_llama2_group_quantization(model_name: str, quant_name: str): ) +@pytest.mark.parametrize( + "model_name", + ["llama2_7b", "llama2_13b", "llama2_70b"], +) +@pytest.mark.parametrize( + "quant_name", + ["q0f16", "q0f32"], +) +def test_llama2_no_quantization(model_name: str, quant_name: str): + model_info = MODELS["llama"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + _, quant_map = model_info.quantize["no-quant"](config, QUANTIZATION[quant_name]) + assert len(quant_map.param_map) == 0 + assert len(quant_map.map_func) == 0 + + if __name__ == "__main__": test_llama2_group_quantization("llama2_7b", "q4f16_1") test_llama2_group_quantization("llama2_13b", "q4f16_1") diff --git a/tests/python/model/test_mistral.py b/tests/python/model/test_mistral.py new file mode 100644 index 0000000000..6c6361b1c8 --- /dev/null +++ b/tests/python/model/test_mistral.py @@ -0,0 +1,21 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest + +from mlc_chat.compiler import MODEL_PRESETS, MODELS + + +@pytest.mark.parametrize("model_name", ["mistral_7b"]) +def test_llama2_creation(model_name: str): + model_info = MODELS["mistral"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + + +if __name__ == "__main__": + test_llama2_creation("mistral_7b") diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py index 04b23e91d3..52eb63a5ac 100644 --- a/tests/python/quantization/test_group_quantization.py +++ b/tests/python/quantization/test_group_quantization.py @@ -14,6 +14,7 @@ GroupQuantize, GroupQuantizeEmbedding, GroupQuantizeLinear, + GroupQuantizeMultiLinear, ) @@ -128,14 +129,13 @@ def forward(self, x: nn.Tensor): config = QUANTIZATION[quant_name] assert isinstance(config, GroupQuantize) + num_group = -(shape[1] // -config.group_size) weight_np = np.random.randint( np.iinfo(config.storage_dtype).min, np.iinfo(config.storage_dtype).max, - (shape[0], -(shape[1] // -config.num_elem_per_storage)), + (shape[0], config.num_storage_per_group * num_group), ).astype(config.storage_dtype) - scale_np = np.random.random((shape[0], -(shape[1] // -config.group_size))).astype( - config.model_dtype - ) + scale_np = np.random.random((shape[0], num_group)).astype(config.model_dtype) mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") mod.linear.q_weight.data = weight_np mod.linear.q_scale.data = scale_np @@ -160,6 +160,7 @@ class Test(nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(shape[0], shape[1], dtype=dtype) + self.multilinear = nn.MultiLinear(shape[0], [shape[1], shape[1]], dtype=dtype) self.embedding = nn.Embedding(shape[0], shape[1], dtype=dtype) def forward(self, x: nn.Tensor): @@ -175,6 +176,12 @@ def forward(self, x: nn.Tensor): ] assert quant_map.map_func["model.linear.weight"] == config.quantize_weight assert isinstance(mod.linear, GroupQuantizeLinear) + assert quant_map.param_map["model.multilinear.weight"] == [ + "model.multilinear.q_weight", + "model.multilinear.q_scale", + ] + assert quant_map.map_func["model.multilinear.weight"] == config.quantize_weight + assert isinstance(mod.multilinear, GroupQuantizeMultiLinear) assert quant_map.param_map["model.embedding.weight"] == [ "model.embedding.q_weight", "model.embedding.q_scale",