Skip to content

Commit 126e8e8

Browse files
committed
Add some basic xnnpack recipes
Pull Request resolved: #10035 ghstack-source-id: 283781375 @exported-using-ghexport Differential Revision: [D72085170](https://our.internmc.facebook.com/intern/diff/D72085170/)
1 parent 52a5178 commit 126e8e8

File tree

8 files changed

+217
-1
lines changed

8 files changed

+217
-1
lines changed

backends/transforms/duplicate_dynamic_quant_chain.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import operator
99

1010
import torch
11+
from executorch.exir.program._program import _update_exported_program_graph_module
1112

1213
from torch.ao.quantization.pt2e.utils import (
1314
_filter_sym_size_users,
@@ -194,3 +195,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
194195
graph_module.graph.eliminate_dead_code()
195196
graph_module.recompile()
196197
return PassResult(graph_module, True)
198+
199+
200+
def duplicate_dynamic_quant_chain_pass(
201+
ep: torch.export.ExportedProgram,
202+
) -> torch.export.ExportedProgram:
203+
res = DuplicateDynamicQuantChainPass()(ep.graph_module)
204+
assert res is not None
205+
return _update_exported_program_graph_module(ep, res.graph_module)

backends/xnnpack/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ runtime.python_library(
3838
":xnnpack_preprocess",
3939
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
4040
"//executorch/backends/xnnpack/utils:xnnpack_utils",
41+
"//executorch/backends/xnnpack/recipes:xnnpack_recipes"
4142
],
4243
)

backends/xnnpack/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
XnnpackDynamicallyQuantizedPartitioner,
1010
XnnpackPartitioner,
1111
)
12+
from .recipes.recipes import get_xnnpack_recipe
1213

1314
# Exposed Configs in XNNPACK Package
1415
from .utils.configs import (
@@ -23,12 +24,12 @@
2324
# XNNPACK Backend
2425
from .xnnpack_preprocess import XnnpackBackend
2526

26-
2727
__all__ = [
2828
"XnnpackDynamicallyQuantizedPartitioner",
2929
"XnnpackPartitioner",
3030
"XnnpackBackend",
3131
"capture_graph_for_xnnpack",
32+
"get_xnnpack_recipe",
3233
"get_xnnpack_capture_config",
3334
"get_xnnpack_edge_compile_config",
3435
"get_xnnpack_executorch_backend_config",

backends/xnnpack/recipes/TARGETS

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
4+
oncall("executorch")
5+
6+
python_library(
7+
name = "xnnpack_recipes",
8+
srcs = [
9+
"recipes.py",
10+
],
11+
deps = [
12+
"//caffe2:torch",
13+
"//executorch/exir:lib",
14+
"//executorch/export:recipe",
15+
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
16+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
17+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
18+
],
19+
)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
from typing import Any, Callable
9+
10+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
11+
duplicate_dynamic_quant_chain_pass,
12+
)
13+
14+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
15+
16+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
17+
get_symmetric_quantization_config,
18+
XNNPACKQuantizer,
19+
)
20+
from executorch.export.recipe import ExportRecipe, QuantizationRecipe
21+
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight
22+
23+
24+
def get_generic_fp32_cpu_recipe() -> ExportRecipe:
25+
return ExportRecipe(
26+
name="fp32_recipe",
27+
quantization_recipe=None,
28+
partitioners=[XnnpackPartitioner()],
29+
)
30+
31+
32+
def get_dynamic_quant_recipe() -> ExportRecipe:
33+
# Create quantizer
34+
quantizer = XNNPACKQuantizer()
35+
operator_config = get_symmetric_quantization_config(
36+
is_per_channel=True, is_dynamic=True
37+
)
38+
quantizer.set_global(operator_config)
39+
40+
# Create quantization recipe
41+
quant_recipe = QuantizationRecipe(
42+
quantizer=quantizer,
43+
)
44+
45+
# Create export recipe
46+
return ExportRecipe(
47+
name="dynamic_quant_recipe",
48+
quantization_recipe=quant_recipe,
49+
partitioners=[XnnpackPartitioner()],
50+
pre_edge_transform_passes=duplicate_dynamic_quant_chain_pass,
51+
)
52+
53+
54+
def get_8a4w_config(group_size: int=32) -> ExportRecipe:
55+
# Create quantization recipe
56+
quant_recipe = QuantizationRecipe(
57+
quantizer=None,
58+
ao_base_config=[
59+
int8_dynamic_activation_int4_weight(group_size=32),
60+
],
61+
)
62+
63+
# Create export recipe
64+
return ExportRecipe(
65+
name="8a4w_quant_recipe",
66+
quantization_recipe=quant_recipe,
67+
partitioners=[XnnpackPartitioner()],
68+
)
69+
70+
71+
RECIPE_MAP: dict[str, Callable[[], ExportRecipe]] = {
72+
"FP32_CPU_ACCELERATED_RECIPE": get_generic_fp32_cpu_recipe,
73+
"DYNAMIC_QUANT_CPU_ACCELERATED_RECIPE": get_dynamic_quant_recipe,
74+
"8A4W_CPU_ACCELERATED_RECIPE": get_8a4w_config,
75+
}
76+
77+
78+
def get_xnnpack_recipe(recipe_name: str, **kwargs: Any) -> ExportRecipe:
79+
assert recipe_name in RECIPE_MAP, f"Recipe {recipe_name} not found."
80+
return RECIPE_MAP[recipe_name](**kwargs)

backends/xnnpack/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def define_common_targets():
6767
"//executorch/extension/threadpool:threadpool",
6868
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
6969
"//executorch/runtime/executor:pte_data_map" + aten_suffix,
70+
"//executorch/backends/xnnpack/recipes:xnnpack_recipes",
7071
],
7172
# XnnpackBackend.cpp needs to compile with executor as whole
7273
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)

backends/xnnpack/test/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,14 @@ runtime.python_test(
9494
"libtorch",
9595
],
9696
)
97+
98+
runtime.python_test(
99+
name = "test_xnnpack_recipes",
100+
srcs = glob([
101+
"recipes/*.py",
102+
]),
103+
deps = [
104+
"//executorch/backends/xnnpack:xnnpack_delegate",
105+
"//executorch/export:lib",
106+
],
107+
)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
from executorch.backends.xnnpack import get_xnnpack_recipe
13+
from executorch.exir.schema import DelegateCall, Program
14+
from executorch.export import export
15+
from torch import nn
16+
from torch.testing._internal.common_quantization import TestHelperModules
17+
18+
19+
class TestXnnpackRecipes(unittest.TestCase):
20+
def setUp(self) -> None:
21+
super().setUp()
22+
23+
def tearDown(self) -> None:
24+
super().tearDown()
25+
26+
def check_fully_delegated(self, program: Program) -> None:
27+
instructions = program.execution_plan[0].chains[0].instructions
28+
assert instructions is not None
29+
self.assertEqual(len(instructions), 1)
30+
self.assertIsInstance(instructions[0].instr_args, DelegateCall)
31+
32+
def test_basic_recipe(self) -> None:
33+
m_eager = TestHelperModules.TwoLinearModule().eval()
34+
example_inputs = [(torch.randn(9, 8),)]
35+
session = export(
36+
model=m_eager,
37+
example_inputs=example_inputs,
38+
export_recipe=get_xnnpack_recipe("FP32_CPU_ACCELERATED_RECIPE"),
39+
)
40+
self.assertTrue(
41+
torch.allclose(
42+
session.run_method("forward", example_inputs[0])[0],
43+
m_eager(*example_inputs[0]),
44+
)
45+
)
46+
self.check_fully_delegated(session.get_executorch_program())
47+
48+
def test_dynamic_quant_recipe(self) -> None:
49+
with torch.no_grad():
50+
m_eager = TestHelperModules.TwoLinearModule().eval()
51+
example_inputs = [(torch.randn(9, 8),)]
52+
session = export(
53+
model=m_eager,
54+
example_inputs=example_inputs,
55+
export_recipe=get_xnnpack_recipe(
56+
"DYNAMIC_QUANT_CPU_ACCELERATED_RECIPE"
57+
),
58+
)
59+
self.assertTrue(
60+
torch.allclose(
61+
session.run_method("forward", example_inputs[0])[0],
62+
m_eager(*example_inputs[0]),
63+
atol=1e-1,
64+
)
65+
)
66+
self.check_fully_delegated(session.get_executorch_program())
67+
68+
def test_8a4w_recipe(self) -> None:
69+
class SimpleLinearModel(nn.Module):
70+
def __init__(self) -> None:
71+
super(SimpleLinearModel, self).__init__()
72+
self.layer1 = nn.Linear(32, 2)
73+
74+
def forward(self, x) -> torch.Tensor:
75+
x = self.layer1(x)
76+
return x
77+
78+
model = SimpleLinearModel()
79+
example_inputs = [(torch.randn(1, 32),)]
80+
session = export(
81+
model=model,
82+
example_inputs=example_inputs,
83+
export_recipe=get_xnnpack_recipe(
84+
"8A4W_CPU_ACCELERATED_RECIPE", group_size=32
85+
),
86+
)
87+
self.assertTrue(
88+
torch.allclose(
89+
session.run_method("forward", example_inputs[0])[0],
90+
model(*example_inputs[0]),
91+
atol=1e-1,
92+
)
93+
)
94+
self.check_fully_delegated(session.get_executorch_program())

0 commit comments

Comments
 (0)