diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 077b942ddf01..ce048105ae8b 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -658,6 +658,90 @@ def unbind(data, axis=0): return _expr.TupleWrapper(_expr.Tuple(ret), selections) +def gru_cell( + input_seqs, + hidden_state, + w_inp, + w_hid, + b_inp=None, + b_hid=None, + rz_act=_op.sigmoid, + n_act=_op.tanh, + backwards=False, + linear_before_reset=True, +): + """ + Common implementation of GRU cell for all frontends of TVM + TODO(vvchernov): currently it is used by pytorch and ONNX. Extend for other frontends + + Parameters + ---------- + input_seqs : List[relay.Expr] + The sequence of input tensors + Input tensor should be 2d while issue #8412 is not resolved + Shape = (batch, feature_size) + hidden_state : relay.Expr + Hidden state. shape = (batch_size, hidden_size) + w_inp, w_hid : relay.Expr + weight matrices. wi shape = (3 * hidden_size, feature_size) + wh shape = (3 * hidden_size, hidden_size) + NOTE: wi = (w_ir|w_iz|w_in) for reset, update and new gates. + The order is important for correct GRU calculation! + b_inp, b_hid : relay.Expr + bias matrices. The same order of internal parts as for weights. shape = (3 * hidden_size) + r_act : relay.op + activation funtion for reset gate. it is sigmoid by default + z_act : relay.op + activation funtion for update gate. it is sigmoid by default + n_act : relay.op + activation funtion for new gate. it is tanh by default + backwards : bool + Flag for reverse pass of GRU + + Returns + ------- + result : List[relay.Expr], relay.Expr, relay.Expr + The sequence of computed result, final hidden and cell state + """ + + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + xwt = _op.nn.dense(x_t, w_inp) + if linear_before_reset: + hwt = _op.nn.dense(hidden_state, w_hid) + if b_inp is not None and b_hid is not None: + xwt += b_inp + hwt += b_hid + i_r, i_z, i_n = _op.split(xwt, 3, axis=-1) + h_r, h_z, h_n = _op.split(hwt, 3, axis=-1) + r_gate = rz_act(i_r + h_r) + z_gate = rz_act(i_z + h_z) + n_gate = n_act(i_n + r_gate * h_n) + else: + i_r, i_z, i_n = _op.split(xwt, 3, axis=1) + w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0) + r_gate = i_r + _op.nn.dense(hidden_state, w_hr) + z_gate = i_z + _op.nn.dense(hidden_state, w_hz) + if b_inp is not None and b_hid is not None: + b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1) + b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1) + r_gate += b_ir + b_hr + z_gate += b_iz + b_hz + i_n += b_in + h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn + else: + h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + r_gate = rz_act(r_gate) + z_gate = rz_act(z_gate) + n_gate = n_act(i_n + h_n) + + hidden_state = (hidden_state - n_gate) * z_gate + n_gate + + outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] + + return outputs_list, hidden_state + + def lstm_cell( input_seqs, hidden_state, diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0f78c32ef59f..5471f67ea106 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -47,6 +47,7 @@ infer_value, new_var, unbind, + gru_cell, lstm_cell, ) @@ -2349,56 +2350,41 @@ class GRU(RNN): """Operator convert for GRU""" @classmethod - def generate_gru( - cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False + def bidir_gru_cell( + cls, + input_seqs, + weight_dicts, + acts, ): - """Create an unrolled gru loop. - - See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. """ - h_list = [] - seq_length = len(X_steps) - for i in range(seq_length): - step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] - step = _op.squeeze(step, axis=[0]) - current = _op.nn.dense(step, W) - cz, cr, ch = _op.split(current, 3, axis=1) - rz, rr, rh = _op.split(R, 3, axis=0) - z = cz + _op.nn.dense(H_t, rz) - r = cr + _op.nn.dense(H_t, rr) - if B is not None: - WB, RB = _op.split(B, 2) - wbz, wbr, wbh = _op.split(WB, 3, axis=-1) - rbz, rbr, rbh = _op.split(RB, 3, axis=-1) - z += wbz + rbz - r += wbr + rbr - if linear_before_reset: - h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh - else: - h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh - else: - if linear_before_reset: - h = ch + (r * (_op.nn.dense(H_t, rh))) - else: - h = ch + _op.nn.dense((r * H_t), rh) - - z = f_act(z) - r = f_act(r) - h = g_act(h) - - H_t = ((_expr.const(1, dtype=W_dtype) - z) * h) + (z * H_t) - h_list.append(_op.expand_dims(H_t, axis=0)) + Bidirectional GRU cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t = gru_cell( + input_seqs, + **weight_dicts[0], + rz_act=acts[0], + n_act=acts[1], + ) - if backwards: - # Canonical view is hidden states from the first token not last - h_list = h_list[::-1] + reverse_outputs, rev_H_t = gru_cell( + input_seqs, + **weight_dicts[1], + rz_act=acts[2], + n_act=acts[3], + backwards=True, + ) - # Concatenate outputs and add back in direction axis. - concatenated = _op.concatenate(h_list, 0) - output = _op.expand_dims(concatenated, axis=1) - H_t = _op.expand_dims(H_t, axis=0) + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0) + ) - return output, H_t + return ( + _op.stack(final_outputs, axis=0), + _op.stack([fw_H_t, rev_H_t], axis=0), + ) @classmethod def _impl_v7(cls, inputs, attr, params): @@ -2416,20 +2402,14 @@ def _impl_v7(cls, inputs, attr, params): W_dtype = infer_type(Wp).checked_type.dtype if num_directions not in [1, 2]: - raise NotImplementedError( - f"Directions for GRUs should be either 1 or 2 got {num_directions}" - ) + raise ValueError("num_directions must be either 1 or 2!") X_shape = infer_shape(X) hidden_size = infer_shape(Rp)[-1] batch_size = X_shape[1] - # Initialize state if not provided. - # Otherwise remove bidirectional axis. if Hp_0 is None: Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Bp is None: - Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype) if "activations" in attr: activations = attr["activations"] @@ -2460,39 +2440,54 @@ def _impl_v7(cls, inputs, attr, params): else: acts = [_op.sigmoid, _op.tanh] * 2 - result_output = [] - result_H = [] + # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved + X_steps = unbind(X, axis=0) - X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) H_ts = _op.split(Hp_0, num_directions) Ws = _op.split(Wp, num_directions) Rs = _op.split(Rp, num_directions) - Bs = _op.split(Bp, num_directions) + if Bp is not None: + Bs = _op.split(Bp, num_directions) + + weights_dicts = [] for i in range(num_directions): - H_t = _op.squeeze(H_ts[i], axis=[0]) - W = _op.squeeze(Ws[i], axis=[0]) - R = _op.squeeze(Rs[i], axis=[0]) - B = _op.squeeze(Bs[i], axis=[0]) - f_act, g_act = acts[i * 2 : (i + 1) * 2] - output, H = GRU.generate_gru( - X_steps=X_steps, - H_t=H_t, - W=W, - R=R, - B=B, - linear_before_reset=linear_before_reset, - f_act=f_act, - g_act=g_act, - W_dtype=W_dtype, - backwards=i == 1, - ) + weights_dict = {} + + weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0]) + weights_dict["linear_before_reset"] = linear_before_reset + + # Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o + matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3) + weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0) + matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3) + weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0) + if Bp is not None: + Bi, Bh = _op.split(Bs[i], 2, -1) + matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3) + weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0) + matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3) + weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], axis=0) + weights_dicts.append(weights_dict) - result_output.append(output) - result_H.append(H) + if num_directions == 2: + output, H = GRU.bidir_gru_cell( + input_seqs=X_steps, + weight_dicts=weights_dicts, + acts=acts, + ) + else: + # outputs shape = [seqs_num, (batch_size, hidden_size)] + outputs, H = gru_cell( + input_seqs=X_steps, + **weights_dicts[0], + rz_act=acts[0], + n_act=acts[1], + ) - output = _op.concatenate(result_output, axis=1) - H = _op.concatenate(result_H, axis=0) + # output shape = (seqs_num, num_directions, batch_size, hidden_size) + output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) + H = _op.expand_dims(H, axis=0) return _expr.TupleWrapper(_expr.Tuple((output, H)), 2) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7c10889ce17e..613643f091d7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -39,7 +39,7 @@ from ..prelude import Prelude, StaticTensorArrayOps from ..ty import Any, TensorType, TupleType from . import qnn_torch -from .common import AttrCvt, get_relay_op, unbind, lstm_cell +from .common import AttrCvt, get_relay_op, unbind, lstm_cell, gru_cell from .common import infer_value as _infer_value from .common import infer_shape as _infer_shape from .common import infer_value_simulated as _infer_value_simulated @@ -2315,6 +2315,192 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) + def bidir_gru_cell( + self, + input_seqs, + weights_dicts, + ): + """ + Bidirectional GRU cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t = gru_cell( + input_seqs, + **weights_dicts[0], + ) + + reverse_outputs, rev_H_t = gru_cell( + input_seqs, + **weights_dicts[1], + backwards=True, + ) + + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1) + ) + + return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0) + + def gru_layers(self, input_data, layer_weights_dicts, bidirectional, dropout_p=0.0): + """ + Methods iterates layers for Stacked GRU + """ + layers_num = len(layer_weights_dicts) + # split input sequence to samples set + input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)] + output_hiddens = [] + for i in range(layers_num): + weights_dicts = layer_weights_dicts[i] + # input_seqs shape = [seq_num, (batch, feature_size)] or + # [seq_num, (batch, 2*feature_size)] for bidirectional + if bidirectional: + input_seqs, H_t = self.bidir_gru_cell(input_seqs, weights_dicts) + else: + input_seqs, H_t = gru_cell(input_seqs, **weights_dicts[0]) + + output_hiddens.append(H_t) + + # TODO (vvchernov): in pytorch implementation train is also checked + # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 + # /aten/src/ATen/native/RNN.cpp#L1054 + if dropout_p != 0 and i < layers_num - 1: + # for input in input_seqs: + # input = _op.dropout(input, dropout_p) + raise NotImplementedError("Dropout for GRU has not been supported yet!") + + return _op.stack(input_seqs, 0), _op.stack(output_hiddens, 0) + + def gru(self, inputs, input_types): + """ + Description of GRU in pytorch: + https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=gru#torch.nn.GRU + """ + # TODO (vvchernov): support dropout + assert len(inputs) == 9, "Input of size 9 is expected" + # Unpack inputs, note that if optional and not provided then value will be None. + _X = inputs[0] + # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size) + + hidden_state = inputs[1] + # Hidden state shape (hidden_layers_num, batch, hidden_size) + + _weights = inputs[2] + # Wi layer[0] shape (3 * hidden_size, feature_size) + # Wh layer[0] shape (3 * hidden_size, hidden_size) + # Bi layer[0] shape (3 * hidden_size) + # Bh layer[0] shape (3 * hidden_size) + + # Wi layer[>0] shape (3 * hidden_size, hidden_size * num_directions) + # Wh layer[>0] shape (3 * hidden_size, hidden_size) + # Bi layer[>0] shape (3 * hidden_size) + # Bh layer[>0] shape (3 * hidden_size) + + # Scalar inputs + has_biases = inputs[3] + num_layers = inputs[4] + dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout + # train = inputs[6] + bidirectional = inputs[7] + batch_first = inputs[8] + + num_directions = 1 + if bidirectional: + num_directions = 2 + + rsd = len(_weights) % num_layers + assert rsd == 0, "The number of weights must be a multiple of the number of layers!" + rsd = (len(_weights) / num_layers) % num_directions + assert ( + rsd == 0 + ), "The number of weights in layer must be a multiple of the number of directions!" + + weights_num = int(len(_weights) / num_layers / num_directions) + if has_biases: + assert weights_num == 4, "The weights number in layer is expected equal to 4" + else: + assert weights_num == 2, "The weights number in layer is expected equal to 2" + + X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X + # TODO (vvchernov): Which data type should be used? from input or weights? + # Instead of it _infer_type(X).checked_type.dtype can be used + X_dtype = input_types[0] + X_shape = _infer_shape(X) # (seq_num, batch, feature_size) + + hidden_size = int(_infer_shape(_weights[0])[0] / 3) + batch_size = X_shape[1] + + # Initialize hidden states if not provided. + layers_h = [] + hidden_layers_num = num_directions * num_layers + if hidden_state is None: + h_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_h.append(h_0) + else: + layers_h = unbind(hidden_state, 0) + + layer_weights_dicts = [] + k = 0 # layer counter + if has_biases: + names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]] + rev_weights_dict = dict(zip(names, rev_tensors)) + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + else: + names = ["hidden_state", "w_inp", "w_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]] + rev_weights_dict = dict(zip(names, rev_tensors)) + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + assert ( + len(layer_weights_dicts) == num_layers and k == num_layers + ), "For stacked GRU number of weights sets should be the same as number of layers!" + + output, out_hidden_state = self.gru_layers( + X, + layer_weights_dicts, + bidirectional, + dropout_p=dropout_p, + ) + + # output shape = (seq_num, batch, hidden_size) or + # (seq_num, batch, 2*feature_size) for bidirectional + if batch_first: + output = _op.transpose(output, (1, 0, 2)) + + return (output, out_hidden_state) + def bidir_lstm_cell( self, input_seqs, @@ -2792,6 +2978,7 @@ def create_convert_map(self): "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, "aten::flip": self.flip, + "aten::gru": self.gru, "aten::lstm": self.lstm, } diff --git a/tests/python/frontend/pytorch/test_lstms.py b/tests/python/frontend/pytorch/test_lstms.py deleted file mode 100644 index 967245e1ef9d..000000000000 --- a/tests/python/frontend/pytorch/test_lstms.py +++ /dev/null @@ -1,363 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -import tvm.testing -import numpy as np -import torch -import onnx -import io -import sys -import pytest - -from tvm import relay -from tvm.contrib import graph_executor - -from torch import nn - -## Model parameters -model_feature_size = 16 -model_hidden_size = 32 -model_num_layers = 2 -seqs_length = 2 -projection_size = 20 -batch_size = 2 - - -def check_torch_version_for_proj_in_lstm(): - """ - proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0 torch version - """ - me = False - - version = torch.__version__ - major, minor, micro = version.split(".") - - if int(major) > 1: - me = True - elif int(major) == 1: - if int(minor) >= 8: - me = True - - return me - - -class LSTM_Model(nn.Module): - def __init__( - self, - device, - batch_first=False, - layer_num=1, - bidirectional=False, - proj_size=0, - use_bias=True, - rnd_weights_init=False, - ): - super().__init__() - - self.device = device - self.batch_first = batch_first - self.use_bias = use_bias - - if check_torch_version_for_proj_in_lstm(): - self.lstm = nn.LSTM( - input_size=model_feature_size, - hidden_size=model_hidden_size, - num_layers=layer_num, - bidirectional=bidirectional, - proj_size=proj_size, - batch_first=batch_first, - bias=use_bias, - ).to(device) - else: - if proj_size > 0: - print( - "WARNING: projection is not supported for torch version less than 1.8.0! ", - "LSTM was constructed without projection!", - ) - # sys.exit() - self.lstm = nn.LSTM( - input_size=model_feature_size, - hidden_size=model_hidden_size, - num_layers=layer_num, - bidirectional=bidirectional, - batch_first=batch_first, - bias=use_bias, - ).to(device) - - if rnd_weights_init: - self.gen_rnd_weights() - - def forward(self, input, hidden_init=None): - """ - Computes the output tensor after input inference along LSTM layer. - - :param input: batch of data as a tensor of shape (seqs_length, batch_size, model_feature_size) or (batch_size, seqs_length, model_feature_size) if self.batch_first = True - :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. - :return: the output tensor of shape (batch_size, model_hidden_size) - """ - # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state - # and the final cell state. - out, (hidden, cell) = self.lstm(input, hidden_init) - - return out - - def gen_rnd_weights(self): - """ - Generate random weigths for the model with biases - Without projection: - For first weights group: - Wi (4*model_hidden_size, model_feature_size) - Wh (4*model_hidden_size, model_hidden_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - For first bidirectional weights group: - Wi (4*model_hidden_size, model_feature_size) - Wh (4*model_hidden_size, model_hidden_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - For other weights group: - Wi (4*model_hidden_size, model_hidden_size) - Wh (4*model_hidden_size, model_hidden_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - With projection: - For first weights group: - Wi (4*model_hidden_size, model_feature_size) - Wh (4*model_hidden_size, proj_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - P (proj_size, model_hidden_size) - For first bidirectional weights group: - Wi (4*model_hidden_size, model_feature_size) - Wh (4*model_hidden_size, proj_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - P (proj_size, model_hidden_size) - For other weights group: - Wi (4*model_hidden_size, proj_size * num_directions) - Wh (4*model_hidden_size, proj_size) - Bi (4*model_hidden_size) - Bh (4*model_hidden_size) - P (proj_size, model_hidden_size) - For generation of random weigths for the model without biases Bi and Bh are skipped - """ - for weight_group in self.lstm.all_weights: - for weight in weight_group: - weight.data = torch.rand(weight.shape) - - def get_dummy_input(self): - shape = [seqs_length, batch_size, model_feature_size] - if self.batch_first: - shape = [batch_size, seqs_length, model_feature_size] - res = torch.rand(shape) - - return res, shape - - -def compare(input, gold_data, rtol=1e-5, atol=1e-5): - tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol) - - -def check_lstm_with_type( - lstm_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0) -): - has_proj = "p" in lstm_type - - device = torch.device("cpu") - hidden_layers_num = 1 - model = None - for batch_first in (True, False): - for use_bias in (True, False): - for rnd_weights in [True]: # (True, False): - if lstm_type == "uni": - model = LSTM_Model( - device, - batch_first=batch_first, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - elif lstm_type == "b": - model = LSTM_Model( - device, - batch_first=batch_first, - bidirectional=True, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 - elif lstm_type == "p": - model = LSTM_Model( - device, - batch_first=batch_first, - proj_size=projection_size, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - elif lstm_type == "s": - model = LSTM_Model( - device, - batch_first=batch_first, - layer_num=model_num_layers, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = model_num_layers - elif lstm_type == "sb": - model = LSTM_Model( - device, - batch_first=batch_first, - bidirectional=True, - layer_num=model_num_layers, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 * model_num_layers - elif lstm_type == "sp": - model = LSTM_Model( - device, - batch_first=batch_first, - layer_num=model_num_layers, - proj_size=projection_size, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = model_num_layers - elif lstm_type == "bp": - model = LSTM_Model( - device, - batch_first=batch_first, - bidirectional=True, - proj_size=projection_size, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 - elif lstm_type == "sbp": - model = LSTM_Model( - device, - batch_first=batch_first, - bidirectional=True, - layer_num=model_num_layers, - proj_size=projection_size, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 * model_num_layers - else: - print("WARNING: LSTM type {} is not supported here!".format(lstm_type)) - return - - model.eval() - - # Get golden output from original model - input_hidden_shape = (hidden_layers_num, batch_size, model_hidden_size) - input_hidden_shape_with_proj = (hidden_layers_num, batch_size, projection_size) - dummy_input, input_shape = model.get_dummy_input() - golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() - - dtype = "float32" - h_zeros = np.zeros(input_hidden_shape, dtype=dtype) - if has_proj: - h_zeros = np.zeros(input_hidden_shape_with_proj, dtype=dtype) - c_zeros = np.zeros(input_hidden_shape, dtype=dtype) - - tvm_output = None - for format in ["ts"]: # ["ts", "onnx"]: - if format == "ts": - # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. - traced_script_module = torch.jit.trace(model, dummy_input).eval() - - # Import model to Relay - shape_list = [("input", input_shape)] - mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list) - - # Model compilation by tvm - with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, params=params) - elif format == "onnx": - if has_proj: - print( - "WARNING: torch.onnx.export does not support conversion LSTM with projection " - "from pytorch! TODO: waiting for the support and correct test after that." - ) - continue - onnx_io = io.BytesIO() - with torch.no_grad(): - h0 = torch.rand(input_hidden_shape) - if has_proj: - h0 = torch.rand(input_hidden_shape_with_proj) - c0 = torch.rand(input_hidden_shape) - input_names = ["input", "h0", "c0"] - - # default export (without dynamic input) - torch.onnx.export( - model, (dummy_input, (h0, c0)), onnx_io, input_names=input_names - ) - onnx_io.seek(0, 0) - onnx_model = onnx.load_model(onnx_io) - - # Import model to Relay - shape_dict = { - "input": input_shape, - "h0": input_hidden_shape, - "c0": input_hidden_shape, - } - if has_proj: - shape_dict = { - "input": input_shape, - "h0": input_hidden_shape_with_proj, - "c0": input_hidden_shape, - } - mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) - - # Model compilation by tvm - with tvm.transform.PassContext(opt_level=1): - lib = relay.build(mod, target=target, params=params) - - # Inference of the model with given input data - m = graph_executor.GraphModule(lib["default"](dev)) - - # Set inputs - m.set_input( - input=tvm.nd.array(dummy_input.numpy().astype(dtype)), - h0=tvm.nd.array(h_zeros), - c0=tvm.nd.array(c_zeros), - ) - # Execute - m.run() - # Get outputs (converted to numpy array) - tvm_output = m.get_output(0).numpy() - - compare(tvm_output, golden_output_batch) - - -@tvm.testing.uses_gpu -def test_lstms(): - for target, dev in tvm.testing.enabled_targets(): - check_lstm_with_type("uni", target, dev) - # check_lstm_with_type("p", target, dev) - check_lstm_with_type("s", target, dev) - check_lstm_with_type("b", target, dev) - # check_lstm_with_type("bp", target, dev) - # check_lstm_with_type("sp", target, dev) - check_lstm_with_type("sb", target, dev) - # check_lstm_with_type("sbp", target, dev) - - -if __name__ == "__main__": - test_lstms() diff --git a/tests/python/frontend/pytorch/test_rnns.py b/tests/python/frontend/pytorch/test_rnns.py new file mode 100644 index 000000000000..b5784a6fe1e1 --- /dev/null +++ b/tests/python/frontend/pytorch/test_rnns.py @@ -0,0 +1,430 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +import torch +import onnx +import io +import sys + +from tvm import relay +from tvm.contrib import graph_executor + +from torch import nn + +## LSTM parameters +lstm_feature_size = 16 +lstm_hidden_size = 32 +lstm_projection_size = 20 + +## GRU parameters +gru_feature_size = 8 +gru_hidden_size = 16 + +num_layers = 2 +seqs_length = 2 +batch_size = 2 + + +class RNN_Model(nn.Module): + """ + It is base class for RNN layer classes. + It contains some common fields and methods for child classes. + """ + + def __init__( + self, + ): + super().__init__() + + # model is defined in child class + self.model = None + + def forward(self, input, hidden_init=None): + """ + Computes the output tensor after input inference along RNN layer. + + :param input: batch of data as a tensor of shape (seqs_length, batch_size, feature_size) or (batch_size, seqs_length, feature_size) if self.batch_first = True + :param hidden_init: initial hidden state(s) of the RNN as a tensor(s) of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the output tensor of shape (batch_size, hidden_size) + """ + if self.model is None: + raise NotImplementedError("self.model must be defined in subclasses!") + out, _ = self.model(input, hidden_init) + + return out + + def gen_rnd_weights(self): + """ + Generate random weigths for the model + """ + if self.model is None: + raise NotImplementedError("self.model must be defined in subclasses!") + with torch.no_grad(): + for weight_group in self.model.all_weights: + for weight in weight_group: + weight.data = torch.rand(weight.shape) + + def get_dummy_inputs(self): + raise NotImplementedError("subclasses must override get_dummy_inputs()!") + + def get_input_names(self): + raise NotImplementedError("subclasses must override get_input_names()!") + + def get_shape_desc(self, frontend_type): + raise NotImplementedError("subclasses must override get_shape_desc(frontend_type)!") + + def get_tvm_inputs(self, dtype): + raise NotImplementedError("subclasses must override get_tvm_inputs(dtype)!") + + +class GRU_Model(RNN_Model): + def __init__( + self, + seq_len=seqs_length, + batch_size=batch_size, + feature_size=gru_feature_size, + hidden_size=gru_hidden_size, + batch_first=False, + layer_num=1, + bidirectional=False, + use_bias=True, + rnd_weights_init=False, + ): + super().__init__() + + # Shapes + self.shape = [seq_len, batch_size, feature_size] + if batch_first: + self.shape = [batch_size, seq_len, feature_size] + layers_num = 2 * layer_num if bidirectional else layer_num + self.h0_shape = [layers_num, batch_size, hidden_size] + # Dummy inputs + self.dummy_inputs = (torch.rand(self.shape), torch.zeros(self.h0_shape)) + + self.model = nn.GRU( + input_size=feature_size, + hidden_size=hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + batch_first=batch_first, + bias=use_bias, + ) + + if rnd_weights_init: + self.gen_rnd_weights() + + def gen_rnd_weights(self): + """ + Generate random weigths for the model with biases + For first uni- and bidirectional weights group: + Wi (3*hidden_size, feature_size) + Wh (3*hidden_size, hidden_size) + Bi (3*hidden_size) + Bh (3*hidden_size) + For other weights group: + Wi (3*hidden_size, hidden_size) + Wh (3*hidden_size, hidden_size) + Bi (3*hidden_size) + Bh (3*hidden_size) + For generation of random weigths for the model without biases the Bi and Bh weights are skipped + """ + super().gen_rnd_weights() + + def get_dummy_inputs(self): + return self.dummy_inputs + + def get_input_names(self): + return ["input", "h0"] + + def get_shape_desc(self, frontend_type): + shape_desc = None + if frontend_type == "pt": # PyTorch + shape_desc = [("input", self.shape)] + elif frontend_type == "onnx": # ONNX + shape_desc = { + "input": self.shape, + "h0": self.h0_shape, + } + return shape_desc + + def get_tvm_inputs(self, dtype): + return { + "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)), + "h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)), + } + + +def check_torch_version_for_proj_in_lstm(): + """ + proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0 torch version + """ + me = False + + version = torch.__version__ + major, minor, micro = version.split(".") + + if int(major) > 1: + me = True + elif int(major) == 1: + if int(minor) >= 8: + me = True + + return me + + +class LSTM_Model(RNN_Model): + def __init__( + self, + seq_len=seqs_length, + batch_size=batch_size, + feature_size=lstm_feature_size, + hidden_size=lstm_hidden_size, + batch_first=False, + layer_num=1, + bidirectional=False, + proj_size=0, + use_bias=True, + rnd_weights_init=False, + ): + super().__init__() + + # Shapes + self.shape = [seq_len, batch_size, feature_size] + if batch_first: + self.shape = [batch_size, seq_len, feature_size] + layers_num = 2 * layer_num if bidirectional else layer_num + self.h0_shape = [layers_num, batch_size, hidden_size] + if proj_size > 0: + self.h0_shape = [layers_num, batch_size, proj_size] + self.c0_shape = [layers_num, batch_size, hidden_size] + # Dummy inputs + self.dummy_inputs = ( + torch.rand(self.shape), + (torch.zeros(self.h0_shape), torch.zeros(self.c0_shape)), + ) + + if check_torch_version_for_proj_in_lstm(): + self.model = nn.LSTM( + input_size=lstm_feature_size, + hidden_size=lstm_hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + proj_size=proj_size, + batch_first=batch_first, + bias=use_bias, + ) + else: + if proj_size > 0: + print( + "WARNING: projection is not supported for torch version less than 1.8.0! ", + "LSTM was constructed without projection!", + ) + # sys.exit() + self.model = nn.LSTM( + input_size=lstm_feature_size, + hidden_size=lstm_hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + batch_first=batch_first, + bias=use_bias, + ) + + if rnd_weights_init: + self.gen_rnd_weights() + + def gen_rnd_weights(self): + """ + Generate random weigths for the model with biases + Without projection: + For first weights group: + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + For first bidirectional weights group: + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + For other weights group: + Wi (4*lstm_hidden_size, lstm_hidden_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + With projection: + For first weights group: + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) + For first bidirectional weights group: + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) + For other weights group: + Wi (4*lstm_hidden_size, proj_size * num_directions) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) + For generation of random weigths for the model without biases Bi and Bh are skipped + """ + super().gen_rnd_weights() + + def get_dummy_inputs(self): + return self.dummy_inputs + + def get_input_names(self): + return ["input", "h0", "c0"] + + def get_shape_desc(self, frontend_type): + shape_desc = None + if frontend_type == "pt": # PyTorch + shape_desc = [("input", self.shape)] + elif frontend_type == "onnx": # ONNX + shape_desc = { + "input": self.shape, + "h0": self.h0_shape, + "c0": self.c0_shape, + } + return shape_desc + + def get_tvm_inputs(self, dtype): + return { + "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)), + "h0": tvm.nd.array(self.dummy_inputs[1][0].numpy().astype(dtype)), + "c0": tvm.nd.array(self.dummy_inputs[1][1].numpy().astype(dtype)), + } + + +def compare(input, gold_data, rtol=1e-5, atol=1e-5): + tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol) + + +def check_rnn(rnn_type, rnn_mod, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0)): + def get_model( + rnn_type, + rnn_mod, + args, + ): + # Fill args + if "b" in rnn_mod: + args["bidirectional"] = True + if "s" in rnn_mod: + args["layer_num"] = num_layers + + if rnn_type == "GRU": + RNN_Model_selector = GRU_Model + elif rnn_type == "LSTM": + RNN_Model_selector = LSTM_Model + if "p" in rnn_mod: + args["proj_size"] = lstm_projection_size + + return RNN_Model_selector(**args) + + def get_onnx_model(model): + onnx_io = io.BytesIO() + with torch.no_grad(): + input_names = model.get_input_names() + inputs = model.get_dummy_inputs() + + # default export (without dynamic input) + torch.onnx.export(model, inputs, onnx_io, input_names=input_names) + + onnx_io.seek(0, 0) + return onnx.load_model(onnx_io) + + model = None + dtype = "float32" + device = torch.device("cpu") + for batch_first in (True, False): + for use_bias in (True, False): + for rnd_weights in [True]: # (True, False): + model_inputs = { + "batch_first": batch_first, + "use_bias": use_bias, + "rnd_weights_init": rnd_weights, + } + model = get_model(rnn_type, rnn_mod, model_inputs) + model.to(device) + model.eval() + + # Get golden output from original model + dummy_inputs = model.get_dummy_inputs() + golden_output = model.forward(dummy_inputs[0].to(device)).detach().cpu().numpy() + + tvm_output = None + for format in ["pt"]: # ["pt", "onnx"]: + shape_desc = model.get_shape_desc(format) + if format == "pt": + # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. + traced_script_module = torch.jit.trace(model, dummy_inputs[0]).eval() + + # Import model to Relay + mod, params = relay.frontend.from_pytorch(traced_script_module, shape_desc) + elif format == "onnx": + try: + onnx_model = get_onnx_model(model) + except: + print( + "WARNING: torch.onnx.export does not support conversion LSTM with projection " + "from pytorch! TODO: waiting for the support and correct test after that." + ) + continue + + # Import model to Relay + mod, params = relay.frontend.from_onnx(onnx_model, shape_desc) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + + # Inference of the model with given input data + m = graph_executor.GraphModule(lib["default"](dev)) + + # Set inputs + tvm_inputs = model.get_tvm_inputs(dtype) + m.set_input(**tvm_inputs) + # Execute + m.run() + # Get outputs (converted to numpy array) + tvm_output = m.get_output(0).numpy() + + compare(tvm_output, golden_output) + + +@tvm.testing.uses_gpu +def test_rnns(): + for target, dev in tvm.testing.enabled_targets(): + # RNN types: GRU, LSTM + # GRU modifications: unidirectional, stacked, bidirectional, stacked bidirectional + for mod_type in ["uni", "s", "b", "sb"]: + check_rnn("GRU", mod_type, target, dev) + # LSTM modifications: unidirectional, stacked, bidirectional, stacked bidirectional, + # and all these types with projection ("p", "sp", "bp", "sbp") + # The latter are skiped for test acceleration + for mod_type in ["uni", "s", "b", "sb"]: + check_rnn("LSTM", mod_type, target, dev) + + +if __name__ == "__main__": + test_rnns()