Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: ONNX export failed: Couldn't export operator aten::adaptive_avg_pool2d #63

Closed
MBocchi opened this issue Aug 23, 2018 · 9 comments

Comments

@MBocchi
Copy link

MBocchi commented Aug 23, 2018

Hi I've tried to export PyTorch model to ONNX as following:
input_shape = (3, 224, 224)
model_onnx_path = "torch_model.onnx"
dummy_input = Variable(torch.randn(1, *input_shape).cuda())
output = torch.onnx.export(model, dummy_input, model_onnx_path, verbose=False)

And I got this error:
UserWarning: ONNX export failed on ATen operator adaptive_avg_pool2d because torch.onnx.symbolic.adaptive_avg_pool2d does not exist
RuntimeError: ONNX export failed: Couldn't export operator aten::adaptive_avg_pool2d

How I can fixed it? Thank you

@houseroad
Copy link
Member

@cnaaq
Copy link

cnaaq commented Dec 12, 2018

@MBocchi Did you resolve the problem?
I face the same problem and honestly, I didn't get that much from PytorchAddExportSupport.md.

@jjkislele
Copy link

@cnaaq @MBocchi Sorry for bothering you, perhaps I have found a solution.

Please see pytorch/pytorch#14395 (comment)

When you transform a Pytorch model to ONNX, using torch.onnx.export, you might add option operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK to allow you to use ops in Aten when cannot find it in ONNX operator set. It works for me, when I use torch.nn.AdaptiveAvgPool2d in a PyTorch network.

@fontno
Copy link

fontno commented Jan 30, 2020

@cnaaq @MBocchi Sorry for bothering you, perhaps I have found a solution.

Please see pytorch/pytorch#14395 (comment)

When you transform a Pytorch model to ONNX, using torch.onnx.export, you might add option operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK to allow you to use ops in Aten when cannot find it in ONNX operator set. It works for me, when I use torch.nn.AdaptiveAvgPool2d in a PyTorch network.

Appreciate this @jjkislele. Solved my problem.

@PatrickNa
Copy link

PatrickNa commented Sep 23, 2020

I faced the a similar issue as described above:

RuntimeError: ONNX export failed: Couldn't export operator aten::max_unpool2d

and successfully exported the model with the flag operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK. However it reappears when the model is loaded with onnxruntime afterwards:

Fail: [ONNXRuntimeError] : 1 : FAIL : Fatal error: ATen is not a registered function/op

I therefore do not consider this as a solution. Does anyone have an idea on what else could be tried out?

@jjkislele
Copy link

@PatrickNa
I'd like to say, yes, actually it is a bad way, as the same behaviour as the ostrich puts its head in a hole. Maybe you should raise an issue to pursue the developer for supporting the op you need. Maybe you should hack the source code inside and modify them to support the op you need.

In fact, I face this problem when I want to convert my PyTorch model to TensorRT/OpenVINO model via ONNX. I have tried another PyTorch pool operation: nn.AvgPooling instead of nn.AdaptiveAvgPool2d. So perhaps, you might replace the max pool function based on PyTorch.

Hope it would help you. :-D

@rGitcy
Copy link

rGitcy commented Oct 15, 2021

I use the code to convert Adaptivateavgpool2d --> Avgpool2, and i export onnx successful!
My code list:


import torch
import torch.nn as nn
import numpy as np

m = nn.AdaptiveAvgPool2d((None,1))
input = torch.randn(64, 26, 512, 1) #my inputs shape
output = m(input)
print(output.shape)

add code for adaptiveavgpool2d -> avgpoll2d

alist = torch.randn(64, 26, 512, 1)
inputsz = alist
outputsz = alist

stridesz = np.floor((inputsz / outputsz))
print("stride size", stridesz.shape)

kernelsz = inputsz - (outputsz - 1) * stridesz
print("kernal size", kernelsz.shape)
##test##
adp = nn.AdaptiveAvgPool2d(list(outputsz))
avg = nn.AvgPool2d(kernel_size=list(kernelsz), stride=list(stridesz))
adplist = adp(alist)
avglist = avg(alist)

print(alist)
print(adplist)
print(avglist)


@garymm
Copy link

garymm commented Mar 15, 2022

Some support for this op was added in pytorch/pytorch#17412.
Support for remaining cases is tracked by pytorch/pytorch#42653.

@jcwchen can you please close this issue?

@nvhungv2k
Copy link

nvhungv2k commented Jun 17, 2024

@cnaaq @MBocchi Sorry for bothering you, perhaps I have found a solution.

Please see pytorch/pytorch#14395 (comment)

When you transform a Pytorch model to ONNX, using torch.onnx.export, you might add option operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK to allow you to use ops in Aten when cannot find it in ONNX operator set. It works for me, when I use torch.nn.AdaptiveAvgPool2d in a PyTorch network.

Your suggest help me solve this and finally I exported *.onnx
But when I loaded it. I had the new problem that "Fatal error: ATen is not a registered function/op"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants