Skip to content

Commit

Permalink
LSTM output shape and actvations input format were fixed in onnx fron…
Browse files Browse the repository at this point in the history
…tend
  • Loading branch information
vvchernov committed Aug 2, 2021
1 parent 0c5044b commit 6877742
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,9 +2174,7 @@ def bidir_lstm_cell(
cls,
input_seqs,
weight_dicts,
f_act,
g_act,
h_act,
acts,
):
"""
Bidirectional LSTM cell
Expand All @@ -2185,17 +2183,17 @@ def bidir_lstm_cell(
forward_outputs, fw_H_t, fw_C_t = _op.lstm_cell(
input_seqs,
**weight_dicts[0],
f_act=f_act,
g_act=g_act,
h_act=h_act,
f_act=acts[0],
g_act=acts[1],
h_act=acts[2],
)

reverse_outputs, rev_H_t, rev_C_t = _op.lstm_cell(
input_seqs,
**weight_dicts[1],
f_act=f_act,
g_act=g_act,
h_act=h_act,
f_act=acts[3],
g_act=acts[4],
h_act=acts[5],
backwards=True,
)

Expand Down Expand Up @@ -2267,11 +2265,8 @@ def _impl_v7(cls, inputs, attr, params):
beta = betas[beta_loc]
beta_loc += 1
acts.append(cls._activation_helper(activation, alpha, beta))
f_act, g_act, h_act = acts
else:
f_act = _op.sigmoid
g_act = _op.tanh
h_act = _op.tanh
acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions

# TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
X_steps = LSTM.unbind(X, axis=0)
Expand Down Expand Up @@ -2318,24 +2313,22 @@ def _impl_v7(cls, inputs, attr, params):
output, H, C = LSTM.bidir_lstm_cell(
input_seqs=X_steps,
weight_dicts=weights_dicts,
f_act=f_act,
g_act=g_act,
h_act=h_act,
acts=acts,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
outputs, H, C = _op.lstm_cell(
input_seqs=X_steps,
**weights_dicts[0],
f_act=f_act,
g_act=g_act,
h_act=h_act,
f_act=acts[0],
g_act=acts[1],
h_act=acts[2],
)

# output shape = (seqs_num, num_directions, batch_size, hidden_size)
output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
H = _op.expand_dims(H, axis=1)
C = _op.expand_dims(C, axis=1)
H = _op.expand_dims(H, axis=0)
C = _op.expand_dims(C, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3)

Expand Down

0 comments on commit 6877742

Please sign in to comment.