Skip to content

All-dim reduction operators fail with “IndexError: tuple index out of range” #13608

@GregoryComer

Description

@GregoryComer

🐛 Describe the bug

Various reduction ops, including amin, amax, argmin, and argmax, fail to lower when dim is None (meaning reduce over all dimensions). The current behavior is that lowering a model with such an operator will fail with an internal exception - "tuple index out of range". The expected behavior is that it will lower and reduce over all dimensions, either by partitioning the node or not partitioning it, if this feature is not supported on QNN.

Repro:

import torch

from executorch.backends.qualcomm.utils.utils import (
    generate_htp_compiler_spec,
    generate_qnn_executorch_compiler_spec,
    get_soc_to_chipset_map,
)
from executorch.backends.qualcomm.utils.utils import QcomChipset
from executorch.backends.qualcomm.utils.utils import (
    to_edge_transform_and_lower_to_qnn,
)
from executorch.exir import ExecutorchBackendConfig
from executorch.runtime import Runtime
from typing import Optional, Tuple, List, Union

class AmaxModel(torch.nn.Module):
    def __init__(
        self,
        dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
        keepdim: bool = False,
    ):
        super().__init__()
        self.dim = dim
        self.keepdim = keepdim

    def forward(self, x):
        return torch.amax(x, dim=self.dim, keepdim=self.keepdim)

inputs = (torch.randn(2,3,4),)
model = AmaxModel()

# HTP Compiler Configuration
backend_options = generate_htp_compiler_spec(
    use_fp16=True,
)

# QNN Compiler Spec
compile_spec = generate_qnn_executorch_compiler_spec(
    soc_model=QcomChipset.SM8650,  # Your target SoC
    backend_options=backend_options,
)


# Lower to QNN backend
delegated_program = to_edge_transform_and_lower_to_qnn(
    model,(
    inputs,
    compile_spec
)

# Export to ExecuTorch format
executorch_program = delegated_program.to_executorch(
    config=ExecutorchBackendConfig(extract_delegate_segments=False)
)

runtime = Runtime.get()
program = runtime.load_program(executorch_program.buffer)
method = program.load_method("forward")
output: List[torch.Tensor] = method.execute([*inputs])

Output:

Traceback (most recent call last):
  File "/home/gregory/src/executorch/test_qnn.py", line 45, in <module>
    delegated_program = to_edge_transform_and_lower_to_qnn(
...
  File "/home/gregory/miniconda3/envs/executorch/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py", line 87, in _is_node_supported
    return self.operator_support.is_node_supported(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gregory/src/executorch/src/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 100, in is_node_supported
    op_wrapper = self.node_visitors[node.target.__name__].define_node(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gregory/src/executorch/src/executorch/backends/qualcomm/builders/op_amax.py", line 43, in define_node
    mean_dims = cast(List[int], node.args[1])
                                ~~~~~~~~~^^^
IndexError: tuple index out of range

Versions

Commit 335de46

cc @cccclai @cbilgin

Metadata

Metadata

Labels

backend testerThis bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions