Skip to content

Commit

Permalink
[Frontend] Unified LSTM cell (#8599)
Browse files Browse the repository at this point in the history
* fuse dence sum

* remove excess copying

* dev LSTM in ONNX

* alternative implementation of LSTM in onnx frontend. It is quicker than current one without tuning

* LSTM_dev2 was implemented in onnx frontend

* LSTM dev in pytorch frontend

* LSTM cell implementation was transferred to common place. Unneccessary code was removed

* lint fixes

* Weights permutation for LSTM layer in onnx frontend

* LSTM cell description was added

* arguments and values were renamed. descriptions of some methods were added

* LSTM output shape and actvations input format were fixed in onnx frontend

* empty. tvm-ci test

* unbind method was transferred from onnx frontend to common.py

* unbind method was transferred from pytorch frontend to common.py

* lstm cell was transferred from op/layers.py to frontend/common.py

* clean up weight dictionary initialization

* fix pytorch frontend wrapper over unbind method

* minor fix of comments

* empty. tvm-ci test restart

* empty. tvm-ci test restart

Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
  • Loading branch information
vvchernov and vvchernov authored Aug 6, 2021
1 parent dc5da05 commit 2c124c9
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 243 deletions.
2 changes: 1 addition & 1 deletion python/tvm/contrib/target/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def convert_attributes(cls, attrs):


class Cast(OpConverter):
""" Operator converter for Cast."""
"""Operator converter for Cast."""

@classmethod
def convert_attributes(cls, attrs):
Expand Down
125 changes: 125 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,128 @@ def to_int_list(np_array):
cause problems in relay/TOPI.
"""
return [int(x) for x in np_array]


def unbind(data, axis=0):
"""
Unbind was taken from Pytorch frontend. The operation removes a tensor dimension
and returns a tuple of all slices along a given dimension, with specified axis removed.
TODO (vvchernov): It needs such operation on relay side to reduce time consumption
on squeeze operation.
Parameters
----------
data : relay.Expr
Input tensor
axis : int
Axis along which tensor is split.
Returns
-------
result : List[relay.Expr]
The sequence of computed tensors
"""
shape = infer_shape(data)
if axis >= len(shape):
msg = "Please check input dim, it shouldn't be greater than or equal to rank."
raise AttributeError(msg)

selections = shape[axis]
res_split = _op.split(data, selections, axis)
ret = []
for i in range(selections):
ret.append(_op.squeeze(res_split[i], axis=[axis]))
return _expr.TupleWrapper(_expr.Tuple(ret), selections)


def lstm_cell(
input_seqs,
hidden_state,
cell_state,
w_inp,
w_hid,
b_inp=None,
b_hid=None,
proj=None,
p_i=None,
p_f=None,
p_o=None,
f_act=_op.sigmoid,
g_act=_op.tanh,
h_act=_op.tanh,
backwards=False,
):
"""
Common implementation of LSTM cell for all frontends of TVM
TODO (vvchernov): currently it is used by onnx and pytorch. Extend for other frontends
Parameters
----------
input_seqs : List[relay.Expr]
The sequence of input tensors
Input tensor should be 2d while issue #8412 is not resolved
Shape = (batch, feature_size)
hidden_state : relay.Expr
Hidden state. shape = (batch, hidden_size)
cell_state : relay.Expr
Cell state. shape = (batch, hidden_size)
w_inp, w_hid : relay.Expr
weight matrices. wi shape = (4 * hidden_size, feature_size)
wh shape = (4 * hidden_size, hidden_size or proj_size)
NOTE: wi = (w_ii|w_if|w_ig|w_io) for input, forget, cell and output gates.
The order is important for correct LSTM calculation!
b_inp, b_hid : relay.Expr
bias matrices. The same order of internal parts as for weights. shape = (4 * hidden_size)
proj : relay.Expr
projection matrix. shape = (proj_size, hidden_size)
p_i, p_f, p_o : relay.Expr
peephole LSTM matrices. shape = (batch, hidden_size)
f_act, g_act, h_act : relay.op
activation funtions
backwards : bool
Flag for reverse pass of LSTM
Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
The sequence of computed result, final hidden and cell state
"""

outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):
# x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size)
step = _op.concatenate([x_t, hidden_state], axis=1)
cat_w = _op.concatenate([w_inp, w_hid], axis=1)
# Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid)
# nn.dense(step, cat_w) is used
# gates shape = (batch, 4 * hidden_size)
gates = _op.nn.dense(step, cat_w)
# Add biases
if b_inp is not None:
gates += b_inp
if b_hid is not None:
gates += b_hid
# any gate shape = (batch, hidden_size)
inp_gate, fgt_gate, cell_gate, otp_gate = _op.split(gates, 4, axis=-1)

if p_i is not None and p_f is not None:
inp_gate = f_act(inp_gate + p_i * cell_state)
fgt_gate = f_act(fgt_gate + p_f * cell_state)
else:
inp_gate = f_act(inp_gate)
fgt_gate = f_act(fgt_gate)

cell_gate = g_act(cell_gate)
cell_state = fgt_gate * cell_state + inp_gate * cell_gate
if p_o is not None:
otp_gate = f_act(otp_gate + p_o * cell_state)
else:
otp_gate = f_act(otp_gate)

hidden_state = otp_gate * h_act(cell_state)

if proj is not None:
hidden_state = _op.nn.dense(hidden_state, proj)

outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]

return outputs_list, hidden_state, cell_state
180 changes: 88 additions & 92 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
infer_type,
infer_value,
new_var,
unbind,
lstm_cell,
)

__all__ = ["from_onnx"]
Expand Down Expand Up @@ -2155,58 +2157,44 @@ class LSTM(RNN):
"""Operator converter for LSTM"""

@classmethod
def generate_lstm(
cls, X_steps, H_t, C_t, W, R, B, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False
def bidir_lstm_cell(
cls,
input_seqs,
weight_dicts,
acts,
):
"""Create an unrolled lstm loop.
See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math.
"""
h_list = []
seq_length = len(X_steps)
for i in range(seq_length):
step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)]
step = _op.squeeze(step, axis=[0])
gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R)
if B is not None:
WB, RB = _op.split(B, 2)
gates += WB + RB
i, o, f, c = _op.split(gates, 4, axis=-1)

if p_i != 0:
i = f_act(i + p_i * C_t)
else:
i = f_act(i)

if p_f != 0:
f = f_act(f + p_f * C_t)
else:
f = f_act(f)

c = g_act(c)
C = f * C_t + i * c
if p_o != 0:
o = f_act(o + p_o * C)
else:
o = f_act(o)

H = o * h_act(C)

H_t = H
C_t = C
h_list.append(_op.expand_dims(H, axis=0))
Bidirectional LSTM cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t, fw_C_t = lstm_cell(
input_seqs,
**weight_dicts[0],
f_act=acts[0],
g_act=acts[1],
h_act=acts[2],
)

if backwards:
# Canonical view is hidden states from the first token not last
h_list = h_list[::-1]
reverse_outputs, rev_H_t, rev_C_t = lstm_cell(
input_seqs,
**weight_dicts[1],
f_act=acts[3],
g_act=acts[4],
h_act=acts[5],
backwards=True,
)

# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
C_t = _op.expand_dims(C_t, axis=0)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0)
)

return output, H_t, C_t
return (
_op.stack(final_outputs, axis=0),
_op.stack([fw_H_t, rev_H_t], axis=0),
_op.stack([fw_C_t, rev_C_t], axis=0),
)

@classmethod
def _impl_v7(cls, inputs, attr, params):
Expand Down Expand Up @@ -2237,12 +2225,6 @@ def _impl_v7(cls, inputs, attr, params):
Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
if Cp_0 is None:
Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
if Bp is None:
Bp = _op.zeros((num_directions, hidden_size * 8), W_dtype)
if Pp is not None:
p_i, p_o, p_f = _op.split(Pp, 3, axis=1)
else:
p_i = p_o = p_f = _op.zeros((num_directions, hidden_size), W_dtype)

if "activations" in attr:
activations = attr["activations"]
Expand Down Expand Up @@ -2273,53 +2255,67 @@ def _impl_v7(cls, inputs, attr, params):
else:
acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
result_output = []
result_H = []
result_C = []
# TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
X_steps = unbind(X, axis=0)

H_ts = _op.split(Hp_0, num_directions)
C_ts = _op.split(Cp_0, num_directions)
Ws = _op.split(Wp, num_directions)
Rs = _op.split(Rp, num_directions)
Bs = _op.split(Bp, num_directions)
p_is = _op.split(p_i, num_directions)
p_fs = _op.split(p_f, num_directions)
p_os = _op.split(p_o, num_directions)
for i in range(num_directions):
H_t = _op.squeeze(H_ts[i], axis=[0])
C_t = _op.squeeze(C_ts[i], axis=[0])
W = _op.squeeze(Ws[i], axis=[0])
R = _op.squeeze(Rs[i], axis=[0])
B = _op.squeeze(Bs[i], axis=[0])
p_i = _op.squeeze(p_is[i], axis=[0])
p_f = _op.squeeze(p_fs[i], axis=[0])
p_o = _op.squeeze(p_os[i], axis=[0])

f_act, g_act, h_act = acts[i * 3 : (i + 1) * 3]
output, H, C = LSTM.generate_lstm(
X_steps=X_steps,
H_t=H_t,
C_t=C_t,
W=W,
R=R,
B=B,
p_i=p_i,
p_f=p_f,
p_o=p_o,
f_act=f_act,
g_act=g_act,
h_act=h_act,
backwards=i == 1,
)
if Bp is not None:
Bs = _op.split(Bp, num_directions)
if Pp is not None:
p_i, p_o, p_f = _op.split(Pp, 3, axis=1)

result_output.append(output)
result_H.append(H)
result_C.append(C)
p_is = _op.split(p_i, num_directions)
p_fs = _op.split(p_f, num_directions)
p_os = _op.split(p_o, num_directions)

output = _op.concatenate(result_output, axis=1)
H = _op.concatenate(result_H, axis=0)
C = _op.concatenate(result_C, axis=0)
weights_dicts = []
for i in range(num_directions):
weights_dict = {}

weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
weights_dict["cell_state"] = _op.squeeze(C_ts[i], axis=[0])

# Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o
mati, mato, matf, matc = _op.split(_op.squeeze(Ws[i], axis=[0]), 4)
weights_dict["w_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0)
mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4)
weights_dict["w_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0)
if Bp is not None:
Bi, Bh = _op.split(Bs[i], 2, -1)
mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 4)
weights_dict["b_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0)
mati, mato, matf, matc = _op.split(_op.squeeze(Bh, axis=[0]), 4)
weights_dict["b_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0)
if Pp is not None:
weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0])
weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0])
weights_dict["p_o"] = _op.squeeze(p_os[i], axis=[0])
weights_dicts.append(weights_dict)

if num_directions == 2:
output, H, C = LSTM.bidir_lstm_cell(
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
outputs, H, C = lstm_cell(
input_seqs=X_steps,
**weights_dicts[0],
f_act=acts[0],
g_act=acts[1],
h_act=acts[2],
)

# output shape = (seqs_num, num_directions, batch_size, hidden_size)
output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
H = _op.expand_dims(H, axis=0)
C = _op.expand_dims(C, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3)

Expand Down
Loading

0 comments on commit 2c124c9

Please sign in to comment.