Skip to content

Commit

Permalink
[ONNX] [Test] fix GRU modification and reduce tolerance for RNN tests (
Browse files Browse the repository at this point in the history
…apache#8923)

* fix high tolerance for RNN tests

* random seed was added to GRU test reproduction

Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 7ffc71e commit 4a0e472
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,12 +726,13 @@ def gru_cell(
b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
r_gate += b_ir + b_hr
r_gate = rz_act(r_gate)
z_gate += b_iz + b_hz
i_n += b_in
h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn
else:
r_gate = rz_act(r_gate)
h_n = _op.nn.dense((r_gate * hidden_state), w_hn)
r_gate = rz_act(r_gate)
z_gate = rz_act(z_gate)
n_gate = n_act(i_n + h_n)

Expand Down
26 changes: 24 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3323,6 +3323,8 @@ def verify_rnn(
use_peep=False,
linear_before_reset=False,
directions=1,
rtol=1e-5,
atol=1e-5,
target=None,
dev=None,
):
Expand Down Expand Up @@ -3445,7 +3447,7 @@ def register(name, shape, proto_type):
model = helper.make_model(graph, producer_name="rnn_test")

verify_with_ort_with_inputs(
model, input_values, output_shapes, atol=1e-2, rtol=1e-2, target=target, dev=dev
model, input_values, output_shapes, atol=atol, rtol=rtol, target=target, dev=dev
)


Expand Down Expand Up @@ -3601,6 +3603,8 @@ def test_lstm(target, dev):

@tvm.testing.parametrize_targets
def test_gru(target, dev):
# Set seed for test reproduction
np.random.seed(137)
for directions in [1, 2]:
# No bias.
verify_rnn(
Expand All @@ -3611,10 +3615,12 @@ def test_gru(target, dev):
use_bias=False,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
# large batch.
# large batch. linear before reset
verify_rnn(
seq_length=4,
batch_size=8,
Expand All @@ -3636,6 +3642,8 @@ def test_gru(target, dev):
use_bias=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3648,6 +3656,8 @@ def test_gru(target, dev):
use_bias=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3660,6 +3670,8 @@ def test_gru(target, dev):
use_bias=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3672,6 +3684,8 @@ def test_gru(target, dev):
use_bias=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3687,6 +3701,8 @@ def test_gru(target, dev):
activations=["HardSigmoid", "Softsign"] * directions,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3702,6 +3718,8 @@ def test_gru(target, dev):
betas=[0.3, 0.0] * directions,
rnn_type="GRU",
directions=directions,
rtol=1e-8,
atol=1e-8,
target=target,
dev=dev,
)
Expand All @@ -3717,6 +3735,8 @@ def test_gru(target, dev):
betas=[0.3, 0.1] * directions,
rnn_type="GRU",
directions=directions,
rtol=1e-8,
atol=1e-8,
target=target,
dev=dev,
)
Expand All @@ -3731,6 +3751,8 @@ def test_gru(target, dev):
use_initial_state=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand Down

0 comments on commit 4a0e472

Please sign in to comment.