Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -1215,6 +1216,108 @@ def validate(self, model: torch.fx.GraphModule) -> None:
self.assertIsNot(observers[0], observers[2])
self.assertIsNot(observers[1], observers[2])

def test_allow_implicit_sharing_with_shared_input_edge(self):
"""This tests implicit sharing when an input edge x is shared between
two ops in the following manner:

/-----------------> eq(a, x) -> b
/ /
x -> clone(x) -> a -/

Clone is annotated such that (x, clone) uses a QuantizationSpec and
its output (clone) a SharedQuantizationSpec pointing to its input
(x, clone).

Eq is annotated such that (clone, eq) uses a QuantizationSpec and
(x, eq) uses a SharedQuantizationSpec to the former.
The output (eq) is not quantized (bool output).

Verify that the input to clone and its output share the same observer;
inputs to eq should also share that same observer due to implicit
sharing.

Context: This test used to trigger a cyclic recursion bug in the
following manner:
1) Processing edge (x, clone): implicit sharing sees that eq is
another user of x with an identical qspec, so (x, clone) starts
sharing with (x, eq) by pointing to it.
2) Processing edge (clone, eq): implicit sharing tries to share this
input edge with the producer output clone. But clone's output
uses SharedQuantizationSpec((x, clone)), and from step (1),
(x, clone) already points to (x, eq). Therefore unwrapping leads to
(x, eq) and (clone, eq) is set to share with (x, eq) by pointing to
it.
3) Processing edge (x, eq): when resolving its qspec, the algorithm
follows the shared reference to (clone, eq), which immediately
points back to (x, eq) from step (2). This created a cycle and the
unwrap logic recursed endlessly.
"""

class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in model.graph.nodes:
if node.target in [
torch.ops.aten.clone.default,
torch.ops.aten.eq.Tensor,
]:
input_qspec_map = {}
qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
shared_qspec = SharedQuantizationSpec((node.args[0], node))

if node.target is torch.ops.aten.clone.default:
input_qspec_map[node.args[0]] = qspec
output_qspec = shared_qspec
elif node.target is torch.ops.aten.eq.Tensor:
input_qspec_map[node.args[0]] = qspec
input_qspec_map[node.args[1]] = shared_qspec
# Output is bool, quantization not applicable
output_qspec = None
else:
assert False

node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_qspec,
allow_implicit_sharing=True,
_annotated=True,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def forward(self, x):
a = x.clone()
b = torch.eq(a, x)
return b

m = M().eval()
example_inputs = (torch.randn(1, 5),)
m = torch.export.export(m, example_inputs, strict=True).module()
prepare_pt2e(m, BackendAQuantizer())
m(*example_inputs)
observers = []
for n in m.graph.nodes:
if n.target == torch.ops.aten.clone.default:
input_obs1 = getattr(m, n.args[0].target)
output_obs = getattr(m, next(iter(n.users)).target)
self.assertIs(input_obs1, output_obs)
observers.append(input_obs1)
if n.target == torch.ops.aten.eq.Tensor:
input_obs1 = getattr(m, n.args[0].target)
input_obs2 = getattr(m, n.args[1].target)
self.assertIs(input_obs1, input_obs2)
observers.append(input_obs1)
assert len(observers) == 2
self.assertIs(observers[0], observers[1])

@parametrize("dtype", (torch.float32, torch.bfloat16))
@parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
def test_quantization_dtype(self, dtype, quant_dtype):
Expand Down
33 changes: 24 additions & 9 deletions torchao/quantization/pt2e/prepare.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -146,18 +147,31 @@ def _union(
parent: EdgeOrNode,
child: EdgeOrNode,
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
) -> None:
"""Merge the subtree for `child` with `parent`, the order is important here"""
root_parent = _find_root_edge_or_node(parent, shared_with_map)
root_child = _find_root_edge_or_node(child, shared_with_map)
# union the two trees by pointing the root of child to root of parent
shared_with_map[root_child] = root_parent

parent_qspec = edge_or_node_to_qspec[root_parent]
if (
isinstance(parent_qspec, SharedQuantizationSpec)
and parent_qspec.edge_or_node == root_child
):
# Parent already references child with a shared qspec. We would create
# a cycle if we formed an edge from the child to the parent. Therefore,
# we reverse the edge in this particular case.
Comment on lines +161 to +163
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be more confusing than just assign an ordering before hand?

Copy link
Contributor

@jerryzh168 jerryzh168 Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I haven't thought through, but wondering if it's possible that root_child can go around and end up pointing to root_parent again

Copy link
Contributor

@jerryzh168 jerryzh168 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kimishpatel is this the cycle detection you have in mind?

seems OK to me, if this is the only thing that's needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. Isnt the cycle already formed (parent_qspec.edge_or_node == root_child) before we come here. it feels we are detecting that and correcting it. I might be wrong though. I

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel Prior to forming the problematic union, we are in the state that is shown in figure below about to assign shared_with_map[(clone,eq)] = (x,eq). So there's not already a union we are correct. We are just reversing the green edge by assigning shared_with_map[(x,eq)] = (clone,eq) to make the edge point in the same direction as the blue one (edge_or_node_to_qspec).

prepare_state (2)

shared_with_map[root_parent] = root_child
else:
# union the two trees by pointing the root of child to root of parent
shared_with_map[root_child] = root_parent


def _update_shared_with(
child: EdgeOrNode,
qspec: QuantizationSpecBase,
shared_with_map: dict[EdgeOrNode, EdgeOrNode],
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
):
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
configuration and established the relationship between `edge_or_node` with the edge/node that it
Expand All @@ -167,7 +181,7 @@ def _update_shared_with(
parent = qspec.edge_or_node
# we point from edge_or_node to the node that it is sharing_with, e.g.
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
_union(parent, child, shared_with_map)
_union(parent, child, shared_with_map, edge_or_node_to_qspec)


def _unwrap_shared_qspec(
Expand Down Expand Up @@ -249,7 +263,7 @@ def _union_input_edge_with(
# since dtype is the same (we may want to extend this to be a more strict check
# in the future)
# so we point from `input_edge` to `arg` (output of the argument)
_union(edge_or_node, input_edge, shared_with_map)
_union(edge_or_node, input_edge, shared_with_map, edge_or_node_to_qspec)


def _get_edge_or_node_to_group_id(
Expand Down Expand Up @@ -311,7 +325,9 @@ def _get_edge_or_node_to_group_id(
for edge_or_node, qspec in edge_or_node_to_qspec.items():
if isinstance(edge_or_node, torch.fx.Node):
output_node = edge_or_node
_update_shared_with(output_node, qspec, shared_with_map)
_update_shared_with(
output_node, qspec, shared_with_map, edge_or_node_to_qspec
)
else:
input_edge = edge_or_node
input_edge_root_qspec = _unwrap_shared_qspec(
Expand All @@ -332,9 +348,6 @@ def _get_edge_or_node_to_group_id(
# because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
# Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
# have a circular dependency
# the following order works around this issue, but this does not allow arbitrary configuration
# of sharing so it might break in a different case in the future, when it breaks
# quantizer writer can check the notes here to debug the issue

# sharing with other users of the producer node
# (arg, user)
Expand Down Expand Up @@ -363,7 +376,9 @@ def _get_edge_or_node_to_group_id(
shared_with_map,
)

_update_shared_with(input_edge, qspec, shared_with_map)
_update_shared_with(
input_edge, qspec, shared_with_map, edge_or_node_to_qspec
)

# now that we get the sharing relations between all edges and nodes, we can assingn group ids
cur_group_id = 0
Expand Down
Loading