-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
96 lines (68 loc) · 2.4 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
import numpy as np
import librosa
import matplotlib.pyplot as plt
import IPython.display as ipd
import sys
import soundfile as sf
from hparams import hparams
from tacotron.tacotron import Tacotron
from waveglow.denoiser import Denoiser
from text import text_to_sequence
from text.symbols import symbols
import matplotlib.pylab as plt
import IPython.display as ipd
import sys
sys.path.append('waveglow/')
import numpy as np
import torch
from utils.audio_processing import griffin_lim
from text import text_to_sequence
from waveglow.denoiser import Denoiser
# In[2]:
hparams['smaple_rate'] = 22050
def plot_data(data, figsize=(16, 4)):
fig, axes = plt.subplots(1, len(data), figsize=figsize)
for i in range(len(data)):
axes[i].imshow(data[i], aspect='auto', origin='bottom',
interpolation='none')
plt.show()
# In[3]:
checkpoint_path = './checkpoint_path/checkpoint_30000'
waveglow_path = 'waveglow/waveglow_256channels.pt'
num_speakers = 2
speaker_id = 0
text = "설문을 만들었습니다."
# In[4]:
cuda1 = torch.device('cuda:1')
model = Tacotron(hparams, len(symbols), num_speakers=num_speakers).cuda(1)
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
_ = model.cuda().eval().half()
# In[7]:
waveglow = torch.load(waveglow_path)['model']
waveglow.cuda().eval().half()
for m in waveglow.modules():
if 'Conv' in str(type(m)):
setattr(m, 'padding_mode', 'zeros')
for k in waveglow.convinv:
k.float()
denoiser = Denoiser(waveglow)
# In[ ]:
sequence = np.array(text_to_sequence(text))[None, :]
sequence = torch.autograd.Variable(
torch.from_numpy(sequence)).cuda().long()
speaker_id = np.array([speaker_id])
speaker_id = speaker_id.reshape(speaker_id.shape[0], -1)
speaker_id = torch.autograd.Variable(torch.from_numpy(speaker_id)).cuda().long()
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, speaker_id)
plot_data((mel_outputs.float().data.cpu().numpy()[0],
mel_outputs_postnet.float().data.cpu().numpy()[0],
alignments.float().data.cpu().numpy()[0].T))
import soundfile as sf
with torch.no_grad():
audio = waveglow.infer(mel_outputs_postnet, sigma=0.666)
sf.write('tone_440.wav', audio[0].float().data.cpu().numpy(), hparams['sample_rate'], format='WAV', endian='LITTLE', subtype='PCM_16') # 깨지지 않음
print(np.shape(audio))