-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathhelpers.py
105 lines (81 loc) · 3.55 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_packed_sequence
class RecurrentHelper:
@staticmethod
def last_by_index(outputs, lengths):
# Index of the last output for each sequence.
idx = (lengths - 1).view(-1, 1).expand(outputs.size(0),
outputs.size(2)).unsqueeze(1)
return outputs.gather(1, idx).squeeze()
def last_timestep(self, outputs, lengths, bi=False):
if bi:
forward, backward = self.split_directions(outputs)
last_forward = self.last_by_index(forward, lengths)
last_backward = backward[:, 0, :]
return torch.cat((last_forward, last_backward), dim=-1)
else:
return self.last_by_index(outputs, lengths)
@staticmethod
def split_directions(outputs):
direction_size = int(outputs.size(-1) / 2)
forward = outputs[:, :, :direction_size]
backward = outputs[:, :, direction_size:]
return forward, backward
def pad_outputs(self, out_packed, max_length):
out_unpacked, _lengths = pad_packed_sequence(out_packed,
batch_first=True)
# pad to initial max length
pad_length = max_length - out_unpacked.size(1)
out_unpacked = F.pad(out_unpacked, (0, 0, 0, pad_length))
return out_unpacked
@staticmethod
def project2vocab(output, projection):
# output_unpacked.size() = batch_size, max_length, hidden_units
# flat_outputs = (batch_size*max_length, hidden_units),
# which means that it is a sequence of *all* the outputs (flattened)
flat_output = output.contiguous().view(output.size(0) * output.size(1),
output.size(2))
# the sequence of all the output projections
decoded_flat = projection(flat_output)
# reshaped the flat sequence of decoded words,
# in the original (reshaped) form (3D tensor)
decoded = decoded_flat.view(output.size(0), output.size(1),
decoded_flat.size(1))
return decoded
@staticmethod
def sort_by(lengths):
"""
Sort batch data and labels by length.
Useful for variable length inputs, for utilizing PackedSequences
Args:
lengths (nn.Tensor): tensor containing the lengths for the data
Returns:
- sorted lengths Tensor
- sort (callable) which will sort a given iterable
according to lengths
- unsort (callable) which will revert a given iterable to its
original order
"""
batch_size = lengths.size(0)
sorted_lengths, sorted_idx = lengths.sort()
_, original_idx = sorted_idx.sort(0, descending=True)
reverse_idx = torch.linspace(batch_size - 1, 0, batch_size).long()
if lengths.data.is_cuda:
reverse_idx = reverse_idx.cuda()
sorted_lengths = sorted_lengths[reverse_idx]
def sort(iterable):
if iterable is None:
return None
if len(iterable.shape) > 1:
return iterable[sorted_idx][reverse_idx]
else:
return iterable
def unsort(iterable):
if iterable is None:
return None
if len(iterable.shape) > 1:
return iterable[reverse_idx][original_idx][reverse_idx]
else:
return iterable
return sorted_lengths, sort, unsort