-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathmel.py
76 lines (65 loc) · 2.84 KB
/
mel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# MIT License
# Copyright (c) 2019 Descript Inc.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
from librosa.filters import mel as librosa_mel_fn
class Audio2Mel(nn.Module):
def __init__(self, n_fft, hop_length, win_length, sampling_rate,
n_mel_channels, mel_fmin, mel_fmax):
super().__init__()
window = torch.hann_window(win_length).float()
mel_basis = librosa_mel_fn(
sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("window", window)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
def forward(self, audio):
p = (self.n_fft - self.hop_length) // 2
audio = F.pad(audio, (p, p), "reflect").squeeze(1)
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=False,
)
magnitude = torch.norm(fft, p=2, dim=-1)
mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
return log_mel_spec
if __name__ == '__main__':
filename = librosa.util.example_audio_file()
y, sr = librosa.load(filename, sr=22050)
y = y[:163840]
# y = y.astype(np.float32)
y = torch.from_numpy(y).view(1, 1, -1).cuda()
print(y.shape)
audio2mel = Audio2Mel(1024, 256, 1024, 22050, 80, 70.0, 7600.0).cuda()
mel = audio2mel(y)
print(mel.shape)
print(torch.min(mel))
print(torch.max(mel))