-
Notifications
You must be signed in to change notification settings - Fork 27
Onnx wrapper #59
base: master
Are you sure you want to change the base?
Onnx wrapper #59
Conversation
… instead of Martin's)
:param model: a keras model with a function `preprocess_input` | ||
that will later be called on the loaded numpy image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this comment accurate? looks like it pertains to keras rather than onnx?
return self._extractor(*args, **kwargs) | ||
|
||
def get_activations(self, images, layer_names): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty newline after function definition, remove
# create directory to store ONNX models | ||
import os | ||
if not os.path.exists("ONNX Partial Models"): | ||
os.makedirs("ONNX Partial Models") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is very ad-hoc, I would:
- move the definition of this directory onto the file level
- allow this to be changed from environment variables
i.e. ONNX_MODEL_DIRECTORY = os.getenv('ONNX_MODEL_DIRECTORY', '~/.onnx_models')
onnx_model = self._model | ||
model_name = self.identifier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these re-assignments are unnecessary and imo more confusing than helpful, just use self._model
and self.identifier
in-place
for layer in layer_names: | ||
|
||
# handle logits case - get last layer activations | ||
if layer_names[0] == 'logits': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be if layer == 'logits':
?
ort_outs = ort_session.run(None, ort_inputs) | ||
|
||
# compare ONNX Runtime and PyTorch results | ||
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see -- this, together with the above run, should probably be its own method
# compare ONNX Runtime and PyTorch results | ||
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) | ||
|
||
print("Exported model has been tested with ONNXRuntime, and the result looks good!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logging
rather than print
|
||
def get_final_model(framework, batch_size, in_channels, image_size, model, model_name): | ||
|
||
# print(batch_size, in_channels, image_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dead code
|
||
|
||
def get_final_model(framework, batch_size, in_channels, image_size, model, model_name): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete empty line(s)
layers = get_layers(onnx_model) | ||
return onnx_model, layers | ||
|
||
# unknown model format. In the future, I hope to add automatic conversion to ONNX for other platforms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that future should be now for this PR to be merged, otherwise we're not winning anything
I looked through it a bit, but stopped after realizing that the current version of this PR only supports pytorch->onnx. Let me know when tf/keras are also supported and tested |
ONNX Wrapper PR - Changes still incoming, initial commit.