Skip to content

Commit

Permalink
[ParamManager] Cleanup creation of quantization IRModule (mlc-ai#1053)
Browse files Browse the repository at this point in the history
This commit replaces the single-parameter
`relax_model.param_manager.create_quantize_func` function with a
method on the `ParamManager`, `create_parameter_transformation`.  This
avoids potential typos between `param_manager` as the imported Python
module `mlc_llm.relax_model.param_manager` and an instance of the
`ParamManager` class named `param_manager`, and makes the
functionality easier to find.

This function also takes an optional `optimize_parameter_order` flag,
defaulting to `True`, which applies the `ReorderTransformFunc` pass.
Since the `ReorderTransformFunc` is intended to be used with several
configuration objects owned by `ParamManager`, this simplifies the
common path of producing an optimally-ordered parameter transformation
module.
  • Loading branch information
Lunderberg authored Oct 14, 2023
1 parent 481cd92 commit 8184431
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
4 changes: 4 additions & 0 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ def build_model_from_args(args: argparse.Namespace):
qspec_updater.visit_module(mod)

if not args.build_model_only:
# Run pre-quantization if provided.
args.model_path = param_manager.run_pre_quantize(args.model_path)
param_manager.init_torch_pname_to_bin_name(args.use_safetensors)

new_params = utils.convert_weights(param_manager, params, args)
utils.save_params(new_params, args.artifact_path)
if args.model_category != "minigpt":
Expand Down
55 changes: 54 additions & 1 deletion mlc_llm/relax_model/param_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .. import quantization
from .modules import named_parameters
from ..transform import ReorderTransformFunc


def f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any:
Expand Down Expand Up @@ -274,6 +275,31 @@ def register_params(

self.params_in_func[func_name].append(param)

def run_pre_quantize(self, model_path: str):
if self.f_run_prequantize is not None:
model_path = self.f_run_prequantize(model_path)

self.model_path = model_path
return model_path

def init_torch_pname_to_bin_name(self, use_safetensors: bool):
assert hasattr(self, "model_path"), (
"Must call either set_param_loading_func or run_pre_quantize "
"before init_torch_pname_to_bin_name"
)

if self.pidx2pname:
mapping = load_torch_pname2binname_map(
self.model_path,
use_safetensors,
set(self.pidx2pname.values()),
self.f_convert_pname_fwd,
)
else:
mapping = {}

self.torch_pname2binname = mapping

def set_param_loading_func(
self,
model_path: str,
Expand Down Expand Up @@ -726,6 +752,33 @@ def _dequantize(
# Apply the dequantization function.
return bb.emit(f_dequantize(bb, qparams))

def create_parameter_transformation(self, optimize_parameter_order: bool = True):
"""Produce an IRModule that can transform the parameters
Parameters
----------
optimize_parameter_order: bool
If true, reorder the parameter transformations to
prioritize operations that use a currently-open file. If
false, transform the parameters in their default order.
Returns
-------
tvm.IRModule
The transformation module
"""
mod = _create_quantize_func(self)
if optimize_parameter_order:
reorder_pass = ReorderTransformFunc(
self.pidx2pname,
self.torch_pname2binname,
self.f_convert_pname_fwd,
)
mod = reorder_pass(mod)
return mod


@mutator
class ParamReplacer(PyExprMutator):
Expand Down Expand Up @@ -868,7 +921,7 @@ def load_torch_pname2binname_map(
return torch_pname2binname


def create_quantize_func(param_manager: ParamManager) -> tvm.IRModule:
def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule:
"""Construct the Relax function which computes quantization.
This method is called by `transform_module` below, and is not
directly invoked outside the class.
Expand Down
25 changes: 3 additions & 22 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .quantization import quantization_schemes
from .relax_model import param_manager
from .transform import ReorderTransformFunc


supported_model_types = set(
["llama", "gpt_neox", "gpt_bigcode", "minigpt", "moss", "rwkv", "gptj", "chatglm", "mistral", "stablelm_epoch"]
Expand Down Expand Up @@ -192,31 +192,12 @@ def convert_weights(
model_params: List[Optional[tvm.nd.NDArray]],
args: argparse.Namespace,
):
# Run pre-quantization if provided.
if param_mgr.f_run_prequantize is not None:
args.model_path = param_mgr.f_run_prequantize(args.model_path)
param_mgr.model_path = args.model_path
param_mgr.torch_pname2binname = (
param_manager.load_torch_pname2binname_map(
args.model_path,
args.use_safetensors,
set(param_mgr.pidx2pname.values()),
param_mgr.f_convert_pname_fwd,
)
if len(param_mgr.pidx2pname) != 0
else dict()
)

# Create the quantization function.
# We first create an initial one, then reorder it according to each
# weight's location in the binary files, in the purpose of reducing
# memory usage when loading torch weights as well as acceleration.
mod_transform = param_manager.create_quantize_func(param_mgr)
mod_transform = ReorderTransformFunc(
param_mgr.pidx2pname,
param_mgr.torch_pname2binname,
param_mgr.f_convert_pname_fwd,
)(mod_transform)
mod_transform = param_mgr.create_parameter_transformation()

# Remove the dataflow block inside the param transform function,
# so that the LazyTransformParams pass can be applied.
mod_transform = relax.transform.ToNonDataflow()(mod_transform)
Expand Down

0 comments on commit 8184431

Please sign in to comment.