-
Notifications
You must be signed in to change notification settings - Fork 423
/
model.py
105 lines (91 loc) · 3.57 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import os
import torch
from torch import nn
from asteroid import torch_utils
from asteroid import torch_utils
from asteroid_filterbanks import Encoder, Decoder, FreeFB
from asteroid.masknn.recurrent import SingleRNN
from asteroid.engine.optimizers import make_optimizer
from asteroid.masknn.norms import GlobLN
class TasNet(nn.Module):
"""Some kind of TasNet, but not the original one
Differences:
- Overlap-add support (strided convolutions)
- No frame-wise normalization on the wavs
- GlobLN as bottleneck layer.
- No skip connection.
Args:
fb_conf (dict): see local/conf.yml
mask_conf (dict): see local/conf.yml
"""
def __init__(self, fb_conf, mask_conf):
super().__init__()
self.n_src = mask_conf["n_src"]
self.n_filters = fb_conf["n_filters"]
# Create TasNet encoders and decoders (could use nn.Conv1D as well)
self.encoder_sig = Encoder(FreeFB(**fb_conf))
self.encoder_relu = Encoder(FreeFB(**fb_conf))
self.decoder = Decoder(FreeFB(**fb_conf))
self.bn_layer = GlobLN(fb_conf["n_filters"])
# Create TasNet masker
self.masker = nn.Sequential(
SingleRNN(
"lstm",
fb_conf["n_filters"],
hidden_size=mask_conf["n_units"],
n_layers=mask_conf["n_layers"],
bidirectional=True,
dropout=mask_conf["dropout"],
),
nn.Linear(2 * mask_conf["n_units"], self.n_src * self.n_filters),
nn.Sigmoid(),
)
def forward(self, x):
batch_size = x.shape[0]
if len(x.shape) == 2:
x = x.unsqueeze(1)
tf_rep = self.encode(x)
to_sep = self.bn_layer(tf_rep)
est_masks = self.masker(to_sep.transpose(-1, -2)).transpose(-1, -2)
est_masks = est_masks.view(batch_size, self.n_src, self.n_filters, -1)
masked_tf_rep = tf_rep.unsqueeze(1) * est_masks
return torch_utils.pad_x_to_y(self.decoder(masked_tf_rep), x)
def encode(self, x):
relu_out = torch.relu(self.encoder_relu(x))
sig_out = torch.sigmoid(self.encoder_sig(x))
return sig_out * relu_out
def make_model_and_optimizer(conf):
"""Function to define the model and optimizer for a config dictionary.
Args:
conf: Dictionary containing the output of hierachical argparse.
Returns:
model, optimizer.
The main goal of this function is to make reloading for resuming
and evaluation very simple.
"""
model = TasNet(conf["filterbank"], conf["masknet"])
# Define optimizer of this model
optimizer = make_optimizer(model.parameters(), **conf["optim"])
return model, optimizer
def load_best_model(train_conf, exp_dir):
"""Load best model after training.
Args:
train_conf (dict): dictionary as expected by `make_model_and_optimizer`
exp_dir(str): Experiment directory. Expects to find
`'best_k_models.json'` there.
Returns:
nn.Module the best pretrained model according to the val_loss.
"""
# Create the model from recipe-local function
model, _ = make_model_and_optimizer(train_conf)
# Last best model summary
with open(os.path.join(exp_dir, "best_k_models.json"), "r") as f:
best_k = json.load(f)
best_model_path = min(best_k, key=best_k.get)
# Load checkpoint
checkpoint = torch.load(best_model_path, map_location="cpu")
# Load state_dict into model.
model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model)
model.eval()
return model