Skip to content

Commit

Permalink
common GRU was additionaly updated. tuned pytorch GRU was strongly ac…
Browse files Browse the repository at this point in the history
…celerated
  • Loading branch information
vvchernov committed Aug 18, 2021
1 parent b2e5db0 commit 36de7b2
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,30 +707,35 @@ def gru_cell(
outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):
xwt = _op.nn.dense(x_t, w_inp)
i_r, i_z, i_n = _op.split(xwt, 3, axis=1)
w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0)
r_gate = i_r + _op.nn.dense(hidden_state, w_hr)
z_gate = i_z + _op.nn.dense(hidden_state, w_hz)
# TODO(vvchernov): It is assumed that both bias are or not
if b_inp is not None:
b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
r_gate += b_ir + b_hr
r_gate = rz_act(r_gate)
z_gate += b_iz + b_hz
if linear_before_reset:
n_gate = i_n + b_in + (r_gate * (_op.nn.dense(hidden_state, w_hn) + b_hn))
else:
n_gate = i_n + b_in + _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn
if linear_before_reset:
hwt = _op.nn.dense(hidden_state, w_hid)
# TODO(vvchernov): It is assumed that both bias are or not
if b_inp is not None:
xwt += b_inp
hwt += b_hid
i_r, i_z, i_n = _op.split(xwt, 3, axis=-1)
h_r, h_z, h_n = _op.split(hwt, 3, axis=-1)
r_gate = rz_act(i_r + h_r)
z_gate = rz_act(i_z + h_z)
n_gate = n_act(i_n + r_gate * h_n)
else:
r_gate = rz_act(r_gate)
if linear_before_reset:
n_gate = i_n + (r_gate * (_op.nn.dense(hidden_state, w_hn)))
i_r, i_z, i_n = _op.split(xwt, 3, axis=1)
w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0)
r_gate = i_r + _op.nn.dense(hidden_state, w_hr)
z_gate = i_z + _op.nn.dense(hidden_state, w_hz)
# TODO(vvchernov): It is assumed that both bias are or not
if b_inp is not None:
b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
r_gate += b_ir + b_hr
z_gate += b_iz + b_hz
i_n += b_in
h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn
else:
n_gate = i_n + _op.nn.dense((r_gate * hidden_state), w_hn)

z_gate = rz_act(z_gate)
n_gate = n_act(n_gate)
h_n = _op.nn.dense((r_gate * hidden_state), w_hn)
r_gate = rz_act(r_gate)
z_gate = rz_act(z_gate)
n_gate = n_act(i_n + h_n)

hidden_state = (hidden_state - n_gate) * z_gate + n_gate

Expand Down

0 comments on commit 36de7b2

Please sign in to comment.