diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6edcce372ec7c..379a3633df334 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -41,7 +41,6 @@ from . import qnn_torch from .common import AttrCvt, get_relay_op from .common import infer_value as _infer_value -from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape from .common import infer_value_simulated as _infer_value_simulated from .common import try_infer_value @@ -2359,7 +2358,6 @@ def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False): fw_outputs = self.lstm_cell(input_seq, hidden_pair[0], weights_pair[0], has_proj) rev_input_seq = [] - _op.reverse_sequence seq_len = len(input_seq) for i in range(seq_len): rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num, (batch, hidden_size)] @@ -2374,13 +2372,13 @@ def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False): return final_outputs, (fw_outputs[1], rev_outputs[1]) def lstm_layers( - self, input, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False + self, input_data, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False ): hidden_layers_num = len(hiddens) assert len(weights) == hidden_layers_num # split input sequence to samples set - input_seqs = self.unbind((input, 0), dtype) # [seq_num, (batch, feature_size)] + input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)] output_hiddens = [] for k in range(hidden_layers_num): hiddens_input = hiddens[k] @@ -2393,10 +2391,13 @@ def lstm_layers( ) output_hiddens.append(outputs[1]) - # input_seqs shape = [seq_num, (batch, feature_size)] or [seq_num, (batch, 2*feature_size)] for bidirectional + # input_seqs shape = [seq_num, (batch, feature_size)] or + # [seq_num, (batch, 2*feature_size)] for bidirectional input_seqs = outputs[0] - # TODO (vvchernov): in pytorch implementation train is also checked (see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/src/ATen/native/RNN.cpp#L1054) + # TODO (vvchernov): in pytorch implementation train is also checked + # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 + # /aten/src/ATen/native/RNN.cpp#L1054 if dropout_p != 0 and k < hidden_layers_num - 1: # for input in input_seqs: # input = _op.dropout(input, dropout_p) @@ -2412,9 +2413,14 @@ def lstm_layers( return _op.stack(input_seqs, 0), final_hiddens def lstm(self, inputs, input_types): - # Description of LSTM in pytorch: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html - # https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/src/ATen/native/RNN.cpp#L1396 (projection is unsupported) and dependencies were used - # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483 (projection is supported) and dependencies were used + """ + Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html + Native implementation for torch version less than 1.8.0 (projection is unsupported): + https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \ + src/ATen/native/RNN.cpp#L1396 + Native implementation for torch version from 1.8.0 and higher (projection is supported): + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483 + """ # TODO (vvchernov): support dropout assert len(inputs) == 9, "Input of size 9 is expected" # Unpack inputs, note that if optional and not provided then value will be None. @@ -2425,7 +2431,8 @@ def lstm(self, inputs, input_types): assert len(hidden_states) == 2, "lstm expects two hidden states" h_0 = hidden_states[0] c_0 = hidden_states[1] - # H0 shape (hidden_layers_num, batch, proj_size) if projection else (hidden_layers_num, batch, hidden_size) + # H0 shape (hidden_layers_num, batch, proj_size) if projection + # else (hidden_layers_num, batch, hidden_size) # C0 shape (hidden_layers_num, batch, hidden_size) _weights = inputs[2] @@ -2514,7 +2521,7 @@ def lstm(self, inputs, input_types): fw_weights = [] rev_weights = [] for j in range(weights_num + 2): - if j == 2 or j == 3: + if j in (2, 3): fw_weights.append(None) rev_weights.append(None) else: @@ -2526,7 +2533,7 @@ def lstm(self, inputs, input_types): for i in range(0, len(_weights), weights_num): fw_weights = [] for j in range(weights_num + 2): - if j == 2 or j == 3: + if j in (2, 3): fw_weights.append(None) else: fw_weights.append(_weights[i + j]) @@ -2536,7 +2543,8 @@ def lstm(self, inputs, input_types): ), "For stacked LSTM number of weights tuples should be the same as number of layers!" X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X - # TODO (vvchernov): Which data type should be used? from input or weights (use _weights[0])? Also _infer_type(X).checked_type.dtype can be used + # TODO (vvchernov): Which data type should be used? from input or weights? + # Instead of it _infer_type(X).checked_type.dtype can be used X_dtype = input_types[0] X_shape = _infer_shape(X) # (seq_num, batch, feature_size) @@ -2582,7 +2590,8 @@ def lstm(self, inputs, input_types): has_proj=has_proj, ) - # output shape = (seq_num, batch, hidden_size) or (seq_num, batch, 2*feature_size) for bidirectional + # output shape = (seq_num, batch, hidden_size) or + # (seq_num, batch, 2*feature_size) for bidirectional output = outputs[0] hy = []