-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
op = str(node["op"]) | ||
if op not in MXNetGraph.registry_: | ||
raise AttributeError("No conversion function registered for op type %s yet." % op) | ||
convert_fun = MXNetGraph.registry_[op] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the name to convert_func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for awesome work @Roshrini @spidydev @anirudhacharya
Some comments below.
import mxnet as mx | ||
|
||
def load_module(json_path, params_path, input_shape): | ||
"""Loads the MXNet model file, retrieves symbol and parameters and returns. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: and returns MXNet symbol and params (weights).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
import logging | ||
import mxnet as mx | ||
|
||
def load_module(json_path, params_path, input_shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: json_path is too generic name for the function. Will be hard to maintain later. Can we more specific? sym_filepath, params_filepath or something like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Model weights including both arg and aux params. | ||
""" | ||
if not (os.path.isfile(json_path) and os.path.isfile(params_path)): | ||
raise ValueError("Provide valid path to the json and params file") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It is always useful to have specific Error/Warnings message on what is wrong and why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
raise ValueError("Provide valid path to the json and params file") | ||
else: | ||
try: | ||
model_name = json_path.rsplit('.', 1)[0].rsplit('-', 1)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I understand this logic reads symbol and epochs from sym.json file. But, please add code comment for this logic for future bug fixes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
model_name = json_path.rsplit('.', 1)[0].rsplit('-', 1)[0] | ||
num_epochs = int(params_path.rsplit('.', 1)[0].rsplit('-', 1)[1]) | ||
except IndexError: | ||
logging.info("Model and params name should be in format: " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is epoch necessary? Only for retraining the loaded model?
As a standard, saving a model need not have epoch number. Probably a necessary for saving checkpoint models. Though MXNet as of today mandates. But if we introduce a new API to save models without epochs attached, do we have any issue here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping epochs to 0 if not provided with the model name
name=name, | ||
epsilon=eps, | ||
momentum=momentum, | ||
spatial=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MXNET doesnt't have spatial Batch Norm , so actually should be set to 0. While importing ONNX model we will ignore this attribute. But might be an issue when exporting to caffe2/other frameworks that supports spatialBN , thanks for pointing.
# Creating a dictionary here, but if this titlecase pattern | ||
# mxnet_name.title() | ||
act_types = { | ||
"tanh": "Tanh", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only tanh and relu supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done !!
onnx_pad_width = [0]*num_pad_values | ||
|
||
start_index = 0 | ||
end_index = int(num_pad_values/2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: floor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MXNet pad values in pad op(https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.pad) is always multiple of two. Will add comment to clarify.
|
||
|
||
@mx_op.register("slice_axis") | ||
def convert_slice_axis(node, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just slice operator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slice operator will be added later. not used by any models tested yet :)
@@ -114,7 +114,7 @@ def maximum(attrs, inputs, proto_obj): | |||
for op_input in inputs[2:]: | |||
mxnet_op = symbol.maximum(mxnet_op, op_input) | |||
else: | |||
mxnet_op = inputs[0] | |||
mxnet_op = symbol.maximum(inputs[0], inputs[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maximum of same element?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, onnx has a case where if there is only one input, it returns that input itself as output. MXNet needs 2 inputs always
import logging | ||
import mxnet as mx | ||
|
||
def load_module(sym_filepath, params_filepath, input_shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: what is the purpose of this function why couldn't it be replaced by a simple:
sym = mx.sym.load(sym_filepath)
params = mx.nd.load(params_filepath)
return sym, params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sym.load
and nd.load
works to get model and params objects from files but if the model is trained using old version of mxnet, it wont upgrade the model. There will is a compatibility issue.
for example, some models has "param" or "attr" instead of "attrs" in json file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice corner case that is hard to think of 👍
model : str or symbol object | ||
Path to the json file or Symbol object | ||
weights : str or symbol object | ||
Path to the params file or Params object. (Including both arg_params and aux_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it a dictionary of Parameters or something else ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can be both , changed the desc to be more explicit.
from .export_helper import load_module | ||
|
||
|
||
def export_model(model, weights, input_shape, input_type=np.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weights -> params, to be consistent with the rest of the file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed.
return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy()) | ||
for k, v in weights_dict.items()]) | ||
|
||
def create_onnx_graph_proto(self, sym, params, in_shape, in_type, log=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
verbose=False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -18,3 +18,4 @@ | |||
|
|||
from ._import.import_model import import_model, get_model_metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does not make sense to me.
why do you want to put public function in private module folder _import
or _export
and include them later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import
is a reserved keyword, we cant have a folder called import. we can probably rename the two folders to onnx_import and onnx_export and make its member files private, except for the modules that we are exposing to the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
folder name changed to be public , _import --> onnx2mx , _export-->mx2onnx . also changed the files in the folder as per their usage
from .export_helper import load_module | ||
|
||
|
||
def export_model(model, weights, input_shape, input_type=np.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use verbose=False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# create module, passing cpu context | ||
ctx = context.cpu() | ||
test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None) | ||
test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label_shapes may not always be None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but the motive of this function is to just get the shape of the output after forward pass.
self.output_tensors = [] | ||
|
||
@staticmethod | ||
def register(op_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no doc for input and output through out static methods in this class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually , detailed info is only added for public api's.
op = node["op"] | ||
name = node["name"] | ||
if log: | ||
print("Converting idx: %d, op: %s, name: %s" % (idx, op, name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logging.xx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if log: | ||
print("Converting idx: %d, op: %s, name: %s" % (idx, op, name)) | ||
|
||
if op == "null" and name not in params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, the logic is a confusing here, better simplify it
|
||
|
||
@classmethod | ||
def prepare(cls, model, device='CPU', **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not directly use mx.cpu()? and it's in capital letter without careful handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method is declared by ONNX. The backends using ONNX test framework derives from "backend" class and implements these functions.
@@ -0,0 +1,98 @@ | |||
# Licensed to the Apache Software Foundation (ASF) under one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems to be added already, but what is the name python-pytest
?
there's already a folder tests/python/unittest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mxnet unittests uses nosetests but onnx backend test framework uses pytest, so keep them separate we created another folder for pytests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just remove the python-pytest folder, and use onnx, the name is pretty confusing and meaningless as an empty one containing only onnx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future if there is another component that is built into MXNet that uses pytest instead of nosetests, then what will we do?
the point of naming it python-pytest is separate them from the other tests which uses nosetests framework?
And this naming was part of a previous PR #9963 and was suggested by @marcoabreu during the review process.
params.update(arg_params) | ||
params.update(aux_params) | ||
|
||
onnx_file = model_path.rsplit('/', 1)[0] + "/exported_"+model_name+".onnx" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use + to concat path is not portable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
f28076a
to
9922a00
Compare
2. Refactored test framework to support ONNX backened tests. 2. Added Operator support: - Convolution2D - BatchNorm - Add
- Add, Sub, Mul, Div, Sum
- sigmoid, relu, pad( constant, edge, reflect), tanh - enabled corresponding ONNX backend tests.
Added Operators : Ceil, Floor
MaxPool, AvgPool, GlobalMaxPool, GlobalAvgPool, matmul
ArgMax, ArgMin, maximum, minimum
…dded only for these. Fixed logic error with convert_string_to_list()
a723644
to
43788cf
Compare
Changed underline files public or private as per usage Resolved conflicts with the latest
Added some error checking
@aaronmarkham Can you review docs part of this PR? |
@sandeep-krishnamurthy @zhreshold Thank you for reviewing the code. Addressed all the comments now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @Roshrini @spidydev. Great work! Will be very useful for users in combination with ONNX model zoo.
Will wait for other reviewers approval and doc approval.
@zhreshold @aaronmarkham @ThomasDelteil
LGTM. Merging the changes. |
@sandeep-krishnamurthy There are vetos in effect from @zhreshold. Per agreement among committers you are not supposed to merge it. |
@zhreshold could you take a look at this change again and see if your concerns are sufficiently addressed? |
Sorry about the late update, there's one minor issue need to be addressed. |
Since the conversation is really long, I might have missed some updates, please ping me directly if I am not responsive. Thanks! |
@zhreshold also please let me know how to ping you directly, do i do it on the slack channel? |
@zhreshold if this is a small issue to address then let's request a patch from the author. Could you create an issue? |
Since there is so much conversation in this thread, could you please list the open issues? |
@szha - I explicitly pinged all the reviewers and waited for 6 days, before merging the PR. I also tried to best of my ability to gather data from contributors if the suggested changes by other reviewers are addressed before merging the PR. |
I opened a new issue regarding my concerns in #11475 |
@sandeep-krishnamurthy thanks for the efforts. Please respect "request changes" as vetos nonetheless and try and reach @zhreshold, especially given that you sit in the same office. Much appreciated. |
* Resolve conflicts * Export module Test Framework * refactoring export to work with pretrained models * comments added * 1. Refactored export module. 2. Refactored test framework to support ONNX backened tests. 2. Added Operator support: - Convolution2D - BatchNorm - Add * Added Arithmetic operators: - Add, Sub, Mul, Div, Sum * Added operator support: - sigmoid, relu, pad( constant, edge, reflect), tanh - enabled corresponding ONNX backend tests. * Enabled ONNX tests: test_conv, test_basic_conv Added Operators : Ceil, Floor * Added support for: MaxPool, AvgPool, GlobalMaxPool, GlobalAvgPool, matmul * adding more operators * Added Operator support: ArgMax, ArgMin, maximum, minimum * Enabled more BASIC_MODEL tests * Added power operator tests * Added support for reshape. ONNX only supports 0, -1 special values. Added only for these. Fixed logic error with convert_string_to_list() * some tests enabled * enabling squeezenet * LRN Op support * mul_scalar modified to take scalar input * cleaning some code * Resolving conlicts on rebase * Resolving rebase conflicts * id mapping updated for all operators * save onnx models added, some code cleanup * enabled more tests * conv pad calc fixed * reshape op fix * Added support for elu, leakyRelu, prelu * Cleanup - Removed run_node, not needed anymore. - Used correct get_metadata api * valueinfoproto fix, googlenet test added * Removed redundant code. - run_node - Using correct get_metadata_api * dilation added * Lint fixes * lint fixes * some fixes to make export work with onx1.2.1 * enabled more tests * mxnet_export_test file added * duplicate file deleted * reduce ops added * some small fixes * some lint fixes * Add tests for inception_v1 and inception_v2 * Add CI runs for export module * docstring added * lint fixes, pooling attr fix * fix * fix global_pool * CI run fix * code cleanup * lint fix * some code cleanup * pad in pooling added * slicechannel notimplementederror raised * Added required license comments * Lint fixes * lint fix * lint fix * lint fix * lint fix * Correct license statement * Adding onnx a runtime dependency * Fix import module error for string_types * Making ONNX runtime dependency * fixing some comments * addressing some comments * params rename * lint fixes * fixes * spatial disabled, path fixed * fixing some comments * Added support for remaining act_type(softsign, sigmoid, softrelu) in Activation operator * changing import * adding some comments * Add squeeze op * Refactored logic to handle extra node(output label node) for saved mxnet model Added comments * minor fix for squeeze operator. Also, added error handling * identity operator added * scalar ops added * Renamed onnx support folders to mark it public folders Changed underline files public or private as per usage Resolved conflicts with the latest * Added support L2Normalization op Added some error checking * added comments and warning * added comments and warning * doc API ref added
Description
This PR has MXNet to ONNX exporter APIs to export MXNet trained models to ONNX protobuf so that those models can be imported in other frameworks for inference.
Test framework:
Currently, we import ONNX models in MXNet, then export them to ONNX, import it in MXNet again to verify inference results.
Working models:
@spidydev @anirudhacharya @piiswrong @sandeep-krishnamurthy @nswamy @anirudh2290
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes