Skip to content

Commit

Permalink
remove weight_norm, add torch hub
Browse files Browse the repository at this point in the history
  • Loading branch information
seungwonpark committed Oct 28, 2019
1 parent 1725a75 commit 36d5071
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 13 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Unofficial PyTorch implementation of [MelGAN vocoder](https://arxiv.org/abs/1910

- MelGAN is lighter, faster, and better at generalizing to unseen speakers than [WaveGlow](https://github.com/NVIDIA/waveglow).
- This repository use identical mel-spectrogram function from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), so this can be directly used to convert output from NVIDIA's tacotron2 into raw-audio.
- TODO: Planning to publish pretrained model via [PyTorch Hub](https://pytorch.org/hub).
- Pretrained model on LJSpeech-1.1 via [PyTorch Hub](https://pytorch.org/hub).

![](./assets/gd.png)

Expand All @@ -27,6 +27,24 @@ pip install -r requirements.txt
- `python trainer.py -c [config yaml file] -n [name of the run]`
- `tensorboard --logdir logs/`

## Pretrained model

Try with Google Colab:

```python
import torch
vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
vocoder.eval()
mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here

if torch.cuda.is_available():
vocoder = vocoder.cuda()
mel = mel.cuda()

with torch.no_grad():
audio = vocoder(mel)
```

## Inference

- `python inference.py -p [checkpoint path] -i [input mel path]`
Expand Down
40 changes: 40 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
dependencies = ['torch']
from model.generator import Generator

model_params = {
'nvidia_tacotron2_LJ11_epoch3200': {
'mel_channel': 80,
'model_url': '',
},
}


def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progress=True):
params = model_params[model_name]
model = Generator(params['mel_channel'])

if pretrained:
state_dict = torch.hub.load_state_dict_from_url(params['model_url'],
progress=progress)
model.load_state_dict(state_dict['model_g'])

model.eval(inference=True)

return model


if __name__ == '__main__':
vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here

print('Input mel-spectrogram shape: {}'.format(mel.shape))

if torch.cuda.is_available():
print('Moving data & model to GPU')
vocoder = vocoder.cuda()
mel = mel.cuda()

with torch.no_grad():
audio = vocoder.inference(mel)

print('Output audio shape: {}'.format(audio.shape))
14 changes: 2 additions & 12 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main(args):

model = Generator(hp.audio.n_mel_channels).cuda()
model.load_state_dict(checkpoint['model_g'])
model.eval()
model.eval(inference=False)

with torch.no_grad():
for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
Expand All @@ -29,17 +29,7 @@ def main(args):
mel = mel.unsqueeze(0)
mel = mel.cuda()

# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).cuda()
mel = torch.cat((mel, zero), axis=2)

audio = model(mel)
audio = audio.squeeze() # collapse all dimension except time axis
audio = audio[:-(hp.audio.hop_length*10)]
audio = MAX_WAV_VALUE * audio
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE)
audio = audio.short()
audio = model.inference(hp, mel)
audio = audio.cpu().detach().numpy()

out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
Expand Down
33 changes: 33 additions & 0 deletions model/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from .res_stack import ResStack
#from res_stack import ResStack

MAX_WAV_VALUE = 32768.0


class Generator(nn.Module):
def __init__(self, mel_channel):
super(Generator, self).__init__()
self.mel_channel = mel_channel

self.generator = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1, padding=3)),
Expand Down Expand Up @@ -42,6 +45,36 @@ def forward(self, mel):
mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
return self.generator(mel)

def eval(self, inference=False):
super(Generator, self).eval()

# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()

def remove_weight_norm(self):
for idx, layer in enumerate(self.generator):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except:
layer.remove_weight_norm()

def inference(self, hp, mel):
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).to(mel.device)
mel = torch.cat((mel, zero), axis=2)

audio = self.forward(mel)
audio = audio.squeeze() # collapse all dimension except time axis
audio = audio[:-(hp.audio.hop_length*10)]
audio = MAX_WAV_VALUE * audio
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
audio = audio.short()

return audio


'''
to run this, fix
Expand Down
5 changes: 5 additions & 0 deletions model/res_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return x

def remove_weight_norm(self):
for layer in self.layers:
nn.utils.remove_weight_norm(layer[1])
nn.utils.remove_weight_norm(layer[3])

0 comments on commit 36d5071

Please sign in to comment.