-
Notifications
You must be signed in to change notification settings - Fork 713
Description
🐛 Describe the bug
I have noticed that the Arm backend is unable to quantize the following model due to getting stuck in an infinite loop:
class RecursionBugModel(torch.nn.Module):
def forward(self, x, y):
a = torch.minimum(x, y) # op 1
b = a == x # op 2
return b
test_data = {
"one_tensors": lambda: (torch.randn(1,), torch.randn(1,)),
}
I'm attempting to debug the issue but need help understanding things. I have seen that there must be two ops that involve SharedQuantizationSpecs. op 1 has one quantization spec for args[0] where both args[1] and the output share this quantization spec (torch.minimum in the example above). op 2 similarly has a quant spec for args[0] and args[1] shares it (output does not matter for this op). Also note that x is the first arg for op 1 and second arg for op 2.
When trying to quantize this model, we get stuck infinitely in _unwrap_shared_qspec with the following stack trace:
src/executorch/backends/arm/test/tester/test_pipeline.py:308: in run
super().run()
src/executorch/backends/arm/test/tester/test_pipeline.py:281: in run
raise e
src/executorch/backends/arm/test/tester/test_pipeline.py:278: in run
stage()
src/executorch/backends/arm/test/tester/test_pipeline.py:89: in __call__
self.func(*self.args, **self.kwargs)
src/executorch/backends/arm/test/tester/arm_tester.py:344: in quantize
return super().quantize(quantize_stage)
src/executorch/backends/test/harness/tester.py:195: in quantize
return self._run_stage(
src/executorch/backends/test/harness/tester.py:189: in _run_stage
stage_instance.run(prev_stage_artifact, inputs=inputs, *args, **kwargs) # noqa
src/executorch/backends/test/harness/stages/quantize.py:57: in run
prepared = prepare_pt2e(captured_graph, self.quantizer)
../../../.pyenv/versions/3.10.4/lib/python3.10/site-packages/torchao/quantization/pt2e/quantize_pt2e.py:119: in prepare_pt2e
model = prepare(
../../../.pyenv/versions/3.10.4/lib/python3.10/site-packages/torchao/quantization/pt2e/prepare.py:651: in prepare
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
../../../.pyenv/versions/3.10.4/lib/python3.10/site-packages/torchao/quantization/pt2e/prepare.py:332: in _get_edge_or_node_to_group_id
input_edge_root_qspec = _unwrap_shared_qspec(
../../../.pyenv/versions/3.10.4/lib/python3.10/site-packages/torchao/quantization/pt2e/prepare.py:194: in _unwrap_shared_qspec
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
../../../.pyenv/versions/3.10.4/lib/python3.10/site-packages/torchao/quantization/pt2e/prepare.py:194: in _unwrap_shared_qspec
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
E RecursionError: maximum recursion depth exceeded while calling a Python object
The reason for getting stuck here because there is a cycle where edge_or_node_to_qspec[(x, eq)] contains SharedQuantizationSpec(edge_or_node=(minimum, eq)) and shared_with_map[(minimum, eq)] contains (x, eq). Thus there will be a cycle when _unwrap_shared_qspec because it uses the two dicts for lookup. Here is a full print of the dicts:
shared_with_map={(x, minimum): (x, eq), (y, minimum): (x, eq), minimum: (x, eq), (minimum, eq): (x, eq), (x, eq): (x, eq)}
edge_or_node_to_qspec:
(x, minimum) : QuantizationSpec(dtype=torch.int8, observer_or_fake_quant_ctr=functools.partial(<class 'torchao.quantization.pt2e.observer.HistogramObserver'>, eps=0.000244140625){}, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_affine, ch_axis=None, is_dynamic=False)
(y, minimum) : SharedQuantizationSpec(edge_or_node=(x, minimum))
minimum : SharedQuantizationSpec(edge_or_node=(x, minimum))
(minimum, eq) : QuantizationSpec(dtype=torch.int8, observer_or_fake_quant_ctr=functools.partial(<class 'torchao.quantization.pt2e.observer.HistogramObserver'>, eps=0.000244140625){}, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_affine, ch_axis=None, is_dynamic=False)
(x, eq) : SharedQuantizationSpec(edge_or_node=(minimum, eq))
The code I'm referring to is in the torchao repo but I'm not sure if we could annotate the graph differently so that we don't end up in this situtation. It is mentioned in torchao/quantization/pt2e/prepare.py#L323 that they work around getting cyclic dependencies in this algorithm and that it doesn't work for all inputs and I believe this is such a case.
A thing I tried was to set the flag allow_implicit_sharing to false in the Arm quantizer, and that resolves the problem. But then the question is, could it have some other side effect to just disable this flag globally? I see no other backend that has disabled implicit sharing.
So... Does anyone here know what to do here to tackle this issue?
Versions
Collecting environment information...
PyTorch version: 2.9.0.dev20250811+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.31.6
Libc version: glibc-2.39
Python version: 3.10.4 (main, May 5 2025, 09:27:44) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-78-generic-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 14
On-line CPU(s) list: 0-13
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) Ultra 7 165U
CPU family: 6
Model: 170
Thread(s) per core: 2
Core(s) per socket: 12
Socket(s): 1
Stepping: 4
CPU(s) scaling MHz: 38%
CPU max MHz: 4900.0000
CPU min MHz: 400.0000
BogoMIPS: 5376.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb intel_ppin ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq tme rdpid bus_lock_detect movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr ibt flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 352 KiB (10 instances)
L1i cache: 640 KiB (10 instances)
L2 cache: 10 MiB (5 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-13
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS Not affected; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] executorch==0.8.0a0+0988ad6
[pip3] flake8==6.1.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==24.4.26
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy==1.14.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch_tokenizers==0.1.0
[pip3] torch==2.9.0.dev20250811+cpu
[pip3] torchao==0.13.0+git1526dfe50
[pip3] torchaudio==2.8.0.dev20250811+cpu
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.24.0.dev20250811+cpu
[pip3] triton==3.3.0
[conda] Could not collect
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @kimishpatel @jerryzh168