Skip to content

Commit c2dab19

Browse files
RahulC7facebook-github-bot
authored andcommitted
Changing logic to deal with graphs with derived quantization spec (#16357)
Summary: We want to add a test for `default_addmm_A8W8` to fully finish testing `CadenceDefaultQuantizer`. However there are a couple changes we need to make to the testing function. ## Change 1: We allow passing `None` in the vec of `QuantizationSpec` This is because the addmm op has 3 inputs: `bias`, `mat1`, `mat2`. The bias uses a `DerivedQuantizationSpec`, which is dynamically constructed with references to the actual graph nodes (`mat1` and `mat2`). We can't construct an identical `DerivedQuantizationSpec` in the test because we'd need to reference the exact same node objects that the quantizer creates internally. Since we can't compare it directly, we use `None` to skip validation for that input. If `mat1` and `mat2` are quantized correctly, the derived bias spec will be correct too. https://www.internalfb.com/code/fbsource/[2cfdb40fd8b628da2f46366115516408cfb9f50f]/xplat/executorch/backends/cadence/aot/quantizer/patterns.py?lines=91-103 ## Change 2: We changed how we iterate through `input_qspec_map` `input_qspec_map` is a dictionary mapping input nodes to their `qspecs`. The iteration order depends on insertion order, which follows how the quantizer processes `PartitionAnchors`. Each `QuantizationPattern` implements a `get_anchors()` method that returns a `PartitionAnchors` describing which arguments are inputs, weights, biases and nodes. This is relevant because for `addmm`, the `PartitionAnchors` lists them as `inputs=[(node, 1)], weights=[(node, 2)], biases=[(node, 0, ...)]. ` So the map might iterate in order `mat1, mat2, bias` (args indices 1, 2, 0) rather than `bias, mat1, mat2` (args indices 0, 1, 2). This means that our previous way of iterating wouldn't work. Thus, we now use the following way to iterate: ``` for input_node, input_qspec in annotation.input_qspec_map.items(): // Find the index of this input node in the op's args arg_index = None for i, arg in enumerate(op_node.args): if arg is input_node: arg_index = i break self.assertIsNotNone( arg_index, f"Input node {input_node} not found in op_node.args", ) # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec) if expected_input_qspecs[arg_index] is not None: self.assertEqual( input_qspec, expected_input_qspecs[arg_index], f"Input qspec mismatch at arg index {arg_index}", ) ``` The new code looks up which argument index each input_node corresponds to by searching through `op_node.args`, rather than assuming the enumeration index i matches the argument position. Reviewed By: hsharma35 Differential Revision: D88955761
1 parent 2c85e6b commit c2dab19

File tree

1 file changed

+55
-16
lines changed

1 file changed

+55
-16
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@
6464
# Test case definitions for quantizer annotation tests.
6565
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
6666
# Adding a new quantizer test only requires adding a tuple to this list.
67+
# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
6768
QUANTIZER_ANNOTATION_TEST_CASES: list[
6869
tuple[
6970
str,
7071
GraphBuilderFn,
7172
CadenceQuantizer,
7273
OpOverload,
7374
QuantizationSpec,
74-
list[QuantizationSpec],
75+
list[QuantizationSpec | None],
7576
]
7677
] = [
7778
(
@@ -192,6 +193,16 @@
192193
# For relu: only input_activation
193194
[qconfig_A8W8.input_activation],
194195
),
196+
(
197+
"default_addmm_A8W8",
198+
lambda self: self._build_addmm_graph(),
199+
CadenceDefaultQuantizer(),
200+
torch.ops.aten.addmm.default,
201+
qconfig_A8W8.output_activation,
202+
# For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
203+
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
204+
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
205+
),
195206
]
196207

197208
# Derive the set of tested quantizer classes from the test cases.
@@ -408,6 +419,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
408419
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
409420
return gm, relu_nodes[0]
410421

422+
def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
423+
"""Build a simple graph with an addmm operation."""
424+
builder = GraphBuilder()
425+
# addmm: bias + (mat1 @ mat2)
426+
# args: (bias, mat1, mat2)
427+
bias = builder.placeholder("bias", torch.randn(5))
428+
mat1 = builder.placeholder("mat1", torch.randn(1, 10))
429+
mat2 = builder.placeholder("mat2", torch.randn(10, 5))
430+
addmm = builder.call_operator(
431+
op=torch.ops.aten.addmm.default,
432+
args=(bias, mat1, mat2),
433+
meta=NodeMetadata(
434+
{"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]}
435+
),
436+
)
437+
builder.output([addmm])
438+
gm = builder.get_graph_module()
439+
440+
addmm_nodes = gm.graph.find_nodes(
441+
op="call_function",
442+
target=torch.ops.aten.addmm.default,
443+
)
444+
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
445+
return gm, addmm_nodes[0]
446+
411447
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
412448
def test_quantizer_annotation(
413449
self,
@@ -416,7 +452,7 @@ def test_quantizer_annotation(
416452
quantizer: CadenceQuantizer,
417453
target: OpOverload,
418454
expected_output_qspec: QuantizationSpec,
419-
expected_input_qspecs: list[QuantizationSpec],
455+
expected_input_qspecs: list[QuantizationSpec | None],
420456
) -> None:
421457
"""Parameterized test for quantizer annotations."""
422458
gm, op_node = graph_builder_fn(self)
@@ -431,21 +467,24 @@ def test_quantizer_annotation(
431467

432468
# Verify input annotations
433469
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
434-
for i, (input_node, input_qspec) in enumerate(
435-
annotation.input_qspec_map.items()
436-
):
437-
expected_arg = op_node.args[i]
438-
assert isinstance(expected_arg, torch.fx.Node)
439-
self.assertEqual(
440-
input_node,
441-
expected_arg,
442-
f"Input node mismatch at index {i}",
443-
)
444-
self.assertEqual(
445-
input_qspec,
446-
expected_input_qspecs[i],
447-
f"Input qspec mismatch at index {i}",
470+
for input_node, input_qspec in annotation.input_qspec_map.items():
471+
# Find the index of this input node in the op's args
472+
arg_index = None
473+
for i, arg in enumerate(op_node.args):
474+
if arg is input_node:
475+
arg_index = i
476+
break
477+
self.assertIsNotNone(
478+
arg_index,
479+
f"Input node {input_node} not found in op_node.args",
448480
)
481+
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
482+
if expected_input_qspecs[arg_index] is not None:
483+
self.assertEqual(
484+
input_qspec,
485+
expected_input_qspecs[arg_index],
486+
f"Input qspec mismatch at arg index {arg_index}",
487+
)
449488

450489
def test_all_quantizers_have_annotation_tests(self) -> None:
451490
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""

0 commit comments

Comments
 (0)