-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[v1.x] ONNX export support for RNN and sum_axis #20226
Conversation
Hey @waytrue17 , Thanks for submitting the PR
CI supported jobs: [edge, unix-gpu, centos-cpu, unix-cpu, clang, centos-gpu, windows-gpu, windows-cpu, miscellaneous, sanity, website] Note: |
@pytest.mark.parametrize('input_size', [16, 32, 64]) | ||
@pytest.mark.parametrize('num_layers', [1, 2]) | ||
@pytest.mark.parametrize('batch_size', [1, 2, 4]) | ||
@pytest.mark.parametrize('seq_length', [16, 32]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why removing 32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saw the same assertion issue with large seq_length/state_size/input_size. Will need to further decrease one of them to pass rnn
, just like what we did before for lstm
and gru
.
else: | ||
op_export_test('rnn', M, [x, param, state], tmp_path) | ||
op_export_test('rnn', M, [x, param, state], tmp_path, atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know how large was the difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my local runs, the diff is between 1e-3 to 1e-2.
"""Map MXNet's sum_axis operator. | ||
sum_axis is equivalent to sum in MXNet | ||
""" | ||
return convert_sum(node, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of wrapping convert_sum() here, couldn't we just add the decorator to register "sum_axis" to convert_sum() ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registered it under convert_sum. Thanks for the good point!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Description
Add onnx conversion logic for RNN and sum_axis. Add unittest for the two.
Checklist
Essentials
Changes
Comments