From e42c965f860fa9d706c9531ea47f97ca879c7c6b Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Wed, 9 Apr 2025 17:14:08 -0700 Subject: [PATCH 1/2] Recipe and Input class definitions with e2e export Differential Revision: [D71946730](https://our.internmc.facebook.com/intern/diff/D71946730/) [ghstack-poisoned] --- export/TARGETS | 36 +++ export/__init__.py | 24 ++ export/_export.py | 357 +++++++++++++++++++++++++ export/_recipe.py | 83 ++++++ export/tests/TARGETS | 16 ++ export/tests/test_executorch_export.py | 33 +++ runtime/TARGETS | 3 +- 7 files changed, 550 insertions(+), 2 deletions(-) create mode 100644 export/TARGETS create mode 100644 export/__init__.py create mode 100644 export/_export.py create mode 100644 export/_recipe.py create mode 100644 export/tests/TARGETS create mode 100644 export/tests/test_executorch_export.py diff --git a/export/TARGETS b/export/TARGETS new file mode 100644 index 00000000000..cfa71923361 --- /dev/null +++ b/export/TARGETS @@ -0,0 +1,36 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + +python_library( + name = "recipe", + srcs = [ + "_recipe.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/backend:backend_api", + "//executorch/exir:pass_manager", + ] +) + +python_library( + name = "export", + srcs = [ + "_export.py", + ], + deps = [ + ":recipe", + ] +) + +python_library( + name = "lib", + srcs = [ + "__init__.py", + ], + deps = [ + ":export", + ":recipe", + ], +) diff --git a/export/__init__.py b/export/__init__.py new file mode 100644 index 00000000000..14803e76ef9 --- /dev/null +++ b/export/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +ExecuTorch export module. + +This module provides the tools and utilities for exporting PyTorch models +to the ExecuTorch format, including configuration, quantization, and +export management. +""" + +# pyre-strict + +from ._export import ExportSession, export +from ._recipe import ExportRecipe + +__all__ = [ + "ExportRecipe", + "ExportSession", + "export", +] diff --git a/export/_export.py b/export/_export.py new file mode 100644 index 00000000000..b0065cef002 --- /dev/null +++ b/export/_export.py @@ -0,0 +1,357 @@ +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch import nn +from executorch.devtools.backend_debug import get_delegation_info +from executorch.exir.program import ( + EdgeProgramManager, + ExecutorchProgramManager, + to_edge_transform_and_lower, +) +from executorch.runtime import Runtime, Verification +from tabulate import tabulate +from torch.ao.quantization import allow_exported_model_train_eval +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.export import export_for_training, ExportedProgram + +from ._recipe import ExportRecipe + + +def export( + model: Union[nn.Module, Dict[str, nn.Module]], + example_inputs: Union[List[tuple[torch.Tensor, ...]], Dict[str, List[tuple[torch.Tensor, ...]]]], + export_recipe: ExportRecipe, + name: Optional[str] = None, + dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None, + constant_methods: Optional[Union[Dict[str, Callable]]] = None, + artifact_dir: Optional[str] = None, + apply_quantization: bool = False, +) -> "ExportSession": + """ + Create and configure an ExportSession with the given parameters. + + This function provides a convenient way to create an ExportSession and + optionally run the export process in one step. + + Args: + model: The PyTorch model(s) to export, either a single model or a dictionary + mapping method names to models + example_inputs: Example inputs for the model(s), either a list of input tuples + or a dictionary mapping method names to lists of input tuples + export_recipe: Contains the configuration for the export process + name: Optional name for the export + dynamic_shapes: Optional dynamic shape specifications + constant_methods: Optional dictionary of constant methods + artifact_dir: Optional directory to store artifacts + apply_quantization: Whether to apply quantization during export, defaults to False + + Returns: + A configured ExportSession instance with the export process completed if requested + """ + manager = ExportSession( + model=model, + example_inputs=example_inputs, + export_recipe=export_recipe, + name=name, + dynamic_shapes=dynamic_shapes, + constant_methods=constant_methods, + artifact_dir=artifact_dir, + ) + + if apply_quantization: + manager.export(apply_quantization=True) + + return manager + + +class ExportSession: + """ + Manages the export process for ExecuTorch models. + + This class handles the three-stage export process: + 1. Export PyTorch model to ExportedProgram + 2. Transform and lower to EdgeProgramManager + 3. Convert to ExecutorchProgramManager for final execution + """ + + def __init__( + self, + model: Union[nn.Module, Dict[str, nn.Module]], + example_inputs: Union[List[tuple[torch.Tensor, ...]], Dict[str, List[tuple[torch.Tensor, ...]]]], + export_recipe: ExportRecipe, + name: Optional[str] = None, + dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None, + constant_methods: Optional[Union[Dict[str, Callable]]] = None, + artifact_dir: Optional[str] = None, + ) -> None: + """ + Initialize the ExportSession with model, inputs, and recipe. + + Args: + model: The PyTorch model(s) to export, either a single model or a dictionary + mapping method names to models + example_inputs: Example inputs for the model(s), either a list of input tuples + or a dictionary mapping method names to lists of input tuples + export_recipe: Contains the configuration for the export process + name: Optional name for the export + dynamic_shapes: Optional dynamic shape specifications + constant_methods: Optional dictionary of constant methods + artifact_dir: Optional directory to store artifacts + """ + # Standardize model to dictionary format + self._model = model if isinstance(model, dict) else {"forward": model} + + # Standardize example_inputs to dictionary format + self._example_inputs = example_inputs if isinstance(example_inputs, dict) else {"forward": example_inputs} + + # Standardize dynamic_shapes to dictionary format + self._dynamic_shapes = {} + if dynamic_shapes is not None: + if isinstance(dynamic_shapes, dict): + self._dynamic_shapes = dynamic_shapes + else: + self._dynamic_shapes = {"forward": dynamic_shapes} + + self._name = name + self._constant_methods = constant_methods + self._artifact_dir = artifact_dir + self._export_recipe = export_recipe + self._exported_program: Dict[str, ExportedProgram] = {} + self._edge_program_manager: Optional[EdgeProgramManager] = None + self._executorch_program_manager: Optional[ExecutorchProgramManager] = None + self._delegation_info = None + + # Export models for training to enable quantization + self._exported_models: Dict[str, nn.Module] = {} + for method_name, model in self._model.items(): + self._exported_models[method_name] = export_for_training( + model, + self._example_inputs[method_name][0], # type: ignore + dynamic_shapes=self._dynamic_shapes.get(method_name, None), + ).module() + + def quantize(self) -> None: + """ + Perform post-training quantization on the model. + + This method applies post-training quantization to the model using the + quantizer specified in the export recipe and the calibration data from + the export input. The model is modified in-place. + + Note: + This should be called before the export process if quantization is desired. + """ + if self._export_recipe.quantizer is None: + raise ValueError("Quantizer not specified in the export recipe") + + for method_name, model in self._model.items(): + # Set model to evaluation mode for quantization + model.eval() + + # Use the pre-exported model from initialization + captured_model = self._exported_models[method_name] + + # Get the quantizer from the recipe + quantizer = self._export_recipe.get_quantizer() + + # Prepare the model for quantization + prepared_model = prepare_pt2e(captured_model, quantizer) # type: ignore + + # Allow the exported model to switch between train and eval modes + allow_exported_model_train_eval(prepared_model) + + # Calibrate the model with the provided calibration data + for calibration_input in self._example_inputs[method_name]: # type: ignore + prepared_model(*calibration_input) + + # Convert the prepared model to a quantized model + # Update the model in the model dictionary + quantized_model = convert_pt2e(prepared_model) + self._model[method_name] = quantized_model # type: ignore + + def export(self, apply_quantization: bool = False) -> None: + """ + Execute the full export process. + + This method orchestrates the export process with optional quantization: + 1. (Optional) Apply quantization to the model + 2. Export the PyTorch model to ExportedProgram + 3. Transform and lower to EdgeProgramManager + 4. Convert to ExecutorchProgramManager + + Args: + apply_quantization: Whether to apply quantization before export, defaults to False + """ + if apply_quantization and self._export_recipe.quantizer is not None: + self.quantize() + + self._export_stage() + self._to_edge_transform_and_lower_stage() + self._to_executorch_stage() + + def _export_stage(self) -> None: + """ + First stage: Export PyTorch model to ExportedProgram. + + Exports each model in the input to an ExportedProgram and applies + any pre-edge transform passes if specified. + """ + with torch.no_grad(): + for method_name, model in self._model.items(): + # Check if method_name exists in example_inputs + if method_name not in self._example_inputs: + raise ValueError( + f"Example inputs for method {method_name} not found." + ) + + # Get dynamic shapes if available + dynamic_shapes = None + if method_name in self._dynamic_shapes: + dynamic_shapes = self._dynamic_shapes[method_name] + + # Export the model + self._exported_program[method_name] = torch.export.export( + model, + self._example_inputs[method_name][0], + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + # Apply pre-edge transform passes if available + if self._export_recipe.pre_edge_transform_passes is not None: + self._exported_program[method_name] = ( + self._export_recipe.pre_edge_transform_passes( + self._exported_program[method_name] + ) + ) + + def _to_edge_transform_and_lower_stage(self) -> None: + """ + Second stage: Transform and lower to EdgeProgramManager. + + Applies partitioning and transformation passes to convert the + ExportedProgram to an EdgeProgramManager. + """ + self._edge_program_manager = to_edge_transform_and_lower( + self._exported_program, + partitioner=self._export_recipe.partitioners, + transform_passes=self._export_recipe.edge_transform_passes, + constant_methods=self._constant_methods, + compile_config=self._export_recipe.edge_compile_config, + ) + self._delegation_info = get_delegation_info(self._edge_program_manager.exported_program().graph_module) + + def _to_executorch_stage(self) -> None: + """ + Third stage: Convert to ExecutorchProgramManager. + + Converts the EdgeProgramManager to an ExecutorchProgramManager + using the specified backend configuration. + """ + if self._edge_program_manager is None: + raise RuntimeError( + "Edge program manager is not initialized. Run _to_edge_transform_and_lower_stage first." + ) + self._executorch_program_manager = self._edge_program_manager.to_executorch( + self._export_recipe.executorch_backend_config + ) + + def save_pte_file(self, path: str) -> None: + """ + Save the exported program to a PTE file. + + Args: + path: Path where the PTE file will be saved + + Raises: + RuntimeError: If the executorch program manager is not initialized + """ + if self._executorch_program_manager is None: + raise RuntimeError( + "Executorch program manager is not initialized. Run export() first." + ) + self._executorch_program_manager.save(path) + + def get_pte_buffer(self) -> bytes: + """ + Get the PTE buffer as bytes. + + Returns: + The PTE buffer as bytes + + Raises: + RuntimeError: If the executorch program manager is not initialized + """ + if self._executorch_program_manager is None: + raise RuntimeError( + "Executorch program manager is not initialized. Run export() first." + ) + return self._executorch_program_manager.buffer + + def get_example_input( + self, method_name: str = "forward" + ) -> Tuple[torch.Tensor, ...]: + """ + Get the example input for a specific method. + + Args: + method_name: Name of the method to get example input for, defaults to "forward" + + Returns: + Tuple of tensors representing the example input + + Raises: + KeyError: If the method name is not found in example inputs + ValueError: If the example inputs list is empty + """ + if method_name not in self._example_inputs: + raise KeyError(f"Method name '{method_name}' not found in example inputs") + + # Access the first element of the list for this method + example_inputs_list = self._example_inputs[method_name] + if not example_inputs_list: + raise ValueError(f"Example inputs list for method {method_name} is empty") + + # The original code expects this to be a tuple of tensors + return self._example_inputs[method_name][0] + + def run_method( + self, + method_name: str = "forward", + example_inputs: Optional[Tuple[torch.Tensor, ...]] = None, + ) -> Sequence[Any]: + """ + Run a specific method with the given inputs. + + Args: + method_name: Name of the method to run, defaults to "forward" + example_inputs: Optional inputs to use, defaults to the example inputs + + Returns: + The outputs of the method execution + + Raises: + RuntimeError: If the method cannot be loaded + """ + et_runtime = Runtime.get() + program = et_runtime.load_program( + self.get_pte_buffer(), verification=Verification.Minimal + ) + forward = program.load_method(method_name) + + if forward is None: + raise RuntimeError( + f"Failed to load method '{method_name}' from the program" + ) + if example_inputs is None: + example_inputs = self.get_example_input(method_name) + + return forward.execute(example_inputs) + + def print_delegation_info(self) -> None: + """ + Print delegation information for the exported program. + """ + print(self._delegation_info.get_summary()) + df = self._delegation_info.get_operator_delegation_dataframe() + print(tabulate(df, headers="keys", tablefmt="fancy_grid")) diff --git a/export/_recipe.py b/export/_recipe.py new file mode 100644 index 00000000000..f45250237a3 --- /dev/null +++ b/export/_recipe.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export recipe definitions for ExecuTorch. + +This module provides the data structures needed to configure the export process +for ExecuTorch models, including export configurations. +""" + +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Optional, Sequence + +from executorch.exir.backend.partitioner import Partitioner +from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.pass_manager import PassType +from torch.ao.quantization.quantizer import Quantizer +from torch.export import ExportedProgram + + +class Mode(str, Enum): + """ + Export mode enumeration. + + Attributes: + DEBUG: Debug mode with additional checks and information + RELEASE: Release mode optimized for performance + """ + + DEBUG = "debug" + RELEASE = "release" + + +@dataclass +class ExportRecipe: + """ + Configuration recipe for the export process. + + This class holds the configuration parameters for exporting a model, + including quantization, compilation, and transformation options. + + Attributes: + name: Optional name for the recipe + quantizer: Optional quantizer for model quantization + edge_compile_config: Optional edge compilation configuration + pre_edge_transform_passes: Optional function to apply transformation passes + before edge lowering + edge_transform_passes: Optional sequence of transformation passes to apply + during edge lowering + transform_check_ir_validity: Whether to check IR validity during transformation + partitioners: Optional list of partitioners for model partitioning + executorch_backend_config: Optional backend configuration for ExecuTorch + mode: Export mode (debug or release) + """ + + name: Optional[str] = None + quantizer: Optional[Quantizer] = None + edge_compile_config: Optional[EdgeCompileConfig] = ( # pyre-ignore[11]: Type not defined + None + ) + pre_edge_transform_passes: Optional[ + Callable[[ExportedProgram], ExportedProgram] + ] = None + edge_transform_passes: Optional[Sequence[PassType]] = None + transform_check_ir_validity: bool = True + partitioners: Optional[list[Partitioner]] = None + executorch_backend_config: Optional[ExecutorchBackendConfig] = ( # pyre-ignore[11]: Type not defined + None + ) + mode: Mode = Mode.RELEASE + + def get_quantizer(self) -> Optional[Quantizer]: + """ + Get the quantizer associated with this recipe. + + Returns: + The quantizer if one is set, otherwise None + """ + return self.quantizer diff --git a/export/tests/TARGETS b/export/tests/TARGETS new file mode 100644 index 00000000000..93556cb03dd --- /dev/null +++ b/export/tests/TARGETS @@ -0,0 +1,16 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +oncall("executorch") + +python_unittest( + name = "executorch_export", + srcs = [ + "test_executorch_export.py", + ], + deps = [ + "//executorch/exir:lib", + "//executorch/export:lib", + "//executorch/devtools/backend_debug:delegation_info", + "//executorch/runtime:runtime", + ] +) diff --git a/export/tests/test_executorch_export.py b/export/tests/test_executorch_export.py new file mode 100644 index 00000000000..09998a84c81 --- /dev/null +++ b/export/tests/test_executorch_export.py @@ -0,0 +1,33 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import unittest + +import torch +from executorch.export import ExportRecipe, export + +class TestExecutorchExport(unittest.TestCase): + def test_basic_recipe(self) -> None: + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + model = SimpleModel() + example_inputs = [(torch.rand(1, 10),)] + export_recipe = ExportRecipe() + + # Use the export API instead of creating ExportSession directly + export_session = export( + model=model, + example_inputs=example_inputs, + export_recipe=export_recipe + ) + + # The export function doesn't automatically call export() on the session + export_session.export() + self.assertTrue(len(export_session.get_pte_buffer())!= 0) diff --git a/runtime/TARGETS b/runtime/TARGETS index c341c042d03..7448523f5ff 100644 --- a/runtime/TARGETS +++ b/runtime/TARGETS @@ -9,8 +9,7 @@ runtime.python_library( "//executorch/extension/pybindings:portable_lib", ], visibility = [ - "//executorch/runtime/...", - "//executorch/exir/emit/test/...", + "//executorch/...", "@EXECUTORCH_CLIENTS", ], ) From 910ac24ed9d64a92b093568ccd4d9638d53fa7fa Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Wed, 9 Apr 2025 17:14:12 -0700 Subject: [PATCH 2/2] Add some basic xnnpack recipes Differential Revision: [D72085170](https://our.internmc.facebook.com/intern/diff/D72085170/) [ghstack-poisoned] --- .../duplicate_dynamic_quant_chain.py | 8 +++ backends/xnnpack/TARGETS | 1 + backends/xnnpack/__init__.py | 3 +- backends/xnnpack/recipes/TARGETS | 18 +++++++ backends/xnnpack/recipes/recipes.py | 51 +++++++++++++++++++ backends/xnnpack/targets.bzl | 3 +- backends/xnnpack/test/TARGETS | 10 ++++ .../test/recipes/test_xnnpack_recipes.py | 37 ++++++++++++++ 8 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 backends/xnnpack/recipes/TARGETS create mode 100644 backends/xnnpack/recipes/recipes.py create mode 100644 backends/xnnpack/test/recipes/test_xnnpack_recipes.py diff --git a/backends/transforms/duplicate_dynamic_quant_chain.py b/backends/transforms/duplicate_dynamic_quant_chain.py index 2ca65eec45f..48d3ef289ec 100644 --- a/backends/transforms/duplicate_dynamic_quant_chain.py +++ b/backends/transforms/duplicate_dynamic_quant_chain.py @@ -8,6 +8,7 @@ import operator import torch +from executorch.exir.program._program import _update_exported_program_graph_module from torch.ao.quantization.pt2e.utils import ( _filter_sym_size_users, @@ -194,3 +195,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph_module.graph.eliminate_dead_code() graph_module.recompile() return PassResult(graph_module, True) + +def duplicate_dynamic_quant_chain_pass( + ep: torch.export.ExportedProgram, +) -> torch.export.ExportedProgram: + res = DuplicateDynamicQuantChainPass()(ep.graph_module) + assert res is not None + return _update_exported_program_graph_module(ep, res.graph_module) diff --git a/backends/xnnpack/TARGETS b/backends/xnnpack/TARGETS index 4a3dfed7625..f3095fb3f15 100644 --- a/backends/xnnpack/TARGETS +++ b/backends/xnnpack/TARGETS @@ -37,5 +37,6 @@ runtime.python_library( ":xnnpack_preprocess", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/backends/xnnpack/utils:xnnpack_utils", + "//executorch/backends/xnnpack/recipes:xnnpack_recipes" ], ) diff --git a/backends/xnnpack/__init__.py b/backends/xnnpack/__init__.py index 6f4aafa8348..910c9842924 100644 --- a/backends/xnnpack/__init__.py +++ b/backends/xnnpack/__init__.py @@ -22,13 +22,14 @@ # XNNPACK Backend from .xnnpack_preprocess import XnnpackBackend - +from .recipes.recipes import get_xnnpack_recipe __all__ = [ "XnnpackDynamicallyQuantizedPartitioner", "XnnpackPartitioner", "XnnpackBackend", "capture_graph_for_xnnpack", + "get_xnnpack_recipe", "get_xnnpack_capture_config", "get_xnnpack_edge_compile_config", "get_xnnpack_executorch_backend_config", diff --git a/backends/xnnpack/recipes/TARGETS b/backends/xnnpack/recipes/TARGETS new file mode 100644 index 00000000000..6d0fadbc352 --- /dev/null +++ b/backends/xnnpack/recipes/TARGETS @@ -0,0 +1,18 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + + +oncall("executorch") + +python_library( + name = "xnnpack_recipes", + srcs = [ + "recipes.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/backends/transforms:duplicate_dynamic_quant_chain", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + ], +) diff --git a/backends/xnnpack/recipes/recipes.py b/backends/xnnpack/recipes/recipes.py new file mode 100644 index 00000000000..0e81f20a28b --- /dev/null +++ b/backends/xnnpack/recipes/recipes.py @@ -0,0 +1,51 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict +from typing import Any, Callable + +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( + duplicate_dynamic_quant_chain_pass, + DuplicateDynamicQuantChainPass, +) + +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.exir import ExportRecipe + +def get_generic_fp32_cpu_recipe() -> ExportRecipe: + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + return ExportRecipe( + name = "fp32_recipe", + quantizer = None, + partitioners=[XnnpackPartitioner()], + + ) + +def get_dynamic_quant_recipe() -> ExportRecipe: + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + quantizer.set_global(operator_config) + DuplicateDynamicQuantChainPass + return ExportRecipe( + name = "dynamic_quant_recipe", + quantizer = quantizer, + partitioners=[XnnpackPartitioner()], + pre_edge_transform_passes=duplicate_dynamic_quant_chain_pass, + ) + +RECIPE_MAP: dict[str, Callable[[], ExportRecipe]] = { + "FP32_CPU_ACCELERATED_RECIPE": get_generic_fp32_cpu_recipe, + "DYNAMIC_QUANT_CPU_ACCELERATED_RECIPE": get_dynamic_quant_recipe, +} + +def get_xnnpack_recipe(recipe_name:str, **kwargs: Any) -> ExportRecipe: + assert recipe_name in RECIPE_MAP, f"Recipe {recipe_name} not found." + return RECIPE_MAP[recipe_name](**kwargs) diff --git a/backends/xnnpack/targets.bzl b/backends/xnnpack/targets.bzl index e97f1941ff7..82f85722f88 100644 --- a/backends/xnnpack/targets.bzl +++ b/backends/xnnpack/targets.bzl @@ -64,7 +64,8 @@ def define_common_targets(): "//executorch/backends/xnnpack/serialization:xnnpack_flatbuffer_header", "//executorch/extension/threadpool:threadpool", "//executorch/runtime/core/exec_aten/util:tensor_util", - "//executorch/runtime/executor:pte_data_map" + "//executorch/runtime/executor:pte_data_map", + "//executorch/backends/xnnpack/recipes:xnnpack_recipes", ], # XnnpackBackend.cpp needs to compile with executor as whole # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index 9b2ce0a4e82..3370d94005e 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -93,3 +93,13 @@ runtime.python_test( "libtorch", ], ) + +runtime.python_test( + name = "test_xnnpack_recipes", + srcs = glob([ + "recipes/*.py", + ]), + deps = [ + "//executorch/backends/xnnpack:xnnpack_delegate", + ], +) diff --git a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py new file mode 100644 index 00000000000..8ccb39d1f11 --- /dev/null +++ b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py @@ -0,0 +1,37 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import unittest + +import torch +from executorch.backends.xnnpack import get_xnnpack_recipe +from executorch.export import export +from torch.testing._internal.common_quantization import TestHelperModules + +class TestXnnpackRecipes(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def tearDown(self) -> None: + super().tearDown() + + def test_basic_recipe(self) -> None: + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + export_session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=get_xnnpack_recipe("FP32_CPU_ACCELERATED_RECIPE") + ) + export_session.export() + + def test_dynamic_quant_recipe(self) -> None: + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + export_session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=get_xnnpack_recipe("DYNAMIC_QUANT_CPU_ACCELERATED_RECIPE") + ) + export_session.export()