Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend, pytorch] Vc/pytorch lstm #8447

Merged
merged 11 commits into from
Jul 20, 2021
Prev Previous commit
Next Next commit
fix pytorch bidirectional lstm. update test comment
Valery Chernov committed Jul 13, 2021
commit 367c5feebd5eb5aeb6134673719b604e692b9062
17 changes: 11 additions & 6 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
@@ -2377,26 +2377,31 @@ def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair):
rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num, (batch, hidden_size)]
rev_outputs = self.lstm_cell(rev_input_seq, hidden_pair[1], weights_pair[1])

return _op.concatenate([_op.stack(fw_outputs[0], 0), _op.stack(rev_outputs[0], 0)], -1), (fw_outputs[1], rev_outputs[1])
final_outputs = [] # [seq_num, (batch, 2 * hidden_size)]
for j in range(seq_len):
final_outputs.append(_op.concatenate([ fw_outputs[0][j], rev_outputs[0][seq_len - 1 - j] ], -1))

return final_outputs, (fw_outputs[1], rev_outputs[1])

def lstm_layers(self, input, hiddens, weights, bidirectional, dtype, dropout_p = 0.0):
hidden_layers_num = len(hiddens)
assert len(weights) == hidden_layers_num

# split input sequence to samples set
input = self.unbind((input, 0), dtype) # [seq_num, (batch, feature_size)]
input_seqs = self.unbind((input, 0), dtype) # [seq_num, (batch, feature_size)]
output_hiddens = []
for k in range(hidden_layers_num):
hiddens_input = hiddens[k]
weights_input = weights[k]

outputs = self.bidir_lstm_cell(input, hiddens_input, weights_input) if bidirectional else self.lstm_cell(input, hiddens_input, weights_input)
outputs = self.bidir_lstm_cell(input_seqs, hiddens_input, weights_input) if bidirectional else self.lstm_cell(input_seqs, hiddens_input, weights_input)

output_hiddens.append(outputs[1])
input = outputs[0] # [seq_num, (batch, feature_size)] or [seq_num, (batch, 2*feature_size)] for bidirectional
input_seqs = outputs[0] # [seq_num, (batch, feature_size)] or [seq_num, (batch, 2*feature_size)] for bidirectional

if dropout_p != 0 and k < hidden_layers_num - 1: # TODO (vvchernov): in pytorch implementation train is also checked
# input = _op.dropout(input, dropout_p)
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
raise NotImplementedError("Dropout for LSTM has not been supported yet!")
final_hiddens = []
if bidirectional:
@@ -2406,7 +2411,7 @@ def lstm_layers(self, input, hiddens, weights, bidirectional, dtype, dropout_p =
else:
final_hiddens = output_hiddens

return _op.stack(input, 0), final_hiddens
return _op.stack(input_seqs, 0), final_hiddens

def lstm(self, inputs, input_types):
# Description of LSTM in pytorch: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
3 changes: 2 additions & 1 deletion tests/python/frontend/pytorch/test_lstms.py
Original file line number Diff line number Diff line change
@@ -71,7 +71,8 @@ def __init__(self, device, batch_first = False, layer_num = 1, bidirectional = F
proj_size=proj_size,
batch_first=batch_first).to(device)
else:
print('WARNING: projection is not supported for torch version less than 1.8.0!')
if proj_size > 0:
print('WARNING: projection is not supported for torch version less than 1.8.0!')
self.lstm = nn.LSTM(input_size=model_feature_size,
hidden_size=model_hidden_size,
num_layers=layer_num,