Skip to content

Commit ca1cd7f

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W linear ops support and test
Summary: - Adds linear ops test using the 16A8W config in INT16 profile. - Adds support in view ops validation for INT16 Dtype. - Validated with TOSA pipeline test. Note: Not verified with tosa reference model run. Differential Revision: D80308822
1 parent b52a083 commit ca1cd7f

File tree

5 files changed

+43
-7
lines changed

5 files changed

+43
-7
lines changed

backends/arm/operators/op_view.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
4848
output.tosa_spec,
4949
)
5050

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def get_16a8w_quantization_config(
188188

189189
# 16-bit activation quantization spec
190190
act_quantization_spec = QuantizationSpec(
191-
dtype=torch.int32,
191+
dtype=torch.int16,
192192
quant_min=torch.iinfo(torch.int16).min, # -32768
193193
quant_max=torch.iinfo(torch.int16).max, # 32767
194194
qscheme=torch.per_tensor_affine,

backends/arm/test/ops/test_linear.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import pytest
1212

1313
import torch
14+
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
get_16a8w_quantization_config,
16+
)
1417
from executorch.backends.arm.test import common
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -258,3 +261,34 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
258261
per_channel_quantization=per_channel_quantization,
259262
)
260263
pipeline.run()
264+
265+
266+
@pytest.mark.xfail(
267+
reason="TOSA backend has limited INT16 support - view operations only support INT8/INT32/FP32/BOOL"
268+
)
269+
@common.parametrize("test_data", test_data_rank1_INT)
270+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
271+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
272+
test_data, out_features, has_bias, per_channel_quantization = test_data()
273+
in_features = test_data.shape[-1]
274+
275+
# Create pipeline with custom 16A8W quantization config
276+
pipeline = TosaPipelineINT[input_t1](
277+
Linear(
278+
in_features=in_features,
279+
out_features=out_features,
280+
bias=has_bias,
281+
),
282+
(test_data,),
283+
aten_op,
284+
exir_op=[],
285+
per_channel_quantization=per_channel_quantization,
286+
use_to_edge_transform_and_lower=True,
287+
quantization_config=get_16a8w_quantization_config(
288+
is_per_channel=per_channel_quantization
289+
),
290+
tosa_extensions=["int16"],
291+
)
292+
293+
# Run the pipeline
294+
pipeline.run()

backends/arm/test/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ addopts = --strict-markers
33
markers =
44
slow: Tests that take long time
55
tosa_ref_model: Tests that use TOSA reference model # Temporary!
6+
flaky: Tests that are known to be flaky

backends/arm/test/tester/test_pipeline.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def __init__(
108108
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
109109
] = None,
110110
):
111-
112111
self.tester = ArmTester(
113112
module,
114113
example_inputs=test_data,
@@ -341,6 +340,7 @@ def __init__(
341340
qtol: int = 1,
342341
dynamic_shapes: Optional[Tuple[Any]] = None,
343342
tosa_extensions: Optional[List[str]] = None,
343+
quantization_config: Optional[Any] = None,
344344
):
345345
if tosa_extensions is None:
346346
tosa_extensions = []
@@ -356,9 +356,11 @@ def __init__(
356356
)
357357

358358
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
359-
quantization_config = get_symmetric_quantization_config(
360-
is_per_channel=per_channel_quantization
361-
)
359+
# Use custom quantization config if provided, otherwise use default
360+
if quantization_config is None:
361+
quantization_config = get_symmetric_quantization_config(
362+
is_per_channel=per_channel_quantization
363+
)
362364
if symmetric_io_quantization:
363365
quantizer.set_io(quantization_config)
364366
quant_stage = Quantize(quantizer, quantization_config)
@@ -916,7 +918,6 @@ def __init__(
916918
] = None,
917919
tosa_extensions: Optional[List[str]] = None,
918920
):
919-
920921
if tosa_extensions is None:
921922
tosa_extensions = []
922923
tosa_spec = TosaSpecification.create_from_string(

0 commit comments

Comments
 (0)