Skip to content
This repository has been archived by the owner on Jan 5, 2024. It is now read-only.

Onnx wrapper #59

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Onnx wrapper #59

wants to merge 6 commits into from

Conversation

mike-ferguson
Copy link
Member

ONNX Wrapper PR - Changes still incoming, initial commit.

Comment on lines +20 to +21
:param model: a keras model with a function `preprocess_input`
that will later be called on the loaded numpy image
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment accurate? looks like it pertains to keras rather than onnx?

return self._extractor(*args, **kwargs)

def get_activations(self, images, layer_names):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty newline after function definition, remove

Comment on lines +44 to +47
# create directory to store ONNX models
import os
if not os.path.exists("ONNX Partial Models"):
os.makedirs("ONNX Partial Models")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very ad-hoc, I would:

  1. move the definition of this directory onto the file level
  2. allow this to be changed from environment variables

i.e. ONNX_MODEL_DIRECTORY = os.getenv('ONNX_MODEL_DIRECTORY', '~/.onnx_models')

Comment on lines +50 to +51
onnx_model = self._model
model_name = self.identifier
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these re-assignments are unnecessary and imo more confusing than helpful, just use self._model and self.identifier in-place

for layer in layer_names:

# handle logits case - get last layer activations
if layer_names[0] == 'logits':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be if layer == 'logits':?

ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see -- this, together with the above run, should probably be its own method

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logging rather than print


def get_final_model(framework, batch_size, in_channels, image_size, model, model_name):

# print(batch_size, in_channels, image_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code



def get_final_model(framework, batch_size, in_channels, image_size, model, model_name):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete empty line(s)

layers = get_layers(onnx_model)
return onnx_model, layers

# unknown model format. In the future, I hope to add automatic conversion to ONNX for other platforms
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that future should be now for this PR to be merged, otherwise we're not winning anything

@mschrimpf
Copy link
Member

I looked through it a bit, but stopped after realizing that the current version of this PR only supports pytorch->onnx. Let me know when tf/keras are also supported and tested

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants