forked from sigsep/open-unmix-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
110 lines (92 loc) · 3.42 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import utils
import torch.hub
# Optional list of dependencies required by the package
dependencies = ['torch', 'numpy']
def umxhq(
target='vocals', device='cpu', pretrained=True, *args, **kwargs
):
"""
Open Unmix 2-channel/stereo BiLSTM Model trained on MUSDB18-HQ
Args:
target (str): select the target for the source to be separated.
Supported targets are
['vocals', 'drums', 'bass', 'other']
pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
device (str): selects device to be used for inference
"""
# set urls for weights
target_urls = {
'bass': 'https://zenodo.org/api/files/1c8f83c5-33a5-4f59-b109-721fdd234875/bass-8d85a5bd.pth',
'drums': 'https://zenodo.org/api/files/1c8f83c5-33a5-4f59-b109-721fdd234875/drums-9619578f.pth',
'other': 'https://zenodo.org/api/files/1c8f83c5-33a5-4f59-b109-721fdd234875/other-b52fbbf7.pth',
'vocals': 'https://zenodo.org/api/files/1c8f83c5-33a5-4f59-b109-721fdd234875/vocals-b62c91ce.pth'
}
from model import OpenUnmix
# determine the maximum bin count for a 16khz bandwidth model
max_bin = utils.bandwidth_to_max_bin(
rate=44100,
n_fft=4096,
bandwidth=16000
)
# load open unmix model
unmix = OpenUnmix(
n_fft=4096,
n_hop=1024,
nb_channels=2,
hidden_size=512,
max_bin=max_bin
)
# enable centering of stft to minimize reconstruction error
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
target_urls[target],
map_location=device
)
unmix.load_state_dict(state_dict)
unmix.stft.center = True
unmix.eval()
return unmix.to(device)
def umx(
target='vocals', device='cpu', pretrained=True, *args, **kwargs
):
"""
Open Unmix 2-channel/stereo BiLSTM Model trained on MUSDB18
Args:
target (str): select the target for the source to be separated.
Supported targets are
['vocals', 'drums', 'bass', 'other']
pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
device (str): selects device to be used for inference
"""
# set urls for weights
target_urls = {
'bass': 'https://zenodo.org/api/files/d6105b95-8c52-430c-84ce-bd14b803faaf/bass-646024d3.pth',
'drums': 'https://zenodo.org/api/files/d6105b95-8c52-430c-84ce-bd14b803faaf/drums-5a48008b.pth',
'other': 'https://zenodo.org/api/files/d6105b95-8c52-430c-84ce-bd14b803faaf/other-f8e132cc.pth',
'vocals': 'https://zenodo.org/api/files/d6105b95-8c52-430c-84ce-bd14b803faaf/vocals-c8df74a5.pth'
}
from model import OpenUnmix
# determine the maximum bin count for a 16khz bandwidth model
max_bin = utils.bandwidth_to_max_bin(
rate=44100,
n_fft=4096,
bandwidth=16000
)
# load open unmix model
unmix = OpenUnmix(
n_fft=4096,
n_hop=1024,
nb_channels=2,
hidden_size=512,
max_bin=max_bin
)
# enable centering of stft to minimize reconstruction error
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
target_urls[target],
map_location=device
)
unmix.load_state_dict(state_dict)
unmix.stft.center = True
unmix.eval()
return unmix.to(device)