Skip to content
This repository has been archived by the owner on Mar 26, 2019. It is now read-only.

Output from MXNet model differs from sample from onnx/models #42

Open
ThomasDelteil opened this issue Mar 13, 2018 · 5 comments
Open

Output from MXNet model differs from sample from onnx/models #42

ThomasDelteil opened this issue Mar 13, 2018 · 5 comments
Labels

Comments

@ThomasDelteil
Copy link
Contributor

Hi,

I downloaded a model (inceptionv2) from the onnx/models repo, and loaded it in mxnet. I run the sample input data through it and got an output that differs from the output in models folder. Is that a known-issue?

That is the notebook I am using:
https://github.com/ThomasDelteil/Gluon_ONNX/blob/master/Fine-tuning_ONNX.ipynb

The code is reproduced below in case there is some obvious mistake:

import numpy as np
import onnx
import onnx_mxnet
import mxnet as mx
from collections import namedtuple

inception_v2 = "https://s3.amazonaws.com/download.onnx/models/inception_v2.tar.gz"
model_links = [inception_v2]
model_folder = "model"
model_name = "inception_v2"

# Download the models
for link in model_links:
    !mkdir -p $model_folder
    !wget -P $model_folder $link -nc -nv
# Extract the chosen model
!tar -xzf $model_folder/*.tar.gz -C $model_folder

#Helper function to load sample data
def load_sample(index=0):
    numpy_path = "{}/test_data_{}.npz".format(model_path, index)
    sample = np.load(numpy_path, encoding='bytes')
    inputs = sample['inputs'][0]
    outputs = sample['outputs'][0]
    return inputs, outputs

# load the model in MXNet using `onnx-mxnet`

# Set the file-paths
model_path = "{}/{}".format(model_folder, model_name)
onnx_path = "{}/model.onnx".format(model_path)

#Load the model and sample inputs and outputs
sym, params = onnx_mxnet.import_model(onnx_path)

# We pick the mxnet compute context:
ctx = mx.cpu()

# Get some sample data to infer the shapes
inputs, outputs = load_sample(0)

# By default, 'input_0' is an input of the imported model.
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=ctx, label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', inputs.shape)], label_shapes=None)
mod.set_params(arg_params=params, aux_params=params, allow_missing=False, allow_extra=False)

#Test the model using sample data
Batch = namedtuple('Batch', ['data'])

inputs, outputs = load_sample(0)

# forward on the provided data batch
mod.forward(Batch([mx.nd.array(inputs)]))
model_output = mod.get_outputs()

print(model_output[0][0][0])
[  2.84355319e-05]
<NDArray 1 @cpu(0)>
print(outputs[0][0])
0.00016256621
@rajanksin
Copy link
Collaborator

@ThomasDelteil Yes, it's a known issue. We are working towards solving these issues.

@ThomasDelteil
Copy link
Contributor Author

Do you know how 'bad' this is? in % of accuracy lost compared to the original model let's say on imagenet image classification tasks? Is it due to different implementation of layers, or floating point approximations?

Could I go ahead and start loading onnx model in MXNet for production use?

@rajanksin
Copy link
Collaborator

@ThomasDelteil These accuracy differences are due to issues with import operator implementation. ONNX has been changing/adding operators defs in last couple of months, we are actively working towards keeping at par with them. Currently, we focusing on:

  1. Moving onnx-mxnet inside Mxnet repo
  2. Fixing all operator related issues. ( Failures when testing against onnx backend tests)
  3. Fixing the standard models , "inception" being one of them.

We. are hoping to have fixed most of these issues by end of this month.

@ThomasDelteil
Copy link
Contributor Author

Thanks @spidydev for the update!

@rajanksin
Copy link
Collaborator

@ThomasDelteil the difference in accuracy is due to difference in mxnet average pool implementation.
apache/mxnet#10194

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

2 participants