diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 517b63dca2731..e0dce1e212c25 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -707,30 +707,35 @@ def gru_cell( outputs_list = [] for x_t in input_seqs if not backwards else reversed(input_seqs): xwt = _op.nn.dense(x_t, w_inp) - i_r, i_z, i_n = _op.split(xwt, 3, axis=1) - w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0) - r_gate = i_r + _op.nn.dense(hidden_state, w_hr) - z_gate = i_z + _op.nn.dense(hidden_state, w_hz) - # TODO(vvchernov): It is assumed that both bias are or not - if b_inp is not None: - b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1) - b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1) - r_gate += b_ir + b_hr - r_gate = rz_act(r_gate) - z_gate += b_iz + b_hz - if linear_before_reset: - n_gate = i_n + b_in + (r_gate * (_op.nn.dense(hidden_state, w_hn) + b_hn)) - else: - n_gate = i_n + b_in + _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn + if linear_before_reset: + hwt = _op.nn.dense(hidden_state, w_hid) + # TODO(vvchernov): It is assumed that both bias are or not + if b_inp is not None: + xwt += b_inp + hwt += b_hid + i_r, i_z, i_n = _op.split(xwt, 3, axis=-1) + h_r, h_z, h_n = _op.split(hwt, 3, axis=-1) + r_gate = rz_act(i_r + h_r) + z_gate = rz_act(i_z + h_z) + n_gate = n_act(i_n + r_gate * h_n) else: - r_gate = rz_act(r_gate) - if linear_before_reset: - n_gate = i_n + (r_gate * (_op.nn.dense(hidden_state, w_hn))) + i_r, i_z, i_n = _op.split(xwt, 3, axis=1) + w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0) + r_gate = i_r + _op.nn.dense(hidden_state, w_hr) + z_gate = i_z + _op.nn.dense(hidden_state, w_hz) + # TODO(vvchernov): It is assumed that both bias are or not + if b_inp is not None: + b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1) + b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1) + r_gate += b_ir + b_hr + z_gate += b_iz + b_hz + i_n += b_in + h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn else: - n_gate = i_n + _op.nn.dense((r_gate * hidden_state), w_hn) - - z_gate = rz_act(z_gate) - n_gate = n_act(n_gate) + h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + r_gate = rz_act(r_gate) + z_gate = rz_act(z_gate) + n_gate = n_act(i_n + h_n) hidden_state = (hidden_state - n_gate) * z_gate + n_gate