Skip to content

Commit

Permalink
Tune bert params & scaled pe
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Sep 22, 2023
1 parent 37f95e0 commit b79c02f
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 21 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/schedulers/warmup_cosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

optimizer = dict(
type="AdamW",
lr=1.0,
lr=2e-5,
weight_decay=1e-2,
betas=(0.9, 0.98),
eps=1e-9,
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
check_val_every_n_epoch=None,
max_steps=2_000_000,
# Warning: If you are training the model with fs2 (and see nan), you should either use bf16 or fp32
precision="32",
precision="bf16-mixed",
accumulate_grad_batches=1,
callbacks=[
ModelCheckpoint(
Expand Down
7 changes: 4 additions & 3 deletions configs/tts_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,9 @@
pretrained=True,
),
duration_predictor=dict(
type="TransformerEncoder",
type="NaiveProjectionEncoder",
input_size=bert_dim,
output_size=1,
hidden_size=bert_dim,
num_layers=1,
),
vocoder=dict(
type="ADaMoSHiFiGANV1",
Expand Down Expand Up @@ -125,6 +123,9 @@
train=dict(
batch_size=8,
),
valid=dict(
batch_size=8,
),
)

preprocessing = dict(
Expand Down
18 changes: 4 additions & 14 deletions fish_diffusion/archs/diffsinger/grad_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,7 @@ def forward_features(
phones2mel=None,
energy=None,
):
src_masks = (
self.get_mask_from_lengths(contents_lens, contents_max_len)
if contents_lens is not None
else None
)

mel_masks = (
self.get_mask_from_lengths(mel_lens, mel_max_len)
if mel_lens is not None
else None
)
src_masks = self.get_mask_from_lengths(contents_lens, contents_max_len)

features = self.text_encoder.bert.embeddings.word_embeddings(contents)[
:, 0, :, :
Expand Down Expand Up @@ -93,15 +83,15 @@ def forward_features(
)

# Predict durations
durations = self.duration_predictor(features, src_masks_float)
log_durations = durations[:, 0, 0]
log_durations = self.duration_predictor(features[:, 0, :])[..., 0]
duration_loss = F.smooth_l1_loss(log_durations, torch.log(mel_lens.float()))

if self.training is False:
mel_lens = torch.round(torch.exp(torch.clamp(log_durations, 1, 8))).long()
mel_lens = torch.clamp(mel_lens, 10, 2048)
mel_max_len = torch.max(mel_lens).item()
mel_masks = self.get_mask_from_lengths(mel_lens, mel_max_len)

mel_masks = self.get_mask_from_lengths(mel_lens, mel_max_len)

return dict(
features=features,
Expand Down
11 changes: 9 additions & 2 deletions fish_diffusion/modules/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(

self.diffusion_step_projection = nn.Conv1d(dim, dim, 1)
self.register_buffer("positional_embedding", self.get_embedding(dim))
self.position_scale_query = nn.Parameter(torch.ones(1))
self.position_scale_key = nn.Parameter(torch.ones(1))

def get_embedding(self, embedding_dim, num_embeddings=4096):
half_dim = embedding_dim // 2
Expand All @@ -130,8 +132,13 @@ def forward(self, x, condition, diffusion_step, x_masks=None, cond_masks=None):
x = x.transpose(1, 2)
condition = condition.transpose(1, 2)

x = x + self.positional_embedding[: x.size(1)][None]
condition = condition + self.positional_embedding[: condition.size(1)][None]
# self.get_embedding(dim)
x = x + self.positional_embedding[: x.size(1)][None] * self.position_scale_query
condition = (
condition
+ self.positional_embedding[: condition.size(1)][None]
* self.position_scale_key
)

return (
super()
Expand Down

0 comments on commit b79c02f

Please sign in to comment.