-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathgetmodel.py
81 lines (72 loc) · 1.94 KB
/
getmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch.nn.functional as F
from torchaudio.models import ConvTasNet
from losses import ScaleInvariantSDRLoss
from models import UNet, UNetDNP, TransUNet
default_params = {
"UNet": {
"n_channels": 1,
"n_class": 2,
"unet_depth": 6,
"unet_scale_factor": 16,
},
"UNetDNP": {
"n_channels": 1,
"n_class": 2,
"unet_depth": 6,
"n_filters": 16,
},
"ConvTasNet": {
"num_sources": 2,
"enc_kernel_size": 16,
"enc_num_feats": 128,
"msk_kernel_size": 3,
"msk_num_feats": 32,
"msk_num_hidden_feats": 128,
"msk_num_layers": 8,
"msk_num_stacks": 3,
},
"TransUNet": {
"img_dim": 256,
"in_channels": 1,
"classes": 2,
"vit_blocks": 6, # 12
"vit_heads": 4,
"vit_dim_linear_mhsa_block": 128, # 1024
"apply_masks": True
},
"SepFormer": {},
}
def get_model(name, parameters=None):
if not parameters:
parameters = default_params[name]
if name == "UNet":
model = UNet(**parameters)
data_mode = "amplitude"
loss_fn = F.mse_loss
loss_mode = "min"
if name == "UNetDNP":
model = UNetDNP(**parameters)
data_mode = "time"
loss_fn = ScaleInvariantSDRLoss
loss_mode = "max"
if name == "ConvTasNet":
model = ConvTasNet(**parameters)
data_mode = "time"
loss_fn = ScaleInvariantSDRLoss
loss_mode = "max"
if name == "TransUNet":
model = TransUNet(**parameters)
data_mode = "amplitude"
loss_fn = F.mse_loss
loss_mode = "min"
# if name == "SepFormer":
# model = Sepformer(**parameters)
# data_mode = "time"
# loss_fn = ScaleInvariantSDRLoss
# loss_mode = "max"
return {
"model": model,
"data_mode": data_mode,
"loss_fn": loss_fn,
"loss_mode": loss_mode,
}