Skip to content

Commit

Permalink
change noise scheduling for wavegrad. Compute beta values externally …
Browse files Browse the repository at this point in the history
…to enable better flexibility
  • Loading branch information
erogol committed Nov 14, 2020
1 parent 5a59467 commit c657124
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
10 changes: 6 additions & 4 deletions TTS/bin/tune_wavegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@

# setup optimization parameters
base_values = sorted(np.random.uniform(high=10, size=args.search_depth))
exponents = 10 ** np.linspace(-6, -2, num=args.num_iter)
best_error = float('inf')
best_schedule = None
total_search_iter = len(base_values)**args.num_iter
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
model.compute_noise_level(num_steps=args.num_iter, min_val=1e-6, max_val=1e-1, base_vals=base)
beta = exponents * base
model.compute_noise_level(beta)
for data in loader:
mel, audio = data
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
Expand All @@ -78,11 +80,11 @@
mel_hat.append(torch.from_numpy(m))

mel_hat = torch.stack(mel_hat)
mse = torch.sum((mel - mel_hat) ** 2)
mse = torch.sum((mel - mel_hat) ** 2).mean()
if mse.item() < best_error:
best_error = mse.item()
best_schedule = {'num_steps': args.num_iter, 'min_val':1e-6, 'max_val':1e-1, 'base_vals':base}
print(" > Found a better schedule.")
best_schedule = {'beta': beta}
print(f" > Found a better schedule. - MSE: {mse.item()}")
np.save(args.output_path, best_schedule)


2 changes: 2 additions & 0 deletions TTS/vocoder/configs/wavegrad_libritts.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@

// TRAINING
"batch_size": 96, // Batch size for training.

// NOISE SCHEDULE PARAMS - Only effective at training time.
"train_noise_schedule":{
"min_val": 1e-6,
"max_val": 1e-2,
Expand Down
11 changes: 4 additions & 7 deletions TTS/vocoder/models/wavegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def forward(self, x, spectrogram, noise_scale):
return x

def load_noise_schedule(self, path):
sched = np.load(path, allow_pickle=True).item()
self.compute_noise_level(**sched)
beta = np.load(path, allow_pickle=True).item()['beta']
self.compute_noise_level(beta)

@torch.no_grad()
def inference(self, x, y_n=None):
Expand Down Expand Up @@ -113,16 +113,13 @@ def compute_y_n(self, y_0):
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]

def compute_noise_level(self, num_steps, min_val, max_val, base_vals=None):
def compute_noise_level(self, beta):
"""Compute noise schedule parameters"""
beta = np.linspace(min_val, max_val, num_steps)
if base_vals is not None:
beta *= base_vals
self.num_steps = len(beta)
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)

self.num_steps = num_steps
# pylint: disable=not-callable
self.beta = torch.tensor(beta.astype(np.float32))
self.alpha = torch.tensor(alpha.astype(np.float32))
Expand Down

0 comments on commit c657124

Please sign in to comment.