diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml index 1a255785ac..dab28b9261 100644 --- a/.github/workflows/python_lint.yml +++ b/.github/workflows/python_lint.yml @@ -1,9 +1,7 @@ name: Python Lint - on: [push, pull_request] - env: - IMAGE: 'mlcaidev/ci-cpu:8a87699' + IMAGE: 'mlcaidev/ci-cpu:2c03e7f' jobs: isort: @@ -35,3 +33,33 @@ jobs: - name: Lint run: | ./ci/bash.sh $IMAGE bash ./ci/task/black.sh + + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: 'recursive' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh + + pylint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: 'recursive' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh diff --git a/ci/task/black.sh b/ci/task/black.sh index 0e8555cf63..9e17a4c37a 100755 --- a/ci/task/black.sh +++ b/ci/task/black.sh @@ -3,7 +3,8 @@ set -eo pipefail source ~/.bashrc micromamba activate ci-lint -NUM_THREADS=$(nproc) +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" black --check --workers $NUM_THREADS ./python/ black --check --workers $NUM_THREADS ./tests/python diff --git a/ci/task/isort.sh b/ci/task/isort.sh index cdeb030cc6..0cf5ef9144 100755 --- a/ci/task/isort.sh +++ b/ci/task/isort.sh @@ -3,7 +3,8 @@ set -eo pipefail source ~/.bashrc micromamba activate ci-lint -NUM_THREADS=$(nproc) +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" isort --check-only -j $NUM_THREADS --profile black ./python/ isort --check-only -j $NUM_THREADS --profile black ./tests/python/ diff --git a/ci/task/mypy.sh b/ci/task/mypy.sh new file mode 100755 index 0000000000..68713ac1ae --- /dev/null +++ b/ci/task/mypy.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" + +mypy ./python/mlc_chat/compiler ./python/mlc_chat/support +mypy ./tests/python/model ./tests/python/parameter diff --git a/ci/task/pylint.sh b/ci/task/pylint.sh new file mode 100755 index 0000000000..c29f5ad44e --- /dev/null +++ b/ci/task/pylint.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" + +# TVM Unity is a dependency to this testing +pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly + +pylint --jobs $NUM_THREADS ./python/mlc_chat/compiler ./python/mlc_chat/support +pylint --jobs $NUM_THREADS --recursive=y ./tests/python/model ./tests/python/parameter diff --git a/pyproject.toml b/pyproject.toml index 2310e9aa60..85ca20eb24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,4 +19,19 @@ profile = "black" [tool.black] line-length = 100 -target-version = ['py310'] + +[tool.mypy] +ignore_missing_imports = true +show_column_numbers = true +show_error_context = true +follow_imports = "skip" +ignore_errors = false +strict_optional = false +install_types = true +non_interactive = true + +[tool.pylint.messages_control] +max-line-length = 100 +disable = """ +duplicate-code, +""" diff --git a/python/mlc_chat/compiler/model/llama_parameter.py b/python/mlc_chat/compiler/model/llama_parameter.py index 39a8921a05..b0fa867130 100644 --- a/python/mlc_chat/compiler/model/llama_parameter.py +++ b/python/mlc_chat/compiler/model/llama_parameter.py @@ -2,6 +2,8 @@ This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ +from typing import Callable, Dict, List + import numpy as np from ..parameter import ExternMapping @@ -26,8 +28,8 @@ def hf_torch(model_config: LlamaConfig) -> ExternMapping: _, named_params = model.export_tvm(spec=model.get_default_spec()) parameter_names = {name for name, _ in named_params} - param_map = {} - map_func = {} + param_map: Dict[str, List[str]] = {} + map_func: Dict[str, Callable] = {} unused_params = set() for i in range(model_config.num_hidden_layers): @@ -35,24 +37,24 @@ def hf_torch(model_config: LlamaConfig) -> ExternMapping: attn = f"model.layers.{i}.self_attn" assert f"{attn}.qkv_proj.weight" in parameter_names map_func[f"{attn}.qkv_proj.weight"] = lambda q, k, v: np.concatenate([q, k, v], axis=0) - param_map[f"{attn}.qkv_proj.weight"] = ( + param_map[f"{attn}.qkv_proj.weight"] = [ f"{attn}.q_proj.weight", f"{attn}.k_proj.weight", f"{attn}.v_proj.weight", - ) + ] # Add gates in MLP mlp = f"model.layers.{i}.mlp" assert f"{mlp}.gate_up_proj.weight" in parameter_names map_func[f"{mlp}.gate_up_proj.weight"] = lambda gate, up: np.concatenate([gate, up], axis=0) - param_map[f"{mlp}.gate_up_proj.weight"] = ( + param_map[f"{mlp}.gate_up_proj.weight"] = [ f"{mlp}.gate_proj.weight", f"{mlp}.up_proj.weight", - ) + ] # inv_freq is not used in the model unused_params.add(f"{attn}.rotary_emb.inv_freq") for name in parameter_names: if name not in map_func: map_func[name] = lambda x: x - param_map[name] = (name,) + param_map[name] = [name] return ExternMapping(param_map, map_func, unused_params) diff --git a/python/mlc_chat/compiler/parameter/mapping.py b/python/mlc_chat/compiler/parameter/mapping.py index 3018c91ca3..6f63dce71a 100644 --- a/python/mlc_chat/compiler/parameter/mapping.py +++ b/python/mlc_chat/compiler/parameter/mapping.py @@ -1,10 +1,18 @@ """Parameter mapping for converting different LLM implementations to MLC LLM.""" import dataclasses -from typing import Callable, Dict, List, Set +from typing import Callable, Dict, List, Set, Union import numpy as np from tvm.runtime import NDArray +MapFuncVariadic = Union[ + Callable[[], np.ndarray], + Callable[[np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray], +] + @dataclasses.dataclass class ExternMapping: @@ -33,8 +41,8 @@ class ExternMapping: """ param_map: Dict[str, List[str]] - map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]] - unused_params: Set[str] = dataclasses.field(default_factory=dict) + map_func: Dict[str, MapFuncVariadic] + unused_params: Set[str] = dataclasses.field(default_factory=set) @dataclasses.dataclass @@ -72,8 +80,8 @@ class QuantizeMapping: used to convert the quantized parameters into the desired form. """ - param_map: Dict[str, Callable[str, List[str]]] - map_func: Dict[str, Callable[NDArray, List[NDArray]]] + param_map: Dict[str, Callable[[str], List[str]]] + map_func: Dict[str, Callable[[NDArray], List[NDArray]]] __all__ = ["ExternMapping", "QuantizeMapping"] diff --git a/python/mlc_chat/support/config.py b/python/mlc_chat/support/config.py index 62270ffd9c..9e42b815bc 100644 --- a/python/mlc_chat/support/config.py +++ b/python/mlc_chat/support/config.py @@ -37,10 +37,10 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass: cfg : ConfigClass An instance of the config object. """ - field_names = [field.name for field in dataclasses.fields(cls)] + field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type] fields = {k: v for k, v in source.items() if k in field_names} kwargs = {k: v for k, v in source.items() if k not in field_names} - return cls(**fields, kwargs=kwargs) + return cls(**fields, kwargs=kwargs) # type: ignore[call-arg] @classmethod def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass: diff --git a/tests/python/parameter/test_hf_torch_loader.py b/tests/python/parameter/test_hf_torch_loader.py index 745773b209..9cc8d0ea6c 100644 --- a/tests/python/parameter/test_hf_torch_loader.py +++ b/tests/python/parameter/test_hf_torch_loader.py @@ -1,6 +1,7 @@ # pylint: disable=missing-docstring import logging from pathlib import Path +from typing import Union import pytest from mlc_chat.compiler.model.llama import LlamaConfig @@ -24,7 +25,7 @@ "./dist/models/Llama-2-70b-hf", ], ) -def test_load_llama(base_path: str): +def test_load_llama(base_path: Union[str, Path]): base_path = Path(base_path) path_config = base_path / "config.json" path_params = base_path / "pytorch_model.bin.index.json"