-
Notifications
You must be signed in to change notification settings - Fork 35
/
fast_decoders.py
75 lines (55 loc) · 2.75 KB
/
fast_decoders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import torch.nn.functional as F
from onmt_modules.misc import sequence_mask
class DecodeFunc_Sp(object):
"""
Decoding functions
"""
def __init__(self, hparams, type_out):
if type_out == 'Sp':
self.dim_freq = hparams.dim_freq
self.max_decoder_steps = hparams.dec_steps_sp
elif type_out == 'Tx':
self.dim_freq = hparams.dim_code
self.max_decoder_steps = hparams.dec_steps_tx
else:
raise ValueError
self.gate_threshold = hparams.gate_threshold
self.type_out = type_out
def __call__(self, tgt, memory_bank, memory_lengths, decoder, postnet):
dec_outs, attns = decoder(tgt, memory_bank, step=None,
memory_lengths=memory_lengths)
spect_gate = postnet(dec_outs)
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
return spect, gate
def infer(self, tgt_real, memory_bank, memory_lengths, decoder, postnet):
B = memory_bank.size(1)
device = memory_bank.device
spect_outputs = torch.zeros((self.max_decoder_steps, B, self.dim_freq),
dtype=torch.float, device=device)
gate_outputs = torch.zeros((self.max_decoder_steps, B, 1),
dtype=torch.float, device=device)
tgt_words = torch.zeros([B, 1],
dtype=torch.float, device=device)
current_pred = torch.zeros([1, B, self.dim_freq],
dtype=torch.float, device=device)
for t in range(self.max_decoder_steps):
dec_outs, _ = decoder(current_pred,
memory_bank, t,
memory_lengths=memory_lengths,
tgt_words=tgt_words)
spect_gate = postnet(dec_outs)
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
spect_outputs[t:t+1] = spect
gate_outputs[t:t+1] = gate
stop = (torch.sigmoid(gate) - self.gate_threshold + 0.5).round()
current_pred = spect.data
tgt_words = stop.squeeze(-1).t()
if t == self.max_decoder_steps - 1:
print(f"Warning! {self.type_out} reached max decoder steps")
if (stop == 1).all():
break
stop_quant = (torch.sigmoid(gate_outputs.data) - self.gate_threshold + 0.5).round().squeeze(-1)
len_spect = (stop_quant.cumsum(dim=0)==0).sum(dim=0)
return spect_outputs, len_spect, gate_outputs