Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for op_type "Resize" #35

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft

Support for op_type "Resize" #35

wants to merge 15 commits into from

Conversation

jnnks
Copy link
Contributor

@jnnks jnnks commented Jun 27, 2022

WIP Implementation for Resize Operator

fixes #34

@jnnks
Copy link
Contributor Author

jnnks commented Jun 27, 2022

@seanmor5 Is this the right place for the method? :D

@seanmor5
Copy link
Contributor

@jnnks Yes!

Copy link
Contributor

@seanmor5 seanmor5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnnks This looks good so far! I added some review comments. Also, ONNX provides a lot of built-in node tests if you want to use those to verify the operator behaves as expected! Check deserialize_test for examples

lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
Co-authored-by: Sean Moriarity <smoriarity.5@gmail.com>
@jnnks
Copy link
Contributor Author

jnnks commented Jul 1, 2022

Regarding the tests: Where can I get the backend-test-tools?
mix test cannot find System.cmd("backend-test-tools", ["generate-data"], [])

@seanmor5
Copy link
Contributor

seanmor5 commented Jul 1, 2022

I believe it comes with ONNX runtime Python package!

@jnnks
Copy link
Contributor Author

jnnks commented Jul 8, 2022

I am mostly working to comply with the unit tests now. Couple of questions and comments:

  • roi, scales, sizes are input layers (%Axon{}), how can I get the values of interest like above with params?
  • roi = region of interest? That means we have to crop the Tensor, right? Is there a function I can use for it?
    • Looks like Nx.slice may be helpful here
  • If I give all the values the first test (test_resize_downsample_scales_cubic) is expecting, I get an error I don't understand: ** (ArgumentError) cannot broadcast tensor of dimensions {1, 1, 3, 3} to {1, 1, 8, 8}

Seems like this will take a bit longer than expected :D

@NotQuiteLagom
Copy link

Is there something I could do to help here? I could use this PR but it seems "frozen". Any major changes in the approach/library?

@seanmor5
Copy link
Contributor

@jnnks @NotQuiteLagom Sorry I let this one fall off. There is an open issue upstream to fix Axon's resize operator. I will add it in later as I have a working implementation from another library, then I will come back to this PR for review :)

Copy link
Contributor

@seanmor5 seanmor5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnnks I just updated Axon's resize implementation. It only works with 2D spatial inputs (so need output shapes with rank 4). These changes should help get you closer.

Note that constant! forces the input you're giving it to be a constant, so you can always convert it into a flat list. We enforce this at certain places because Nx can't handle dynamic shapes in certain instances, but usually this constraint isn't a big deal.

lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
@jnnks
Copy link
Contributor Author

jnnks commented Jul 31, 2022

Note that constant! forces the input you're giving it to be a constant, so you can always convert it into a flat list

This does not work unfortunately. scales and sizes are not constant in the unit tests:

** (ArgumentError) unable to build model from ONNX graph, expected value scales to be constant value, but was :input

And when using the input! macro, the call to_flat_list fails, as scales is an Axon, no Tensor or number:

 ** (ArgumentError) expected a %Nx.Tensor{} or a number, got: 
     -------------------------------------------------------------------------------
                                          Model
     ===============================================================================
      Layer              Shape   Policy              Parameters   Parameters Memory
     ===============================================================================
      scales ( input )   {4}     p=f32 c=f32 o=f32   0            0 bytes
     -------------------------------------------------------------------------------
     Total Parameters: 0
     Total Parameters Memory: 0 bytes
     Inputs: %{"scales" => {4}}

lib/axon_onnx/deserialize.ex Outdated Show resolved Hide resolved
@jnnks
Copy link
Contributor Author

jnnks commented Aug 1, 2022

@seanmor5 I need your wisdom :)

@jnnks
Copy link
Contributor Author

jnnks commented Aug 10, 2022

The values from scales and sizes are not known during deserialization and are only available during inference.
So I am going for an Axon.layer with Axon.Layers.resize inside the lambda function.
This does not work either, since I don't know how to get the output shape as a tuple instead of a Tensor.

(I also removed a lot of unimportant features that will be added again later, when the function actually works)

Comment on lines 2042 to 2053
# resize function
fun = fn input, scales, _opts ->
# this will return a Tensor, but we need a tuple
# we are in Defn env now, so Nx.to_flat_list() does not work
output_size = input
|> Nx.shape()
|> Tuple.to_list()
|> Nx.tensor()
|> Nx.multiply(scales)
|> IO.inspect()

Axon.Layers.resize(input, size: output_size, method: method)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails, since Axon.Layers.resize wants to have a Tuple instead of a Tensor (output_size)

@seanmor5
Copy link
Contributor

@jnnks If the tests depend on non-constant inputs and raise, then it's okay to ignore them in that case. For a lot of functions we cannot implement the tests because of this dependency on constants, a better test is probably an integration test which involves a model that uses a Resize

@jnnks
Copy link
Contributor Author

jnnks commented Aug 11, 2022

It seems counter intuitive to me to use an industry standard and then deviate from it. Especially since the standard has well defined unit tests this seems strange to me.

I'll convert the predefined onnx models to use constant inputs so we can test against those.

@seanmor5
Copy link
Contributor

We don't really have a choice, Nx does not support dynamic input shapes and so we need to enforce this constraint. In most cases it ends up being fine, there are only a few edge cases you encounter in practice where dynamic shapes are needed and for the most part they can be mitigated or refactored to meet a static shape requirement

@jnnks
Copy link
Contributor Author

jnnks commented Sep 19, 2022

Looks like 'Axon.resize' does not deliver the results expected by the 'onnx' test case (test_resize_downsample_scales_nearest).
I rebuilt one of the onnx model files to use constant inputs and it seems to work fine, however the test case fails because of mismatched outputs:

Input

#Nx.Tensor<
  f32[1][1][2][4]
  [
    [
      [
        [1.0, 2.0, 3.0, 4.0],
        [5.0, 6.0, 7.0, 8.0]
      ]
    ]
  ]
>  

Scales

#Nx.Tensor<
  f32[4]
  [1.0, 1.0, 0.6000000238418579, 0.6000000238418579]
>

Output

     ** (RuntimeError) expected #Nx.Tensor<
       f32[1][1][1][2]
       [
         [
           [
             [6.0, 8.0]
           ]
         ]
       ]
     > to be within tolerance of #Nx.Tensor<
       f32[1][1][1][2]
       [
         [
           [
             [1.0, 3.0]
           ]
         ]
       ]
     >

...[1.0, 3.0]... is the output desired by onnx.

Since the scales are both 0.6 I would argue that onnx is expecting the wrong output.
What are your thoughts on this @seanmor5 ?

@seanmor5
Copy link
Contributor

@jnnks I think you might be right, but is there a parameter we're not handling that might affect the behavior here?

@jnnks
Copy link
Contributor Author

jnnks commented Sep 28, 2022

Not as far as I can see, but I will check the spec again for implicit defaults, that may be causing it.
Maybe it's a recurring pattern in the other tests as well.

@jnnks
Copy link
Contributor Author

jnnks commented Mar 26, 2023

We have two problems:

  • ONNX expects different values from what Axon.Layers.resize returns [1]
  • ONNX uses the scales and sizes as input values instead of constants in the graph [2]

To make the operator testable we need to do the following:

  • decode the test model and test data
  • decode the test dataset
    • can be done with OnnxTestHelper.pb_to_tensor(pb_path)
  • inject the scales/sizes into the decoded model text file [3]
  • reencode the model file also using protoc
  • run Axon.Layers.resize with the test data to get the expected resize result
  • reencode the test data, not sure how yet

[1] Different return values
input: #Nx.Tensor<
  f32[1][1][2][4]
  [
    [
      [
        [1.0, 2.0, 3.0, 4.0],
        [5.0, 6.0, 7.0, 8.0]
      ]
    ]
  ]
>

resize with: {"scales",
 #Nx.Tensor<
   f32[4]
   [1.0, 1.0, 0.6000000238418579, 0.6000000238418579]
 >}

expected_output: #Nx.Tensor<
  f32[1][1][1][2]
  [
    [
      [
        [2.6666665077209473, 4.3333330154418945]
      ]
    ]
  ]
>

actual_output: #Nx.Tensor<
  f32[1][1][1][2]
  [
    [
      [
        [4.199999809265137, 5.466666221618652]
      ]
    ]
  ]
>
[2] input values instead of constants
(ArgumentError) unable to build model from ONNX graph, expected value scales to be constant value, but was :input
[3] Nx.Tensor to Protobuf
input:
#Nx.Tensor<
   f32[4]
   [1.0, 1.0, 0.6, 0.6]
 >
output:
  initializer {
    dims: 4
    data_type: 1
    float_data: 1
    float_data: 1
    float_data: 0.6
    float_data: 0.6
    name: "scales"
  }

See: https://github.com/onnx/onnx/blob/4b2d50334914621835cc1e8dadd4fe82b6b9876c/onnx/onnx-ml.proto#L484

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unsupported op_type "Resize"
3 participants