-
Notifications
You must be signed in to change notification settings - Fork 684
Closed
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/
Description
🐛 Describe the bug
When running TransposedConv2d with dilation > 1 on QNN, the outputs appear to differ from PyTorch eager mode by an amount much larger than would expected from normal deviation. It looks like the computation may be incorrect. The SNR (delegated vs eager) is negative.
Minimal repro:
import torch
from executorch.backends.qualcomm.utils.utils import (
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
get_soc_to_chipset_map,
)
from executorch.backends.qualcomm.utils.utils import QcomChipset
from executorch.backends.qualcomm.utils.utils import (
to_edge_transform_and_lower_to_qnn,
)
from executorch.exir import ExecutorchBackendConfig
from executorch.runtime import Runtime
from typing import Optional, Tuple, List, Union
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.ConvTranspose2d(2, 2, 3, dilation=2)
def forward(self, x):
return self.conv(x)
inputs = (torch.randn(2,7,7),)
model = Model()
# HTP Compiler Configuration
backend_options = generate_htp_compiler_spec(
use_fp16=True,
)
# QNN Compiler Spec
compile_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8650, # Your target SoC
backend_options=backend_options,
)
# Lower to QNN backend
delegated_program = to_edge_transform_and_lower_to_qnn(
model,
inputs,
compile_spec
)
# Export to ExecuTorch format
executorch_program = delegated_program.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)
runtime = Runtime.get()
program = runtime.load_program(executorch_program.buffer)
method = program.load_method("forward")
output: List[torch.Tensor] = method.execute([*inputs])
print(f"Eager: {model(*inputs)}")
print(f"QNN: {output[0]}")
Versions
Commit 335de46
Metadata
Metadata
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/