diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 8c3a75c374..e720d19542 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -420,7 +420,8 @@ def mod_transform_before_build( if args.model.lower().startswith("rwkv-"): model_names += ["reset_kv_cache"] - mod = param_manager.transform_dequantize(mod) + mod = param_manager.transform_dequantize()(mod) + mod = relax.transform.BundleModelParams()(mod) use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 04f56a5152..7f0751b2a0 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -369,7 +369,7 @@ def set_param_loading_func( else: self.pidx2pname = dict() - def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: + def transform_dequantize(self) -> tvm.ir.transform.Pass: """Apply dequantization to the input IRModule. Parameters @@ -386,38 +386,48 @@ def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: The IRModule updated with the dequantization computation. """ - # For each Relax function in the input IRModule (e.g., "prefill"), - # we create its input relax.Var of all the quantized data, and - # store the mapping from function name to the var. - func2param_var: Dict[str, relax.Var] = {} - for gv, func in mod.functions.items(): - if not isinstance(func, relax.Function): - continue - if func.attrs is None or not "num_input" in func.attrs: - continue - func2param_var[gv.name_hint] = relax.Var( - "params", self.get_quantized_param_info(gv.name_hint) - ) + @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize") + def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: + # For each Relax function in the input IRModule (e.g., "prefill"), + # we create its input relax.Var of all the quantized data, and + # store the mapping from function name to the var. + func_name_to_quantized_params: Dict[str, List[relax.Var]] = {} - # Cache mapping to avoid duplicate dequantization. - dequantized_cache: Dict[relax.Var, relax.Var] = {} + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs: + quantized_param_info = self.get_quantized_param_info(gv.name_hint) + param_vars = [ + relax.Var(f"param_{i}", info) + for i, info in enumerate(quantized_param_info.fields) + ] + func_name_to_quantized_params[gv.name_hint] = param_vars - # Define a var replacement function for applying dequantization. - def f_replace(var: relax.Var, bb: relax.BlockBuilder, func_name: str) -> relax.Var: - if var in dequantized_cache: - return dequantized_cache[var] - assert var in self.func_raw_param_map - func_name, param = self.func_raw_param_map[var] - dequantized = self._dequantize(param, func2param_var[func_name], bb, func_name) - dequantized_cache[var] = dequantized - return dequantized + # Cache mapping to avoid duplicate dequantization. + dequantized_cache: Dict[relax.Var, relax.Var] = {} - # Create the function mutator for applying dequantization. - replacer = ParamReplacer(mod, func2param_var, f_replace) - # Update the input IRModule with dequantization. - mod = replacer.transform() + # Define a var replacement function for applying dequantization. + def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var: + if var in dequantized_cache: + return dequantized_cache[var] + assert var in self.func_raw_param_map - return mod + func_name, param = self.func_raw_param_map[var] + quantized_params = func_name_to_quantized_params[func_name] + relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]] + + dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name) + + dequantized_cache[var] = dequantized + return dequantized + + # Create the function mutator for applying dequantization. + replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace) + # Update the input IRModule with dequantization. + mod = replacer.transform() + + return mod + + return transform_func def get_quantized_param_info(self, func_name: str) -> List[relax.TensorStructInfo]: bb = relax.BlockBuilder() @@ -697,10 +707,9 @@ def _register_param( def _dequantize( self, param: Parameter, - quantized_tuple: relax.Var, + qparams: List[relax.Var], bb: relax.BlockBuilder, func_name: str, - qparams: List[relax.Var] = None, ) -> relax.Var: """Applying dequantization to the input parameter. This method is called by `transform_module` below, and is not @@ -711,30 +720,13 @@ def _dequantize( param : Parameter The parameter whose quantized tensors are to be dequantized. - quantized_tuple : relax.Var - The relax.Var of the quantized tensors of all parameters in the model. - - bb : relax.BlockBuilder - The Relax BlockBuilder used for inserting the dequantization computations. - - func_name : str - The name of the function which dequantization is applied to. - qparams : List[relax.Var] - The quantized parts of the parameter. - By default it is `None`, in which case we will get the quantized parts - from `quantized_tuple`. + The relax.Var of the quantized tensors of all parameters in the model. Returns ------- The dequantized parameter, in the form of a relax.Var. """ - if not qparams: - # Get the corresponding Relax vars of the quantized tensors of this parameter. - qparams: List[relax.Var] = [] - for qparam_idx in self.param2qrange[param]: - qparams.append(bb.emit(relax.TupleGetItem(quantized_tuple, qparam_idx))) - # Get the dequantization function of this parameter. f_dequantize = param.quant_spec.get_dequantize_func( param_info=param.param_info_dict[func_name], @@ -789,7 +781,7 @@ class ParamReplacer(PyExprMutator): mod : tvm.IRModule The IRModule of the model to be updated. - func2param_var : Dict[str, relax.Var] + func_name_to_quantized_params : Dict[str, List[relax.Var]] The mapping from each function name to its input var of quantized data tuple. f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var] @@ -801,7 +793,7 @@ class ParamReplacer(PyExprMutator): """ mod: tvm.IRModule - func2param_var: Dict[str, relax.Var] + func_name_to_quantized_params: Dict[str, List[relax.Var]] f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var] param_set: Set[relax.Var] @@ -810,12 +802,12 @@ class ParamReplacer(PyExprMutator): def __init__( self, mod: tvm.IRModule, - func2param_var: Dict[str, relax.Var], + func_name_to_quantized_params: Dict[str, relax.Var], f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var], ): super().__init__(mod) self.mod = mod - self.func2param_var = func2param_var + self.func_name_to_quantized_params = func_name_to_quantized_params self.f_replace = f_replace self.cur_func_name = "" @@ -827,31 +819,31 @@ def transform(self) -> tvm.IRModule: continue assert ( - gv.name_hint in self.func2param_var - ), f"{gv.name_hint} not in {self.func2param_var}" - self.cur_func_name = gv.name_hint - updated_func = self.rewrite_func(func, self.func2param_var[gv.name_hint]) + gv.name_hint in self.func_name_to_quantized_params + ), f"{gv.name_hint} not in {self.func_name_to_quantized_params}" + updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint]) updated_func = remove_all_unused(updated_func) self.builder_.update_func(gv, updated_func) return self.builder_.get() - def rewrite_func(self, func: Function, param_var: relax.Var) -> relax.Function: + def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function: num_input = int(func.attrs["num_input"]) self.param_set = set(func.params[num_input:]) body = self.visit_expr(func.body) return relax.Function( - params=func.params[:num_input] + [param_var], + params=func.params[:num_input] + quantized_params, body=body, ret_struct_info=func.ret_struct_info, is_pure=func.is_pure, attrs=func.attrs, - ).without_attr("num_input") + ) def visit_var_(self, var: Var) -> Expr: - if var not in self.param_set: + if var in self.param_set: + return self.f_replace(var, self.builder_) + else: return super().visit_var_(var) - return self.f_replace(var, self.builder_, self.cur_func_name) ##################################################################