-
Notifications
You must be signed in to change notification settings - Fork 687
Closed
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/
Description
🐛 Describe the bug
Various reduction ops, including amin, amax, argmin, and argmax, fail to lower when dim is None (meaning reduce over all dimensions). The current behavior is that lowering a model with such an operator will fail with an internal exception - "tuple index out of range". The expected behavior is that it will lower and reduce over all dimensions, either by partitioning the node or not partitioning it, if this feature is not supported on QNN.
Repro:
import torch
from executorch.backends.qualcomm.utils.utils import (
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
get_soc_to_chipset_map,
)
from executorch.backends.qualcomm.utils.utils import QcomChipset
from executorch.backends.qualcomm.utils.utils import (
to_edge_transform_and_lower_to_qnn,
)
from executorch.exir import ExecutorchBackendConfig
from executorch.runtime import Runtime
from typing import Optional, Tuple, List, Union
class AmaxModel(torch.nn.Module):
def __init__(
self,
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
keepdim: bool = False,
):
super().__init__()
self.dim = dim
self.keepdim = keepdim
def forward(self, x):
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
inputs = (torch.randn(2,3,4),)
model = AmaxModel()
# HTP Compiler Configuration
backend_options = generate_htp_compiler_spec(
use_fp16=True,
)
# QNN Compiler Spec
compile_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8650, # Your target SoC
backend_options=backend_options,
)
# Lower to QNN backend
delegated_program = to_edge_transform_and_lower_to_qnn(
model,(
inputs,
compile_spec
)
# Export to ExecuTorch format
executorch_program = delegated_program.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)
runtime = Runtime.get()
program = runtime.load_program(executorch_program.buffer)
method = program.load_method("forward")
output: List[torch.Tensor] = method.execute([*inputs])
Output:
Traceback (most recent call last):
File "/home/gregory/src/executorch/test_qnn.py", line 45, in <module>
delegated_program = to_edge_transform_and_lower_to_qnn(
...
File "/home/gregory/miniconda3/envs/executorch/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py", line 87, in _is_node_supported
return self.operator_support.is_node_supported(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gregory/src/executorch/src/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 100, in is_node_supported
op_wrapper = self.node_visitors[node.target.__name__].define_node(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gregory/src/executorch/src/executorch/backends/qualcomm/builders/op_amax.py", line 43, in define_node
mean_dims = cast(List[int], node.args[1])
~~~~~~~~~^^^
IndexError: tuple index out of range
Versions
Commit 335de46
Metadata
Metadata
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/