diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py index acdccc60de..71f4f082c0 100644 --- a/monai/networks/schedulers/scheduler.py +++ b/monai/networks/schedulers/scheduler.py @@ -105,9 +105,11 @@ def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod /= alphas_cumprod[0].item() - alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) - betas = 1.0 - alphas - return betas, alphas, alphas_cumprod[:-1] + betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = torch.clip(betas, 0.0, 0.999) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + return betas, alphas, alphas_cumprod class Scheduler(nn.Module):