diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 7168c9d8b1..0787db7073 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -638,6 +638,10 @@ def build_model_from_args(args: argparse.Namespace): qspec_updater.visit_module(mod) if not args.build_model_only: + # Run pre-quantization if provided. + args.model_path = param_manager.run_pre_quantize(args.model_path) + param_manager.init_torch_pname_to_bin_name(args.use_safetensors) + new_params = utils.convert_weights(param_manager, params, args) utils.save_params(new_params, args.artifact_path) if args.model_category != "minigpt": diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 138f04f769..590b60d76b 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -13,6 +13,7 @@ from .. import quantization from .modules import named_parameters +from ..transform import ReorderTransformFunc def f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any: @@ -274,6 +275,31 @@ def register_params( self.params_in_func[func_name].append(param) + def run_pre_quantize(self, model_path: str): + if self.f_run_prequantize is not None: + model_path = self.f_run_prequantize(model_path) + + self.model_path = model_path + return model_path + + def init_torch_pname_to_bin_name(self, use_safetensors: bool): + assert hasattr(self, "model_path"), ( + "Must call either set_param_loading_func or run_pre_quantize " + "before init_torch_pname_to_bin_name" + ) + + if self.pidx2pname: + mapping = load_torch_pname2binname_map( + self.model_path, + use_safetensors, + set(self.pidx2pname.values()), + self.f_convert_pname_fwd, + ) + else: + mapping = {} + + self.torch_pname2binname = mapping + def set_param_loading_func( self, model_path: str, @@ -726,6 +752,33 @@ def _dequantize( # Apply the dequantization function. return bb.emit(f_dequantize(bb, qparams)) + def create_parameter_transformation(self, optimize_parameter_order: bool = True): + """Produce an IRModule that can transform the parameters + + Parameters + ---------- + optimize_parameter_order: bool + + If true, reorder the parameter transformations to + prioritize operations that use a currently-open file. If + false, transform the parameters in their default order. + + Returns + ------- + tvm.IRModule + The transformation module + + """ + mod = _create_quantize_func(self) + if optimize_parameter_order: + reorder_pass = ReorderTransformFunc( + self.pidx2pname, + self.torch_pname2binname, + self.f_convert_pname_fwd, + ) + mod = reorder_pass(mod) + return mod + @mutator class ParamReplacer(PyExprMutator): @@ -868,7 +921,7 @@ def load_torch_pname2binname_map( return torch_pname2binname -def create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: +def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: """Construct the Relax function which computes quantization. This method is called by `transform_module` below, and is not directly invoked outside the class. diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 9d8751e5d6..f356874d1d 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -10,7 +10,7 @@ from .quantization import quantization_schemes from .relax_model import param_manager -from .transform import ReorderTransformFunc + supported_model_types = set( ["llama", "gpt_neox", "gpt_bigcode", "minigpt", "moss", "rwkv", "gptj", "chatglm", "mistral", "stablelm_epoch"] @@ -192,31 +192,12 @@ def convert_weights( model_params: List[Optional[tvm.nd.NDArray]], args: argparse.Namespace, ): - # Run pre-quantization if provided. - if param_mgr.f_run_prequantize is not None: - args.model_path = param_mgr.f_run_prequantize(args.model_path) - param_mgr.model_path = args.model_path - param_mgr.torch_pname2binname = ( - param_manager.load_torch_pname2binname_map( - args.model_path, - args.use_safetensors, - set(param_mgr.pidx2pname.values()), - param_mgr.f_convert_pname_fwd, - ) - if len(param_mgr.pidx2pname) != 0 - else dict() - ) - # Create the quantization function. # We first create an initial one, then reorder it according to each # weight's location in the binary files, in the purpose of reducing # memory usage when loading torch weights as well as acceleration. - mod_transform = param_manager.create_quantize_func(param_mgr) - mod_transform = ReorderTransformFunc( - param_mgr.pidx2pname, - param_mgr.torch_pname2binname, - param_mgr.f_convert_pname_fwd, - )(mod_transform) + mod_transform = param_mgr.create_parameter_transformation() + # Remove the dataflow block inside the param transform function, # so that the LazyTransformParams pass can be applied. mod_transform = relax.transform.ToNonDataflow()(mod_transform)