Skip to content

Commit 6af0907

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W linear ops support and test (#13448)
Summary: - Adds linear ops test using the 16A8W config in INT16 profile. - Adds support in view ops validation for INT16 Dtype. - Validated with TOSA pipeline test. - Checked earlier marked flaky tests no longer flaky and remove markers. Note: Not verified with tosa reference model run. Reviewed By: digantdesai Differential Revision: D80308822
1 parent 1d93b76 commit 6af0907

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

backends/arm/operators/op_view.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[
48+
ts.DType.INT8,
49+
ts.DType.INT16,
50+
ts.DType.INT32,
51+
ts.DType.FP32,
52+
ts.DType.BOOL,
53+
],
4854
output.tosa_spec,
4955
)
5056

backends/arm/test/ops/test_linear.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from typing import Tuple
1010

1111
import pytest
12-
1312
import torch
14-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
1720
EthosU55PipelineINT,
@@ -20,6 +23,8 @@
2023
TosaPipelineINT,
2124
VgfPipeline,
2225
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2328

2429
aten_op = "torch.ops.aten.linear.default"
2530

@@ -143,7 +148,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
143148
pipeline.run()
144149

145150

146-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
147151
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
148152
def test_linear_tosa_INT(test_data: torch.Tensor):
149153
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -258,3 +262,64 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
258262
per_channel_quantization=per_channel_quantization,
259263
)
260264
pipeline.run()
265+
266+
267+
def get_symmetric_a16w8_linear_quantizer(
268+
u55_config=False, per_channel_quantization=False
269+
):
270+
tosa_version = conftest.get_option("tosa_version")
271+
tosa_profiles = {
272+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
273+
}
274+
275+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
276+
quantizer.set_global(
277+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
278+
)
279+
quantizer.set_module_type(
280+
torch.nn.Linear,
281+
get_symmetric_a16w8_quantization_config(
282+
is_per_channel=per_channel_quantization
283+
),
284+
)
285+
286+
return Quantize(
287+
quantizer,
288+
get_symmetric_a16w8_quantization_config(
289+
is_per_channel=per_channel_quantization
290+
),
291+
)
292+
293+
294+
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
295+
@pytest.mark.xfail(
296+
reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph"
297+
)
298+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
299+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
300+
test_data, out_features, has_bias, per_channel_quantization = test_data()
301+
in_features = test_data.shape[-1]
302+
303+
# Create pipeline with custom 16A8W quantization config
304+
pipeline = TosaPipelineINT[input_t1](
305+
Linear(
306+
in_features=in_features,
307+
out_features=out_features,
308+
bias=has_bias,
309+
),
310+
(test_data,),
311+
aten_op,
312+
exir_op=[],
313+
per_channel_quantization=per_channel_quantization,
314+
use_to_edge_transform_and_lower=True,
315+
tosa_extensions=["int16"],
316+
)
317+
318+
pipeline.change_args(
319+
"quantize",
320+
get_symmetric_a16w8_linear_quantizer(
321+
per_channel_quantization=per_channel_quantization
322+
),
323+
)
324+
# Run the pipeline
325+
pipeline.run()

0 commit comments

Comments
 (0)