Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail when converting the ReduceMax op type. Perhaps not considering the torch.max's return value #33

Closed
maybeLee opened this issue Dec 21, 2021 · 2 comments · Fixed by #37

Comments

@maybeLee
Copy link

I was trying to convert a Keras model to PyTorch through ONNX but failed.
The information of my targeted model is as follows

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 32, 3)]           0
_________________________________________________________________
global_max_pooling1d         (None, 3)                 0
_________________________________________________________________
dense (Dense)                (None, 96)                384
_________________________________________________________________
reshape (Reshape)            (None, 32, 3)             0
=================================================================
Total params: 384
Trainable params: 384
Non-trainable params: 0
_________________________________________________________________

The script to reproduce is as follows:

# Build the model
import tensorflow as tf
import tf2onnx
x = tf.keras.layers.Input((32,3))
in1 = tf.keras.layers.GlobalMaxPooling1D()(x)
in2 = tf.keras.layers.Dense(96)(in1)
y = tf.keras.layers.Reshape((32,3))(in2)
model = tf.keras.Model(x, y)
model.summary()

# Convert the model
input_shape = model.layers[0].input_shape[0]
spec = (tf.TensorSpec(input_shape, tf.float32, name="input"),)
_, _ = tf2onnx.convert.from_keras(model, input_signature=spec, \
        opset=15, output_path="temp.onnx")
from onnx2pytorch import ConvertModel
import onnx
onnx_model = onnx.load("temp.onnx")
torch_model = ConvertModel(onnx_model, experimental=True)

# Predict
import torch
import numpy as np
input = np.random.rand(10, *input_shape[1:])
input = torch.from_numpy(input)
torch_model.double()
pred = torch_model(input)
pred = pred.detach().numpy()
print("The prediction is: ", pred.shape)

You may access the code here:
https://colab.research.google.com/drive/1EtuxhHjy3QdmCf4v6DSpN9jsde2SeNpW?usp=sharing

The crash information is as follows:

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
   1846     if has_torch_function_variadic(input, weight, bias):
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 
   1850 

TypeError: linear(): argument 'input' (position 1) must be Tensor, not torch.return_types.max

Without very deep investigation, I assume this problem is caused by torch.max()'s output, which is torch.return_types.max instead of a torch tensor, while the linear layer expect the input to be a tensor.

I guess the fix of this bug would be change torch.max(**kwargs) to torch.max(**kwargs)[0] to make the output data type to be tensor. But I am new to this project and don't know how to write some fixable codes. Can you check if this is actually a bug and how can we fix it?

@calvinmccarter-at-lightmatter
Copy link
Contributor

@maybeLee -- PR #37 should fix this for you.

@maybeLee
Copy link
Author

maybeLee commented Jan 4, 2022

Thanks for your contribution

@maybeLee maybeLee closed this as completed Jan 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants