Skip to content

QNN Transposed Conv2d with Dilation appears Incorrect #13611

@GregoryComer

Description

@GregoryComer

🐛 Describe the bug

When running TransposedConv2d with dilation > 1 on QNN, the outputs appear to differ from PyTorch eager mode by an amount much larger than would expected from normal deviation. It looks like the computation may be incorrect. The SNR (delegated vs eager) is negative.

Minimal repro:

import torch

from executorch.backends.qualcomm.utils.utils import (
    generate_htp_compiler_spec,
    generate_qnn_executorch_compiler_spec,
    get_soc_to_chipset_map,
)
from executorch.backends.qualcomm.utils.utils import QcomChipset
from executorch.backends.qualcomm.utils.utils import (
    to_edge_transform_and_lower_to_qnn,
)
from executorch.exir import ExecutorchBackendConfig
from executorch.runtime import Runtime
from typing import Optional, Tuple, List, Union

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.ConvTranspose2d(2, 2, 3, dilation=2)

    def forward(self, x):
        return self.conv(x)

inputs = (torch.randn(2,7,7),)
model = Model()

# HTP Compiler Configuration
backend_options = generate_htp_compiler_spec(
    use_fp16=True,
)

# QNN Compiler Spec
compile_spec = generate_qnn_executorch_compiler_spec(
    soc_model=QcomChipset.SM8650,  # Your target SoC
    backend_options=backend_options,
)


# Lower to QNN backend
delegated_program = to_edge_transform_and_lower_to_qnn(
    model,
    inputs,
    compile_spec
)

# Export to ExecuTorch format
executorch_program = delegated_program.to_executorch(
    config=ExecutorchBackendConfig(extract_delegate_segments=False)
)

runtime = Runtime.get()
program = runtime.load_program(executorch_program.buffer)
method = program.load_method("forward")
output: List[torch.Tensor] = method.execute([*inputs])

print(f"Eager: {model(*inputs)}")
print(f"QNN: {output[0]}")

Versions

Commit 335de46

cc @cccclai @cbilgin

Metadata

Metadata

Labels

backend testerThis bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions