From c65712426aa8fe18a6b40b71a9a6b2c9681e9ec2 Mon Sep 17 00:00:00 2001 From: erogol Date: Sat, 14 Nov 2020 13:01:10 +0100 Subject: [PATCH] change noise scheduling for wavegrad. Compute beta values externally to enable better flexibility --- TTS/bin/tune_wavegrad.py | 10 ++++++---- TTS/vocoder/configs/wavegrad_libritts.json | 2 ++ TTS/vocoder/models/wavegrad.py | 11 ++++------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index 375e1f1c7..fde521c56 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -59,11 +59,13 @@ # setup optimization parameters base_values = sorted(np.random.uniform(high=10, size=args.search_depth)) +exponents = 10 ** np.linspace(-6, -2, num=args.num_iter) best_error = float('inf') best_schedule = None total_search_iter = len(base_values)**args.num_iter for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): - model.compute_noise_level(num_steps=args.num_iter, min_val=1e-6, max_val=1e-1, base_vals=base) + beta = exponents * base + model.compute_noise_level(beta) for data in loader: mel, audio = data y_hat = model.inference(mel.cuda() if args.use_cuda else mel) @@ -78,11 +80,11 @@ mel_hat.append(torch.from_numpy(m)) mel_hat = torch.stack(mel_hat) - mse = torch.sum((mel - mel_hat) ** 2) + mse = torch.sum((mel - mel_hat) ** 2).mean() if mse.item() < best_error: best_error = mse.item() - best_schedule = {'num_steps': args.num_iter, 'min_val':1e-6, 'max_val':1e-1, 'base_vals':base} - print(" > Found a better schedule.") + best_schedule = {'beta': beta} + print(f" > Found a better schedule. - MSE: {mse.item()}") np.save(args.output_path, best_schedule) diff --git a/TTS/vocoder/configs/wavegrad_libritts.json b/TTS/vocoder/configs/wavegrad_libritts.json index 57c267091..a271ce33b 100644 --- a/TTS/vocoder/configs/wavegrad_libritts.json +++ b/TTS/vocoder/configs/wavegrad_libritts.json @@ -72,6 +72,8 @@ // TRAINING "batch_size": 96, // Batch size for training. + + // NOISE SCHEDULE PARAMS - Only effective at training time. "train_noise_schedule":{ "min_val": 1e-6, "max_val": 1e-2, diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index f60873951..f9bcdb85e 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -79,8 +79,8 @@ def forward(self, x, spectrogram, noise_scale): return x def load_noise_schedule(self, path): - sched = np.load(path, allow_pickle=True).item() - self.compute_noise_level(**sched) + beta = np.load(path, allow_pickle=True).item()['beta'] + self.compute_noise_level(beta) @torch.no_grad() def inference(self, x, y_n=None): @@ -113,16 +113,13 @@ def compute_y_n(self, y_0): noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] - def compute_noise_level(self, num_steps, min_val, max_val, base_vals=None): + def compute_noise_level(self, beta): """Compute noise schedule parameters""" - beta = np.linspace(min_val, max_val, num_steps) - if base_vals is not None: - beta *= base_vals + self.num_steps = len(beta) alpha = 1 - beta alpha_hat = np.cumprod(alpha) noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0) - self.num_steps = num_steps # pylint: disable=not-callable self.beta = torch.tensor(beta.astype(np.float32)) self.alpha = torch.tensor(alpha.astype(np.float32))