Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Basic Pylint and Mypy Tooling #1100

Merged
merged 1 commit into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
name: Python Lint

on: [push, pull_request]

env:
IMAGE: 'mlcaidev/ci-cpu:8a87699'
IMAGE: 'mlcaidev/ci-cpu:2c03e7f'

jobs:
isort:
Expand Down Expand Up @@ -35,3 +33,33 @@ jobs:
- name: Lint
run: |
./ci/bash.sh $IMAGE bash ./ci/task/black.sh

mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: 'recursive'
- name: Version
run: |
wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
chmod u+x ./ci/bash.sh
./ci/bash.sh $IMAGE "conda env export --name ci-lint"
- name: Lint
run: |
./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh

pylint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: 'recursive'
- name: Version
run: |
wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
chmod u+x ./ci/bash.sh
./ci/bash.sh $IMAGE "conda env export --name ci-lint"
- name: Lint
run: |
./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh
3 changes: 2 additions & 1 deletion ci/task/black.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ set -eo pipefail

source ~/.bashrc
micromamba activate ci-lint
NUM_THREADS=$(nproc)
export NUM_THREADS=$(nproc)
export PYTHONPATH="./python:$PYTHONPATH"

black --check --workers $NUM_THREADS ./python/
black --check --workers $NUM_THREADS ./tests/python
3 changes: 2 additions & 1 deletion ci/task/isort.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ set -eo pipefail

source ~/.bashrc
micromamba activate ci-lint
NUM_THREADS=$(nproc)
export NUM_THREADS=$(nproc)
export PYTHONPATH="./python:$PYTHONPATH"

isort --check-only -j $NUM_THREADS --profile black ./python/
isort --check-only -j $NUM_THREADS --profile black ./tests/python/
10 changes: 10 additions & 0 deletions ci/task/mypy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -eo pipefail

source ~/.bashrc
micromamba activate ci-lint
export NUM_THREADS=$(nproc)
export PYTHONPATH="./python:$PYTHONPATH"

mypy ./python/mlc_chat/compiler ./python/mlc_chat/support
mypy ./tests/python/model ./tests/python/parameter
13 changes: 13 additions & 0 deletions ci/task/pylint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
set -eo pipefail

source ~/.bashrc
micromamba activate ci-lint
export NUM_THREADS=$(nproc)
export PYTHONPATH="./python:$PYTHONPATH"

# TVM Unity is a dependency to this testing
pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly

pylint --jobs $NUM_THREADS ./python/mlc_chat/compiler ./python/mlc_chat/support
pylint --jobs $NUM_THREADS --recursive=y ./tests/python/model ./tests/python/parameter
17 changes: 16 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,19 @@ profile = "black"

[tool.black]
line-length = 100
target-version = ['py310']

[tool.mypy]
ignore_missing_imports = true
show_column_numbers = true
show_error_context = true
follow_imports = "skip"
ignore_errors = false
strict_optional = false
install_types = true
non_interactive = true

[tool.pylint.messages_control]
max-line-length = 100
disable = """
duplicate-code,
"""
16 changes: 9 additions & 7 deletions python/mlc_chat/compiler/model/llama_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""
from typing import Callable, Dict, List

import numpy as np

from ..parameter import ExternMapping
Expand All @@ -26,33 +28,33 @@ def hf_torch(model_config: LlamaConfig) -> ExternMapping:
_, named_params = model.export_tvm(spec=model.get_default_spec())
parameter_names = {name for name, _ in named_params}

param_map = {}
map_func = {}
param_map: Dict[str, List[str]] = {}
map_func: Dict[str, Callable] = {}
unused_params = set()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
assert f"{attn}.qkv_proj.weight" in parameter_names
map_func[f"{attn}.qkv_proj.weight"] = lambda q, k, v: np.concatenate([q, k, v], axis=0)
param_map[f"{attn}.qkv_proj.weight"] = (
param_map[f"{attn}.qkv_proj.weight"] = [
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
)
]
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
assert f"{mlp}.gate_up_proj.weight" in parameter_names
map_func[f"{mlp}.gate_up_proj.weight"] = lambda gate, up: np.concatenate([gate, up], axis=0)
param_map[f"{mlp}.gate_up_proj.weight"] = (
param_map[f"{mlp}.gate_up_proj.weight"] = [
f"{mlp}.gate_proj.weight",
f"{mlp}.up_proj.weight",
)
]
# inv_freq is not used in the model
unused_params.add(f"{attn}.rotary_emb.inv_freq")

for name in parameter_names:
if name not in map_func:
map_func[name] = lambda x: x
param_map[name] = (name,)
param_map[name] = [name]
return ExternMapping(param_map, map_func, unused_params)
18 changes: 13 additions & 5 deletions python/mlc_chat/compiler/parameter/mapping.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
"""Parameter mapping for converting different LLM implementations to MLC LLM."""
import dataclasses
from typing import Callable, Dict, List, Set
from typing import Callable, Dict, List, Set, Union

import numpy as np
from tvm.runtime import NDArray

MapFuncVariadic = Union[
Callable[[], np.ndarray],
Callable[[np.ndarray], np.ndarray],
Callable[[np.ndarray, np.ndarray], np.ndarray],
Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray],
]


@dataclasses.dataclass
class ExternMapping:
Expand Down Expand Up @@ -33,8 +41,8 @@ class ExternMapping:
"""

param_map: Dict[str, List[str]]
map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]]
unused_params: Set[str] = dataclasses.field(default_factory=dict)
map_func: Dict[str, MapFuncVariadic]
unused_params: Set[str] = dataclasses.field(default_factory=set)


@dataclasses.dataclass
Expand Down Expand Up @@ -72,8 +80,8 @@ class QuantizeMapping:
used to convert the quantized parameters into the desired form.
"""

param_map: Dict[str, Callable[str, List[str]]]
map_func: Dict[str, Callable[NDArray, List[NDArray]]]
param_map: Dict[str, Callable[[str], List[str]]]
map_func: Dict[str, Callable[[NDArray], List[NDArray]]]


__all__ = ["ExternMapping", "QuantizeMapping"]
4 changes: 2 additions & 2 deletions python/mlc_chat/support/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass:
cfg : ConfigClass
An instance of the config object.
"""
field_names = [field.name for field in dataclasses.fields(cls)]
field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type]
fields = {k: v for k, v in source.items() if k in field_names}
kwargs = {k: v for k, v in source.items() if k not in field_names}
return cls(**fields, kwargs=kwargs)
return cls(**fields, kwargs=kwargs) # type: ignore[call-arg]

@classmethod
def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass:
Expand Down
3 changes: 2 additions & 1 deletion tests/python/parameter/test_hf_torch_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=missing-docstring
import logging
from pathlib import Path
from typing import Union

import pytest
from mlc_chat.compiler.model.llama import LlamaConfig
Expand All @@ -24,7 +25,7 @@
"./dist/models/Llama-2-70b-hf",
],
)
def test_load_llama(base_path: str):
def test_load_llama(base_path: Union[str, Path]):
base_path = Path(base_path)
path_config = base_path / "config.json"
path_params = base_path / "pytorch_model.bin.index.json"
Expand Down