|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | 2 | # All rights reserved. |
| 3 | +# Copyright 2025 Arm Limited and/or its affiliates. |
3 | 4 | # |
4 | 5 | # This source code is licensed under the BSD 3-Clause license found in the |
5 | 6 | # LICENSE file in the root directory of this source tree. |
@@ -1215,6 +1216,108 @@ def validate(self, model: torch.fx.GraphModule) -> None: |
1215 | 1216 | self.assertIsNot(observers[0], observers[2]) |
1216 | 1217 | self.assertIsNot(observers[1], observers[2]) |
1217 | 1218 |
|
| 1219 | + def test_allow_implicit_sharing_with_shared_input_edge(self): |
| 1220 | + """This tests implicit sharing when an input edge x is shared between |
| 1221 | + two ops in the following manner: |
| 1222 | +
|
| 1223 | + /-----------------> eq(a, x) -> b |
| 1224 | + / / |
| 1225 | + x -> clone(x) -> a -/ |
| 1226 | +
|
| 1227 | + Clone is annotated such that (x, clone) uses a QuantizationSpec and |
| 1228 | + its output (clone) a SharedQuantizationSpec pointing to its input |
| 1229 | + (x, clone). |
| 1230 | +
|
| 1231 | + Eq is annotated such that (clone, eq) uses a QuantizationSpec and |
| 1232 | + (x, eq) uses a SharedQuantizationSpec to the former. |
| 1233 | + The output (eq) is not quantized (bool output). |
| 1234 | +
|
| 1235 | + Verify that the input to clone and its output share the same observer; |
| 1236 | + inputs to eq should also share that same observer due to implicit |
| 1237 | + sharing. |
| 1238 | +
|
| 1239 | + Context: This test used to trigger a cyclic recursion bug in the |
| 1240 | + following manner: |
| 1241 | + 1) Processing edge (x, clone): implicit sharing sees that eq is |
| 1242 | + another user of x with an identical qspec, so (x, clone) starts |
| 1243 | + sharing with (x, eq) by pointing to it. |
| 1244 | + 2) Processing edge (clone, eq): implicit sharing tries to share this |
| 1245 | + input edge with the producer output clone. But clone's output |
| 1246 | + uses SharedQuantizationSpec((x, clone)), and from step (1), |
| 1247 | + (x, clone) already points to (x, eq). Therefore unwrapping leads to |
| 1248 | + (x, eq) and (clone, eq) is set to share with (x, eq) by pointing to |
| 1249 | + it. |
| 1250 | + 3) Processing edge (x, eq): when resolving its qspec, the algorithm |
| 1251 | + follows the shared reference to (clone, eq), which immediately |
| 1252 | + points back to (x, eq) from step (2). This created a cycle and the |
| 1253 | + unwrap logic recursed endlessly. |
| 1254 | + """ |
| 1255 | + |
| 1256 | + class BackendAQuantizer(Quantizer): |
| 1257 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 1258 | + for node in model.graph.nodes: |
| 1259 | + if node.target in [ |
| 1260 | + torch.ops.aten.clone.default, |
| 1261 | + torch.ops.aten.eq.Tensor, |
| 1262 | + ]: |
| 1263 | + input_qspec_map = {} |
| 1264 | + qspec = QuantizationSpec( |
| 1265 | + dtype=torch.uint8, |
| 1266 | + quant_min=0, |
| 1267 | + quant_max=255, |
| 1268 | + qscheme=torch.per_tensor_affine, |
| 1269 | + is_dynamic=False, |
| 1270 | + observer_or_fake_quant_ctr=observer.default_observer, |
| 1271 | + ) |
| 1272 | + shared_qspec = SharedQuantizationSpec((node.args[0], node)) |
| 1273 | + |
| 1274 | + if node.target is torch.ops.aten.clone.default: |
| 1275 | + input_qspec_map[node.args[0]] = qspec |
| 1276 | + output_qspec = shared_qspec |
| 1277 | + elif node.target is torch.ops.aten.eq.Tensor: |
| 1278 | + input_qspec_map[node.args[0]] = qspec |
| 1279 | + input_qspec_map[node.args[1]] = shared_qspec |
| 1280 | + # Output is bool, quantization not applicable |
| 1281 | + output_qspec = None |
| 1282 | + else: |
| 1283 | + assert False |
| 1284 | + |
| 1285 | + node.meta["quantization_annotation"] = QuantizationAnnotation( |
| 1286 | + input_qspec_map=input_qspec_map, |
| 1287 | + output_qspec=output_qspec, |
| 1288 | + allow_implicit_sharing=True, |
| 1289 | + _annotated=True, |
| 1290 | + ) |
| 1291 | + |
| 1292 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 1293 | + pass |
| 1294 | + |
| 1295 | + class M(torch.nn.Module): |
| 1296 | + def forward(self, x): |
| 1297 | + a = x.clone() |
| 1298 | + b = torch.eq(a, x) |
| 1299 | + return b |
| 1300 | + |
| 1301 | + m = M().eval() |
| 1302 | + example_inputs = (torch.randn(1, 5),) |
| 1303 | + m = torch.export.export(m, example_inputs, strict=True).module() |
| 1304 | + prepare_pt2e(m, BackendAQuantizer()) |
| 1305 | + m(*example_inputs) |
| 1306 | + observers = [] |
| 1307 | + for n in m.graph.nodes: |
| 1308 | + if n.target == torch.ops.aten.clone.default: |
| 1309 | + input_obs1 = getattr(m, n.args[0].target) |
| 1310 | + output_obs = getattr(m, next(iter(n.users)).target) |
| 1311 | + self.assertIs(input_obs1, output_obs) |
| 1312 | + observers.append(input_obs1) |
| 1313 | + if n.target == torch.ops.aten.eq.Tensor: |
| 1314 | + input_obs1 = getattr(m, n.args[0].target) |
| 1315 | + input_obs2 = getattr(m, n.args[1].target) |
| 1316 | + self.assertIs(input_obs1, input_obs2) |
| 1317 | + observers.append(input_obs1) |
| 1318 | + assert len(observers) == 2 |
| 1319 | + self.assertIs(observers[0], observers[1]) |
| 1320 | + |
1218 | 1321 | @parametrize("dtype", (torch.float32, torch.bfloat16)) |
1219 | 1322 | @parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn)) |
1220 | 1323 | def test_quantization_dtype(self, dtype, quant_dtype): |
|
0 commit comments