Skip to content

Commit e4c01f2

Browse files
Nitin Jainfacebook-github-bot
authored andcommitted
Add 16A8W quantization configuration utility for ARM backend (#13175)
Summary: Pull Request resolved: #13175 This diff implements a 16A8W (16-bit activations, 8-bit weights) quantization configuration utility for the ExecutorTorch ARM backend, following the feedback from D79746479. ## Key Changes **1. New Quantization Configuration Function** - Add `get_16a8w_quantization_config()` in `fbcode/executorch/backends/arm/quantizer/arm_quantizer.py` - Provides 16-bit activations with HistogramObserver (better precision than 8A8W) - Maintains 8-bit weights with MinMaxObserver/PerChannelMinMaxObserver (memory efficient) - **Technically supported by TOSA through [EXT-INT16 extension/profile](https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d)** ## Benefits - **Better Precision**: 16-bit activations provide higher precision than 8-bit. Useful for carrying precision for recurring neural nets. Differential Revision: D79763381
1 parent ecb639a commit e4c01f2

File tree

1 file changed

+105
-2
lines changed

1 file changed

+105
-2
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,111 @@ def get_symmetric_quantization_config(
143143
return quantization_config
144144

145145

146+
@functools.lru_cache
147+
def get_16a8w_quantization_config(
148+
is_per_channel: bool = True,
149+
is_qat: bool = False,
150+
is_dynamic: bool = False,
151+
weight_qmin: int = -127,
152+
weight_qmax: int = 127,
153+
):
154+
"""
155+
16A8W quantization config: 16-bit activations, 8-bit weights.
156+
157+
This configuration provides better accuracy than 8A8W while maintaining
158+
reasonable memory usage through 8-bit weights.
159+
160+
Args:
161+
is_per_channel: Whether to use per-channel quantization for weights
162+
is_qat: Whether this is for Quantization Aware Training
163+
is_dynamic: Whether to use dynamic quantization
164+
weight_qmin: Minimum quantization value for weights
165+
weight_qmax: Maximum quantization value for weights
166+
167+
Returns:
168+
QuantizationConfig with 16-bit activations and 8-bit weights
169+
"""
170+
extra_args: Dict[str, Any] = {"eps": 2**-12}
171+
172+
# Setup observer/fake-quant for 16-bit activations
173+
if is_qat:
174+
if is_dynamic:
175+
act_observer_or_fake_quant_ctr = FakeQuantize
176+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
177+
averaging_constant=1
178+
)
179+
extra_args["observer"] = dynamic_quant_observer
180+
else:
181+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
182+
else:
183+
if is_dynamic:
184+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
185+
else:
186+
# HistogramObserver works well for 16-bit range
187+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
188+
189+
# 16-bit activation quantization spec
190+
act_quantization_spec = QuantizationSpec(
191+
dtype=torch.int16,
192+
quant_min=torch.iinfo(torch.int16).min, # -32768
193+
quant_max=torch.iinfo(torch.int16).max, # 32767
194+
qscheme=torch.per_tensor_affine,
195+
is_dynamic=is_dynamic,
196+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
197+
**extra_args,
198+
),
199+
)
200+
201+
# Setup quantization config for weights (same as 8A8W - use 8-bit weights)
202+
weight_qscheme = (
203+
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
204+
)
205+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
206+
MinMaxObserver
207+
)
208+
# Determine the right observer/fake-quant constructor
209+
if is_qat:
210+
# Set plain fake-quant with true min/max
211+
weight_observer_or_fake_quant_ctr = FakeQuantize
212+
else:
213+
# PTQ: set min/max observer
214+
weight_observer_or_fake_quant_ctr = (
215+
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
216+
)
217+
218+
weight_extra_args = {"eps": 2**-12}
219+
220+
# 8-bit weight quantization spec (keep weights at 8-bit for memory efficiency)
221+
weight_quantization_spec = QuantizationSpec(
222+
dtype=torch.int8,
223+
quant_min=weight_qmin,
224+
quant_max=weight_qmax,
225+
qscheme=weight_qscheme,
226+
ch_axis=0,
227+
is_dynamic=False,
228+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
229+
**weight_extra_args
230+
),
231+
)
232+
233+
bias_quantization_spec = None
234+
if is_dynamic:
235+
quantization_config = QuantizationConfig(
236+
act_quantization_spec, # 16-bit input activations
237+
None,
238+
weight_quantization_spec, # 8-bit weights
239+
bias_quantization_spec,
240+
)
241+
else:
242+
quantization_config = QuantizationConfig(
243+
act_quantization_spec, # 16-bit input activations
244+
act_quantization_spec, # 16-bit output activations
245+
weight_quantization_spec, # 8-bit weights
246+
bias_quantization_spec,
247+
)
248+
return quantization_config
249+
250+
146251
NodeFilterType = Callable[[Node], bool]
147252
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
148253
a Node and returns whether the node should be annotated or not.
@@ -216,11 +321,9 @@ def not_module_type_or_name_filter(n: Node) -> bool:
216321

217322

218323
class TOSAQuantizer(Quantizer):
219-
220324
def __init__(
221325
self, compile_spec_or_tosa_spec: Union[TosaSpecification, List[CompileSpec]]
222326
) -> None:
223-
224327
super().__init__()
225328
if isinstance(compile_spec_or_tosa_spec, TosaSpecification):
226329
self.tosa_spec = compile_spec_or_tosa_spec

0 commit comments

Comments
 (0)