Skip to content

Commit

Permalink
Add qformer solution
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Sep 27, 2023
1 parent 82b3fe5 commit a592056
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 63 deletions.
25 changes: 8 additions & 17 deletions configs/tts_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
type="NaiveTTSDataset",
path=str(speaker_path),
speaker_id=speaker_name,
cache_list=True,
)
)

Expand Down Expand Up @@ -65,23 +66,13 @@
s=0.008,
noise_loss="smoothed-l1",
denoiser=dict(
# type="TransformerDecoderDenoiser",
# dim=512,
# mlp_factor=4,
# mel_channels=mel_channels,
# condition_dim=bert_dim,
# num_layers=40,
# gradient_checkpointing=gradient_checkpointing,
type="ConvNextDenoiser",
dim=384,
mlp_factor=4,
type="WaveNetDenoiser",
mel_channels=mel_channels,
condition_dim=bert_dim,
num_layers=20,
d_encoder=bert_dim,
residual_channels=512,
residual_layers=20,
dilation_cycle=4,
gradient_checkpointing=gradient_checkpointing,
cross_attention=True,
cross_every_n_layers=10,
use_linear_bias=True,
),
sampler_interval=10,
spec_min=[-5],
Expand Down Expand Up @@ -149,9 +140,9 @@
lambda_func = LambdaWarmUpCosineScheduler(
warm_up_steps=10000,
val_final=1e-5,
val_base=1e-4,
val_base=4e-5,
val_start=0,
max_decay_steps=300000,
max_decay_steps=1000000,
)

optimizer = dict(
Expand Down
42 changes: 28 additions & 14 deletions fish_diffusion/archs/diffsinger/grad_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, model_config):
self.text_encoder = ENCODERS.build(model_config.text_encoder)
self.diffusion = DIFFUSIONS.build(model_config.diffusion)
self.duration_predictor = ENCODERS.build(model_config.duration_predictor)
self.bert_query = nn.Parameter(torch.randn(1, 1, self.text_encoder.output_size))

if getattr(model_config, "speaker_encoder", None):
self.speaker_encoder = ENCODERS.build(model_config.speaker_encoder)
Expand Down Expand Up @@ -55,11 +56,29 @@ def forward_features(
phones2mel=None,
energy=None,
):
src_masks = self.get_mask_from_lengths(contents_lens, contents_max_len)
if self.training is False:
# Random 20% size change
mel_lens = torch.round(
mel_lens * (0.8 + 0.4 * torch.rand(1, device=mel_lens.device))
).long()
mel_max_len = torch.max(mel_lens).item()

features = self.text_encoder.bert.embeddings.word_embeddings(contents)[
# Build text features
text_features = self.text_encoder.bert.embeddings.word_embeddings(contents)[
:, 0, :, :
]
text_masks = self.get_mask_from_lengths(contents_lens, contents_max_len)

# Build Query features
query_lengths = (mel_lens / 10).ceil().long() # 512 * 10 / 44100 = 0.116 sec
max_query_length = torch.max(query_lengths).item()
mel_queries = self.bert_query.expand(mel.shape[0], max_query_length, -1)
mel_masks = self.get_mask_from_lengths(query_lengths, max_query_length)

# Concatenate text and query features
# This will waste memory, is there any better way?
features = torch.cat([text_features, mel_queries], dim=1)
src_masks = torch.cat([text_masks, mel_masks], dim=1)

if speakers.ndim in [2, 3] and torch.is_floating_point(speakers):
speaker_embed = speakers
Expand All @@ -82,26 +101,21 @@ def forward_features(
output_hidden_states=True,
)

# Predict durations
log_durations = self.duration_predictor(features[:, 0, :])[..., 0]
duration_loss = F.smooth_l1_loss(log_durations, torch.log(mel_lens.float()))
# Let's extract query features
features = features[:, -max_query_length:, :]

if self.training is False:
mel_lens = torch.round(torch.exp(torch.clamp(log_durations, 1, 8))).long()
mel_lens = torch.clamp(mel_lens, 10, 2048)
mel_max_len = torch.max(mel_lens).item()
# Repeat to match mel length
features = features.repeat_interleave(10, dim=1)

# Truncate to match mel length
features = features[:, :mel_max_len, :]
mel_masks = self.get_mask_from_lengths(mel_lens, mel_max_len)

return dict(
features=features,
x_masks=mel_masks,
x_lens=mel_lens,
cond_masks=src_masks,
loss=duration_loss,
metrics={
"duration_loss": duration_loss,
},
cond_masks=mel_masks,
)

def forward(
Expand Down
34 changes: 31 additions & 3 deletions fish_diffusion/datasets/naive.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import platform
import warnings
from pathlib import Path
from typing import Optional

Expand All @@ -18,9 +20,35 @@ class NaiveDataset(Dataset):

collating_pipeline = []

def __init__(self, path="dataset", speaker_id=0):
self.paths = list_files(path, {".npy"}, recursive=True, sort=True)
self.dataset_path = Path(path)
def __init__(
self, path: str = "dataset", speaker_id: int = 0, cache_list: bool = False
) -> None:
path = Path(path)
if cache_list and platform.system() != "Linux":
warnings.warn(
"Caching npy list is only supported on Linux, "
+ "since it uses `mtime` to check if the file is updated."
+ "Please use `cache_list=False` on other platforms."
)
cache_list = False

self.paths = None
cache_file = path / "filelist.cache"
if (
cache_list
and cache_file.exists()
and path.stat().st_mtime < cache_file.stat().st_mtime
):
self.paths = open(cache_file).read().splitlines()

if self.paths is None:
self.paths = list_files(path, {".npy"}, recursive=True, sort=True)

if cache_list:
with open(cache_file, "w") as f:
f.write("\n".join([str(x) for x in self.paths]))

self.dataset_path = path
self.speaker_id = speaker_id

assert len(self.paths) > 0, f"No files found in {path}, please check your path."
Expand Down
134 changes: 106 additions & 28 deletions fish_diffusion/modules/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,84 @@ def forward(self, x, diffusion_step, conditioner, x_masks=None, cond_masks=None)
return x[:, None] if use_4_dim else x


class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=4096, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
"cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
)
self.register_buffer(
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

@staticmethod
def apply_rotary_pos_emb(q, k, cos, sin, position_ids_q=None, position_ids_k=None):
if position_ids_q is None:
position_ids_q = torch.arange(q.size(1), dtype=torch.long, device=q.device)
position_ids_q = position_ids_q[None]

if position_ids_k is None:
position_ids_k = torch.arange(k.size(1), dtype=torch.long, device=k.device)
position_ids_k = position_ids_k[None]

# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]

q_embed = (q * cos[position_ids_q].unsqueeze(1)) + (
RotaryEmbedding.rotate_half(q) * sin[position_ids_q].unsqueeze(1)
)
k_embed = (k * cos[position_ids_k].unsqueeze(1)) + (
RotaryEmbedding.rotate_half(k) * sin[position_ids_k].unsqueeze(1)
)

q_embed, k_embed = q_embed[0].contiguous(), k_embed[0].contiguous()

return q_embed, k_embed


class TransformerDecoderDenoiser(nn.Module):
def __init__(
self,
Expand All @@ -270,6 +348,7 @@ def __init__(
condition_dim=256,
num_layers=12,
gradient_checkpointing=False,
post_block_num=4,
):
super().__init__()

Expand All @@ -290,9 +369,7 @@ def __init__(
nn.Conv1d(dim * mlp_factor, dim, 1),
)

self.register_buffer("positional_embedding", self.get_embedding(dim))
self.position_scale_query = nn.Parameter(torch.ones(1))
self.position_scale_key = nn.Parameter(torch.ones(1))
self.rotary_emb = RotaryEmbedding(dim)

self.layers = nn.ModuleList(
[
Expand All @@ -302,32 +379,29 @@ def __init__(
dim_feedforward=dim * mlp_factor,
activation="gelu",
batch_first=True,
norm_first=True,
layer_norm_eps=1e-6,
)
for _ in range(num_layers)
]
)

self.output_projection = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size=1),
nn.GELU(),
# Should be a strong postnet, but no condition is needed
self.post_blocks = nn.ModuleList([])
for _ in range(post_block_num):
self.post_blocks.append(
ConvNeXtBlock(
dim=dim,
intermediate_dim=dim * mlp_factor,
)
)

self.post_blocks.append(
nn.Conv1d(dim, mel_channels, kernel_size=1),
)

self.gradient_checkpointing = gradient_checkpointing

def get_embedding(self, embedding_dim, num_embeddings=4096):
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)

return emb

def forward(self, x, diffusion_step, conditioner, x_masks=None, cond_masks=None):
"""
Expand All @@ -342,16 +416,13 @@ def forward(self, x, diffusion_step, conditioner, x_masks=None, cond_masks=None)
assert x.dim() == 3, f"mel must be 3 dim tensor, but got {x.dim()}"

x = self.input_projection(x).transpose(1, 2) # x [B, T, residual_channel]
x_pos = self.positional_embedding[None, : x.size(1)] * self.position_scale_query
x = x + x_pos

condition = self.condition_projection(conditioner).transpose(1, 2)
diffusion_step = self.diffusion_embedding(diffusion_step).unsqueeze(1)
condition_pos = (
self.positional_embedding[None, : condition.size(1)]
* self.position_scale_key
)
condition = condition + condition_pos + diffusion_step
condition = condition + diffusion_step

# Apply positional encoding to both x and condition
cos, sin = self.rotary_emb(condition, seq_len=max(condition.size(1), x.size(1)))
x, condition = RotaryEmbedding.apply_rotary_pos_emb(x, condition, cos, sin)

if x_masks is not None:
x = x.masked_fill(x_masks[..., None], 0.0)
Expand All @@ -377,8 +448,15 @@ def forward(self, x, diffusion_step, conditioner, x_masks=None, cond_masks=None)
memory_key_padding_mask=cond_masks,
)

diffusion_step = diffusion_step.transpose(1, 2)
x = x.transpose(1, 2)
x = self.output_projection(x) # [B, 128, T]

for layer in self.post_blocks:
if isinstance(layer, ConvNeXtBlock):
x = layer(x, x_masks=x_masks, diffusion_step=diffusion_step)
else:
x = layer(x)

if x_masks is not None:
x = x.masked_fill(x_masks[:, None], 0.0)

Expand Down
11 changes: 10 additions & 1 deletion fish_diffusion/modules/wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
)
nn.init.zeros_(self.output_projection.conv.weight)

def forward(self, x, diffusion_step, conditioner):
def forward(self, x, diffusion_step, conditioner, x_masks=None, cond_masks=None):
"""
:param x: [B, M, T]
Expand All @@ -214,6 +214,12 @@ def forward(self, x, diffusion_step, conditioner):
diffusion_step = self.diffusion_embedding(diffusion_step)
diffusion_step = self.mlp(diffusion_step)

if x_masks is not None:
x = x.masked_fill(x_masks[:, None], 0.0)

if cond_masks is not None:
conditioner = conditioner.masked_fill(cond_masks[:, None], 0.0)

skip = []
for layer in self.residual_layers:
x, skip_connection = layer(x, conditioner, diffusion_step)
Expand All @@ -224,4 +230,7 @@ def forward(self, x, diffusion_step, conditioner):
x = F.relu(x)
x = self.output_projection(x) # [B, 128, T]

if x_masks is not None:
x = x.masked_fill(x_masks[:, None], 0.0)

return x[:, None] if use_4_dim else x
9 changes: 9 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import numpy as np
from transformers import AutoTokenizer

t = AutoTokenizer.from_pretrained("bert-base-cased")
d = np.load(
"dataset/LibriTTS/train-clean-100/26/495/26_495_000004_000000.0.data.npy",
allow_pickle=True,
).item()
print(t.decode(d["contents"][0]))

0 comments on commit a592056

Please sign in to comment.