Skip to content

Commit

Permalink
fix high tolerance for RNN tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Sep 3, 2021
1 parent 3f3c067 commit 2c8fe0e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,12 +726,13 @@ def gru_cell(
b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
r_gate += b_ir + b_hr
r_gate = rz_act(r_gate)
z_gate += b_iz + b_hz
i_n += b_in
h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn
else:
r_gate = rz_act(r_gate)
h_n = _op.nn.dense((r_gate * hidden_state), w_hn)
r_gate = rz_act(r_gate)
z_gate = rz_act(z_gate)
n_gate = n_act(i_n + h_n)

Expand Down
26 changes: 23 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3311,6 +3311,8 @@ def verify_rnn(
use_peep=False,
linear_before_reset=False,
directions=1,
rtol=1e-5,
atol=1e-5,
target=None,
dev=None,
):
Expand Down Expand Up @@ -3433,7 +3435,7 @@ def register(name, shape, proto_type):
model = helper.make_model(graph, producer_name="rnn_test")

verify_with_ort_with_inputs(
model, input_values, output_shapes, atol=1e-2, rtol=1e-2, target=target, dev=dev
model, input_values, output_shapes, atol=atol, rtol=rtol, target=target, dev=dev
)


Expand Down Expand Up @@ -3599,19 +3601,21 @@ def test_gru(target, dev):
use_bias=False,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
# large batch.
# large batch. linear after reset = True
verify_rnn(
seq_length=4,
batch_size=8,
input_size=16,
hidden_size=32,
use_bias=True,
rnn_type="GRU",
linear_before_reset=True,
directions=directions,
linear_before_reset=True,
target=target,
dev=dev,
)
Expand All @@ -3624,6 +3628,8 @@ def test_gru(target, dev):
use_bias=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3636,6 +3642,8 @@ def test_gru(target, dev):
use_bias=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3648,6 +3656,8 @@ def test_gru(target, dev):
use_bias=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3660,6 +3670,8 @@ def test_gru(target, dev):
use_bias=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3675,6 +3687,8 @@ def test_gru(target, dev):
activations=["HardSigmoid", "Softsign"] * directions,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand All @@ -3690,6 +3704,8 @@ def test_gru(target, dev):
betas=[0.3, 0.0] * directions,
rnn_type="GRU",
directions=directions,
rtol=1e-8,
atol=1e-8,
target=target,
dev=dev,
)
Expand All @@ -3705,6 +3721,8 @@ def test_gru(target, dev):
betas=[0.3, 0.1] * directions,
rnn_type="GRU",
directions=directions,
rtol=1e-8,
atol=1e-8,
target=target,
dev=dev,
)
Expand All @@ -3719,6 +3737,8 @@ def test_gru(target, dev):
use_initial_state=True,
rnn_type="GRU",
directions=directions,
rtol=1e-6,
atol=1e-6,
target=target,
dev=dev,
)
Expand Down

0 comments on commit 2c8fe0e

Please sign in to comment.