diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a1f423aa27..6d9ba750431 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -676,62 +676,47 @@ def _validate_args(args): ) -def _to_edge_and_lower_llama_xnnpack( - builder_exported, - modelname, - additional_passes, - pt2e_quant_params, - quantizers, - quant_dtype, - args, -) -> LLMEdgeManager: # noqa: C901 - partitioners = [] - - # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled - partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)) - - modelname = f"xnnpack_dq_{modelname}" - - if args.xnnpack_extended_ops: - partitioners.append( - get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) - ) - modelname = f"xnnpack_{modelname}" - - logging.info("Lowering model using following partitioner(s): ") - for partitioner in partitioners: - logging.info(f"--> {partitioner.__class__.__name__}") +def _export_llama(args) -> LLMEdgeManager: # noqa: C901 + _validate_args(args) - # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower(). - if args.generate_etrecord: - raise NotImplementedError( - "export_llama does not support XNNPack and generating ETRecord at the moment." - ) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) - builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( - partitioners - ) - if args.verbose: - print_delegation_info(builder.edge_manager.exported_program().graph_module) + # export_to_edge + builder_exported = _prepare_for_llama_export(args).export() - return builder.to_executorch(passes=additional_passes) + builder_exported.run_canonical_optimizations() + if args.export_only: + exit() -def _to_edge_and_lower_llama( # noqa: C901 - builder_exported, - modelname, - additional_passes, - pt2e_quant_params, - quantizers, - quant_dtype, - args, -): builder_exported_to_edge = builder_exported.pt2e_quantize( quantizers ).export_to_edge() + modelname = builder_exported_to_edge.modelname + # to_backend partitioners = [] + + # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled + if ( + pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None + ) or (args.xnnpack): + partitioners.append( + get_xnnpack_partitioner(dynamic_quant_only_partitioner=True) + ) + + # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False + args.xnnpack = True + modelname = f"xnnpack_dq_{modelname}" + + if args.xnnpack_extended_ops: + assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled" + partitioners.append( + get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) + ) + modelname = f"xnnpack_{modelname}" + if args.vulkan: partitioners.append( get_vulkan_partitioner( @@ -746,6 +731,7 @@ def _to_edge_and_lower_llama( # noqa: C901 modelname = f"vulkan_{modelname}" # Need to remove asserts from the graph to prevent graph breaks + # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`. remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) if args.mps: @@ -774,11 +760,13 @@ def _to_edge_and_lower_llama( # noqa: C901 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` _transform(builder_exported_to_edge.edge_manager.exported_program()) if args.num_sharding > 0: model_sharding.split_graph( builder_exported_to_edge.edge_manager.exported_program(), + # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. builder_exported_to_edge.metadata["get_n_layers"], shares=args.num_sharding, ) @@ -804,15 +792,19 @@ def _to_edge_and_lower_llama( # noqa: C901 atten.head_dim, ) ) + # pyre-ignore tag_quant_io( builder_exported_to_edge.edge_manager.exported_program().graph_module, - partial(get_custom_quant_ios_dtype, cache_shape), + partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore ) logging.info("Lowering model using following partitioner(s): ") for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") + additional_passes = [] + if args.model in TORCHTUNE_DEFINED_MODELS: + additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -826,6 +818,7 @@ def _to_edge_and_lower_llama( # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program + # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) builder = builder.to_executorch( @@ -847,55 +840,11 @@ def _to_edge_and_lower_llama( # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program + # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) builder = builder.to_executorch(passes=additional_passes) - return builder - - -def _export_llama(args) -> LLMEdgeManager: # noqa: C901 - _validate_args(args) - - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) - - additional_passes = [] - if args.model in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] - - # export_to_edge - builder_exported = _prepare_for_llama_export(args).export() - builder_exported.run_canonical_optimizations() - modelname = builder_exported.modelname - - if args.export_only: - exit() - - if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: - # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False - args.xnnpack = True - - if args.xnnpack: - builder = _to_edge_and_lower_llama_xnnpack( - builder_exported, - modelname, - additional_passes, - pt2e_quant_params, - quantizers, - quant_dtype, - args, - ) - else: - builder = _to_edge_and_lower_llama( - builder_exported, - modelname, - additional_passes, - pt2e_quant_params, - quantizers, - quant_dtype, - args, - ) - if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") @@ -917,6 +866,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 output_file = f"{builder.output_dir}/{modelname}.pte" builder.save_to_pte(output_file) + return builder diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index e6d228d5da9..8923ab1fdec 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -119,10 +119,11 @@ def quantize( # noqa C901 # Check for required args if group_size is None: raise Exception("For 8da4w quantization, group size must be specified.") + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ - - quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size)) + model = Int8DynActInt4WeightQuantizer( + precision=torch_dtype, groupsize=group_size + ).quantize(model) if verbose: print("quantized model:", model) @@ -662,7 +663,7 @@ def convert_for_runtime(self) -> nn.Module: def quantized_model(self) -> nn.Module: model_updated_state_dict = self.create_quantized_state_dict(self.packed) self.convert_for_runtime() - self.mod.load_state_dict(model_updated_state_dict, assign=True) + self.mod.load_state_dict(model_updated_state_dict) return self.mod diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index a5057e5e850..82c7aca09e0 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -67,6 +67,7 @@ def export(self) -> "LlavaEdgeManager": dynamic_shapes=dynamic_shape, strict=False, ) + # pyre-ignore: Incompatible attribute type [8]: Attribute `pre_autograd_graph_module` declared in class `LLMEdgeManager` has type `Optional[GraphModule]` but is used as type `Module`. self.pre_autograd_graph_module = self.export_program.module() return self diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ec6cfa41ad8..88d2bc0cab9 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -21,7 +21,7 @@ DuplicateDynamicQuantChainPass, ) from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass -from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower +from executorch.exir import EdgeProgramManager from executorch.exir.backend.partitioner import Partitioner from executorch.exir.backend.utils import format_delegated_graph @@ -39,7 +39,7 @@ from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.export import export_for_training, ExportedProgram +from torch.export import export_for_training from torch.nn.attention import SDPBackend FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -89,8 +89,8 @@ def __init__( dynamic_shapes: Optional[Any] = None, ): self.model = model - self.pre_autograd_exported_program: Optional[ExportedProgram] = None - self.pre_autograd_graph_module: Optional[torch.nn.Module] = None + # graph module returned from export() + self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None self.modelname = modelname self.max_seq_len = max_seq_len self.dtype = dtype @@ -218,8 +218,8 @@ def export(self) -> "LLMEdgeManager": kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, ) + # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as # `Module`. - self.pre_autograd_exported_program = exported_module self.pre_autograd_graph_module = exported_module.module() if hasattr(self.args, "export_only") and self.args.export_only: torch.export.save(exported_module, self.args.output_name) @@ -330,10 +330,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage assert ( self.pre_autograd_graph_module is not None ), "Please run export() first" - m = prepare_pt2e( - self.pre_autograd_graph_module, # pyre-ignore[6] - composed_quantizer, - ) + m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) logging.info( f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" ) @@ -433,19 +430,6 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self - def to_edge_transform_and_lower( - self, partitioners: Optional[List[Partitioner]] - ) -> "LLMEdgeManager": - if partitioners is None: - logging.info("No partitioner provided, skipping backend lowering...") - edge_config = self._get_edge_config() - self.edge_manager = to_edge_transform_and_lower( - self.pre_autograd_exported_program, - partitioner=partitioners, - compile_config=edge_config, - ) - return self - def to_executorch( self, passes: Optional[List[ExportPass]] = None ) -> "LLMEdgeManager":