Skip to content

Commit 0b96757

Browse files
martinlsmMartin Lindström
andauthored
Prevent union-finding cycles for shared qspecs (#3011)
* Prevent union-finding cycles for shared qspecs Certain graphs caused infinite recursion when unioning nodes to groups based on shared quantization specs while implicit sharing was enabled. The problem occurred when `NodeOrEdge` A with its own `QuantizationSpec` received an edge (in `shared_with_map`) to `EdgeOrNode` B which in turn had a `SharedQuantizationSpec` pointing back to A. Remedy this problem by checking if B, from the scenario above, has a `SharedQuantizationSpec` pointing to A; if that is the case, don't union them together by letting A point back to B. Avoiding the union/cycle preserves correctness because the nodes are effectively already united. * Add test case of implicit sharing with two ops sharing one input - Add test case of implicit sharing for a model where one input is shared between two different ops. - Add code comments to `_union_if_no_cycle` * Simplify the model that reproduces the recursion bug * Avoid forming a cycle by reversing the edge * Add context of recursion bug in the test case * Swap order of branches in if statement --------- Co-authored-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent a53a4db commit 0b96757

File tree

2 files changed

+127
-9
lines changed

2 files changed

+127
-9
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD 3-Clause license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -1215,6 +1216,108 @@ def validate(self, model: torch.fx.GraphModule) -> None:
12151216
self.assertIsNot(observers[0], observers[2])
12161217
self.assertIsNot(observers[1], observers[2])
12171218

1219+
def test_allow_implicit_sharing_with_shared_input_edge(self):
1220+
"""This tests implicit sharing when an input edge x is shared between
1221+
two ops in the following manner:
1222+
1223+
/-----------------> eq(a, x) -> b
1224+
/ /
1225+
x -> clone(x) -> a -/
1226+
1227+
Clone is annotated such that (x, clone) uses a QuantizationSpec and
1228+
its output (clone) a SharedQuantizationSpec pointing to its input
1229+
(x, clone).
1230+
1231+
Eq is annotated such that (clone, eq) uses a QuantizationSpec and
1232+
(x, eq) uses a SharedQuantizationSpec to the former.
1233+
The output (eq) is not quantized (bool output).
1234+
1235+
Verify that the input to clone and its output share the same observer;
1236+
inputs to eq should also share that same observer due to implicit
1237+
sharing.
1238+
1239+
Context: This test used to trigger a cyclic recursion bug in the
1240+
following manner:
1241+
1) Processing edge (x, clone): implicit sharing sees that eq is
1242+
another user of x with an identical qspec, so (x, clone) starts
1243+
sharing with (x, eq) by pointing to it.
1244+
2) Processing edge (clone, eq): implicit sharing tries to share this
1245+
input edge with the producer output clone. But clone's output
1246+
uses SharedQuantizationSpec((x, clone)), and from step (1),
1247+
(x, clone) already points to (x, eq). Therefore unwrapping leads to
1248+
(x, eq) and (clone, eq) is set to share with (x, eq) by pointing to
1249+
it.
1250+
3) Processing edge (x, eq): when resolving its qspec, the algorithm
1251+
follows the shared reference to (clone, eq), which immediately
1252+
points back to (x, eq) from step (2). This created a cycle and the
1253+
unwrap logic recursed endlessly.
1254+
"""
1255+
1256+
class BackendAQuantizer(Quantizer):
1257+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1258+
for node in model.graph.nodes:
1259+
if node.target in [
1260+
torch.ops.aten.clone.default,
1261+
torch.ops.aten.eq.Tensor,
1262+
]:
1263+
input_qspec_map = {}
1264+
qspec = QuantizationSpec(
1265+
dtype=torch.uint8,
1266+
quant_min=0,
1267+
quant_max=255,
1268+
qscheme=torch.per_tensor_affine,
1269+
is_dynamic=False,
1270+
observer_or_fake_quant_ctr=observer.default_observer,
1271+
)
1272+
shared_qspec = SharedQuantizationSpec((node.args[0], node))
1273+
1274+
if node.target is torch.ops.aten.clone.default:
1275+
input_qspec_map[node.args[0]] = qspec
1276+
output_qspec = shared_qspec
1277+
elif node.target is torch.ops.aten.eq.Tensor:
1278+
input_qspec_map[node.args[0]] = qspec
1279+
input_qspec_map[node.args[1]] = shared_qspec
1280+
# Output is bool, quantization not applicable
1281+
output_qspec = None
1282+
else:
1283+
assert False
1284+
1285+
node.meta["quantization_annotation"] = QuantizationAnnotation(
1286+
input_qspec_map=input_qspec_map,
1287+
output_qspec=output_qspec,
1288+
allow_implicit_sharing=True,
1289+
_annotated=True,
1290+
)
1291+
1292+
def validate(self, model: torch.fx.GraphModule) -> None:
1293+
pass
1294+
1295+
class M(torch.nn.Module):
1296+
def forward(self, x):
1297+
a = x.clone()
1298+
b = torch.eq(a, x)
1299+
return b
1300+
1301+
m = M().eval()
1302+
example_inputs = (torch.randn(1, 5),)
1303+
m = torch.export.export(m, example_inputs, strict=True).module()
1304+
prepare_pt2e(m, BackendAQuantizer())
1305+
m(*example_inputs)
1306+
observers = []
1307+
for n in m.graph.nodes:
1308+
if n.target == torch.ops.aten.clone.default:
1309+
input_obs1 = getattr(m, n.args[0].target)
1310+
output_obs = getattr(m, next(iter(n.users)).target)
1311+
self.assertIs(input_obs1, output_obs)
1312+
observers.append(input_obs1)
1313+
if n.target == torch.ops.aten.eq.Tensor:
1314+
input_obs1 = getattr(m, n.args[0].target)
1315+
input_obs2 = getattr(m, n.args[1].target)
1316+
self.assertIs(input_obs1, input_obs2)
1317+
observers.append(input_obs1)
1318+
assert len(observers) == 2
1319+
self.assertIs(observers[0], observers[1])
1320+
12181321
@parametrize("dtype", (torch.float32, torch.bfloat16))
12191322
@parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
12201323
def test_quantization_dtype(self, dtype, quant_dtype):

torchao/quantization/pt2e/prepare.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD 3-Clause license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -146,18 +147,31 @@ def _union(
146147
parent: EdgeOrNode,
147148
child: EdgeOrNode,
148149
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
150+
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
149151
) -> None:
150152
"""Merge the subtree for `child` with `parent`, the order is important here"""
151153
root_parent = _find_root_edge_or_node(parent, shared_with_map)
152154
root_child = _find_root_edge_or_node(child, shared_with_map)
153-
# union the two trees by pointing the root of child to root of parent
154-
shared_with_map[root_child] = root_parent
155+
156+
parent_qspec = edge_or_node_to_qspec[root_parent]
157+
if (
158+
isinstance(parent_qspec, SharedQuantizationSpec)
159+
and parent_qspec.edge_or_node == root_child
160+
):
161+
# Parent already references child with a shared qspec. We would create
162+
# a cycle if we formed an edge from the child to the parent. Therefore,
163+
# we reverse the edge in this particular case.
164+
shared_with_map[root_parent] = root_child
165+
else:
166+
# union the two trees by pointing the root of child to root of parent
167+
shared_with_map[root_child] = root_parent
155168

156169

157170
def _update_shared_with(
158171
child: EdgeOrNode,
159172
qspec: QuantizationSpecBase,
160173
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
174+
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
161175
):
162176
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
163177
configuration and established the relationship between `edge_or_node` with the edge/node that it
@@ -167,7 +181,7 @@ def _update_shared_with(
167181
parent = qspec.edge_or_node
168182
# we point from edge_or_node to the node that it is sharing_with, e.g.
169183
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
170-
_union(parent, child, shared_with_map)
184+
_union(parent, child, shared_with_map, edge_or_node_to_qspec)
171185

172186

173187
def _unwrap_shared_qspec(
@@ -249,7 +263,7 @@ def _union_input_edge_with(
249263
# since dtype is the same (we may want to extend this to be a more strict check
250264
# in the future)
251265
# so we point from `input_edge` to `arg` (output of the argument)
252-
_union(edge_or_node, input_edge, shared_with_map)
266+
_union(edge_or_node, input_edge, shared_with_map, edge_or_node_to_qspec)
253267

254268

255269
def _get_edge_or_node_to_group_id(
@@ -311,7 +325,9 @@ def _get_edge_or_node_to_group_id(
311325
for edge_or_node, qspec in edge_or_node_to_qspec.items():
312326
if isinstance(edge_or_node, torch.fx.Node):
313327
output_node = edge_or_node
314-
_update_shared_with(output_node, qspec, shared_with_map)
328+
_update_shared_with(
329+
output_node, qspec, shared_with_map, edge_or_node_to_qspec
330+
)
315331
else:
316332
input_edge = edge_or_node
317333
input_edge_root_qspec = _unwrap_shared_qspec(
@@ -332,9 +348,6 @@ def _get_edge_or_node_to_group_id(
332348
# because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
333349
# Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
334350
# have a circular dependency
335-
# the following order works around this issue, but this does not allow arbitrary configuration
336-
# of sharing so it might break in a different case in the future, when it breaks
337-
# quantizer writer can check the notes here to debug the issue
338351

339352
# sharing with other users of the producer node
340353
# (arg, user)
@@ -363,7 +376,9 @@ def _get_edge_or_node_to_group_id(
363376
shared_with_map,
364377
)
365378

366-
_update_shared_with(input_edge, qspec, shared_with_map)
379+
_update_shared_with(
380+
input_edge, qspec, shared_with_map, edge_or_node_to_qspec
381+
)
367382

368383
# now that we get the sharing relations between all edges and nodes, we can assingn group ids
369384
cur_group_id = 0

0 commit comments

Comments
 (0)