Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
hrnoh24 committed Jul 30, 2024
2 parents 6caf053 + de58b77 commit 6e890c9
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 14 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,7 @@ python src/train.py trainer.max_epochs=20 data.batch_size=64

https://github.com/lucidrains/audiolm-pytorch/tree/main
https://github.com/facebookresearch/encodec/tree/main
https://github.com/jaywalnut310/vits/tree/main
https://github.com/jaywalnut310/vits/tree/main
https://github.com/wesbz/SoundStream/blob/main/main.py
https://github.com/bshall/soft-vc
https://github.com/descriptinc/melgan-neurips/tree/master
14 changes: 7 additions & 7 deletions configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
defaults:
- model_checkpoint
- early_stopping
# - early_stopping
- model_summary
- rich_progress_bar
- _self_

model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
monitor: "val/acc"
mode: "max"
monitor: "val/recon_loss"
mode: "min"
save_last: True
auto_insert_metric_name: False

early_stopping:
monitor: "val/acc"
patience: 100
mode: "max"
# early_stopping:
# monitor: "val/acc"
# patience: 100
# mode: "max"

model_summary:
max_depth: -1
2 changes: 1 addition & 1 deletion configs/experiment/streamvc_v1.0.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ data:

trainer:
min_epochs: 10
max_epochs: 10
max_epochs: 500
3 changes: 2 additions & 1 deletion src/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .gan_losses import discriminator_loss, feature_loss, generator_loss
from .reconstruction_loss import spectral_reconstruction_loss
# from .reconstruction_loss import spectral_reconstruction_loss
from .mel_loss import ReconstructionLoss
81 changes: 81 additions & 0 deletions src/models/losses/mel_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from librosa.filters import mel

class Audio2Mel(nn.Module):
def __init__(
self,
n_fft=1024,
hop_length=256,
win_length=1024,
sampling_rate=22050,
n_mel_channels=80,
mel_fmin=0.0,
mel_fmax=None,
):
super().__init__()
##############################################
# FFT Parameters #
##############################################
window = torch.hann_window(win_length).float()
mel_basis = mel(
sr=sampling_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("window", window)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels

def forward(self, audio):
p = (self.n_fft - self.hop_length) // 2
audio = F.pad(audio, (p, p), "reflect").squeeze(1)
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=False,
return_complex=False,
)
real_part, imag_part = fft.unbind(-1)
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
return log_mel_spec

class ReconstructionLoss(nn.Module):
def __init__(self,
n_fft=1024,
hop_length=256,
win_length=1024,
sampling_rate=16000,
n_mel_channels=80,
mel_fmin=0.0,
mel_fmax=None,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.fft = Audio2Mel(n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
sampling_rate=sampling_rate,
n_mel_channels=n_mel_channels,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax)

def forward(self, x, G_x):
S_x = self.fft(x)
S_G_x = self.fft(G_x)

loss = F.l1_loss(S_x, S_G_x)
return loss
2 changes: 1 addition & 1 deletion src/models/losses/reconstruction_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def spectral_reconstruction_loss(x, G_x, sr=16000, eps=1e-4, device="cpu"):
sample_rate=sr,
n_fft=s,
hop_length=s // 4,
n_mels=8,
n_mels=64,
wkwargs={"device": device},
).to(device)
S_x = melspec(x)
Expand Down
8 changes: 5 additions & 3 deletions src/models/streamvc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
generator_loss,
discriminator_loss,
feature_loss,
spectral_reconstruction_loss,
# spectral_reconstruction_loss,
ReconstructionLoss
)


Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
self.scheduler_d = scheduler_d

self.criterion = torch.nn.CrossEntropyLoss()
self.reconstruction_loss = ReconstructionLoss()

def forward(
self, x: torch.Tensor, pitch: torch.Tensor, energy: torch.Tensor
Expand Down Expand Up @@ -133,7 +135,7 @@ def training_step(

loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, _ = generator_loss(y_d_hat_g)
loss_recon = spectral_reconstruction_loss(y, y_hat)
loss_recon = self.reconstruction_loss(y, y_hat)
loss_content = self.criterion(logits.transpose(1, 2), labels)
loss_all = 100 * loss_fm + loss_gen + loss_recon + loss_content
self.manual_backward(loss_all)
Expand Down Expand Up @@ -173,7 +175,7 @@ def validation_step(

loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, _ = generator_loss(y_d_hat_g)
loss_recon = spectral_reconstruction_loss(y, y_hat)
loss_recon = self.reconstruction_loss(y, y_hat)
loss_content = self.criterion(logits.transpose(1, 2), labels)
loss_all = 100 * loss_fm + loss_gen + loss_recon + loss_content

Expand Down

0 comments on commit 6e890c9

Please sign in to comment.