diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 17f2a78e88..7d3d2a6531 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -18,9 +18,6 @@ def __call__(self, num_inference_steps=2000, generator=None): model = self.model.to(device) - # TODO(Patrick) move to scheduler config - n_steps = 1 - x = torch.randn(*shape) * self.scheduler.config.sigma_max x = x.to(device) @@ -30,7 +27,7 @@ def __call__(self, num_inference_steps=2000, generator=None): for i, t in enumerate(self.scheduler.timesteps): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) - for _ in range(n_steps): + for _ in range(self.scheduler.correct_steps): with torch.no_grad(): result = self.model(x, sigma_t) diff --git a/src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py b/src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py index 29551d9a6e..b9cf0884ea 100644 --- a/src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py +++ b/src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py @@ -27,6 +27,7 @@ def __call__(self, num_inference_steps=1000, generator=None): t = t * torch.ones(shape[0], device=device) scaled_t = t * (num_inference_steps - 1) + # TODO add corrector with torch.no_grad(): result = model(x, scaled_t) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index f59b75eea1..fbd1fbba13 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -51,7 +51,7 @@ def alpha_bar(time_step): class DDIMScheduler(SchedulerMixin, ConfigMixin): def __init__( self, - timesteps=1000, + num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear", @@ -62,7 +62,7 @@ def __init__( ): super().__init__() self.register_to_config( - timesteps=timesteps, + num_train_timesteps=num_train_timesteps, beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, @@ -72,13 +72,13 @@ def __init__( ) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=np.float32) ** 2 + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule - self.betas = betas_for_alpha_bar(timesteps) + self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") @@ -88,10 +88,7 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy() - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] @@ -131,7 +128,7 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.timesteps // self.num_inference_steps + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] @@ -183,4 +180,4 @@ def add_noise(self, original_samples, noise, timesteps): return noisy_samples def __len__(self): - return self.config.timesteps + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 27940a5693..d62cc608af 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -50,7 +50,7 @@ def alpha_bar(time_step): class DDPMScheduler(SchedulerMixin, ConfigMixin): def __init__( self, - timesteps=1000, + num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear", @@ -62,7 +62,7 @@ def __init__( ): super().__init__() self.register_to_config( - timesteps=timesteps, + num_train_timesteps=num_train_timesteps, beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, @@ -75,10 +75,10 @@ def __init__( if trained_betas is not None: self.betas = np.asarray(trained_betas) elif beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule - self.betas = betas_for_alpha_bar(timesteps) + self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") @@ -160,4 +160,4 @@ def add_noise(self, original_samples, noise, timesteps): return noisy_samples def __len__(self): - return self.config.timesteps + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 7df1c3bbb1..3b889d0ac2 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -50,7 +50,7 @@ def alpha_bar(time_step): class PNDMScheduler(SchedulerMixin, ConfigMixin): def __init__( self, - timesteps=1000, + num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear", @@ -58,17 +58,17 @@ def __init__( ): super().__init__() self.register_to_config( - timesteps=timesteps, + num_train_timesteps=num_train_timesteps, beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, ) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule - self.betas = betas_for_alpha_bar(timesteps) + self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") @@ -96,10 +96,12 @@ def get_prk_time_steps(self, num_inference_steps): if num_inference_steps in self.prk_time_steps: return self.prk_time_steps[num_inference_steps] - inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps)) + inference_step_times = list( + range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) + ) prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( - np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order + np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) @@ -109,7 +111,9 @@ def get_time_steps(self, num_inference_steps): if num_inference_steps in self.time_steps: return self.time_steps[num_inference_steps] - inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps)) + inference_step_times = list( + range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) + ) self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) return self.time_steps[num_inference_steps] @@ -135,6 +139,10 @@ def step_prk( sample: Union[torch.FloatTensor, np.ndarray], num_inference_steps, ): + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + """ t = timestep prk_time_steps = self.get_prk_time_steps(num_inference_steps) @@ -165,6 +173,10 @@ def step_plms( sample: Union[torch.FloatTensor, np.ndarray], num_inference_steps, ): + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + """ t = timestep if len(self.ets) < 3: raise ValueError( @@ -221,4 +233,4 @@ def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output): return prev_sample def __len__(self): - return self.config.timesteps + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 79936105b9..cbc52f9110 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit +import pdb import numpy as np import torch @@ -24,61 +25,132 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): - def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): + """ + The variance exploding stochastic differential equation (SDE) scheduler. + + :param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param + sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + :param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to + epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format: + "np" or "pt" for the expected format of samples passed to the Scheduler. + """ + + def __init__( + self, + num_train_timesteps=2000, + snr=0.15, + sigma_min=0.01, + sigma_max=1348, + sampling_eps=1e-5, + correct_steps=1, + tensor_format="pt", + ): super().__init__() self.register_to_config( + num_train_timesteps=num_train_timesteps, snr=snr, sigma_min=sigma_min, sigma_max=sigma_max, sampling_eps=sampling_eps, + correct_steps=correct_steps, ) self.sigmas = None self.discrete_sigmas = None self.timesteps = None + # TODO - update step to be torch-independant + self.set_format(tensor_format=tensor_format) + def set_timesteps(self, num_inference_steps): - self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps) + elif tensor_format == "pt": + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + else: + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") def set_sigmas(self, num_inference_steps): if self.timesteps is None: self.set_timesteps(num_inference_steps) - self.discrete_sigmas = torch.exp( - torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) - ) - self.sigmas = torch.tensor( - [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] - ) - - def step_pred(self, result, x, t): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + self.discrete_sigmas = np.exp( + np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) + ) + self.sigmas = np.array( + [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] + ) + elif tensor_format == "pt": + self.discrete_sigmas = torch.exp( + torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) + ) + self.sigmas = torch.tensor( + [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] + ) + else: + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def get_adjacent_sigma(self, timesteps, t): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) + elif tensor_format == "pt": + return torch.where( + timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device) + ) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def step_pred(self, score, x, t): + """ + Predict the sample at the previous timestep by reversing the SDE. + """ # TODO(Patrick) better comments + non-PyTorch - t = t * torch.ones(x.shape[0], device=x.device) - timestep = (t * (len(self.timesteps) - 1)).long() - - sigma = self.discrete_sigmas.to(t.device)[timestep] - adjacent_sigma = torch.where( - timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device) - ) - f = torch.zeros_like(x) - G = torch.sqrt(sigma**2 - adjacent_sigma**2) + t = self.repeat_scalar(t, x.shape[0]) + timesteps = self.long((t * (len(self.timesteps) - 1))) + + sigma = self.discrete_sigmas[timesteps] + adjacent_sigma = self.get_adjacent_sigma(timesteps, t) + drift = self.zeros_like(x) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the score modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + drift = drift - diffusion[:, None, None, None] ** 2 * score + + # equation 6: sample noise for the diffusion term of + noise = self.randn_like(x) + x_mean = x - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + return x, x_mean - f = f - G[:, None, None, None] ** 2 * result + def step_correct(self, score, x): + """ + Correct the predicted sample based on the output score of the network. This is often run repeatedly after + making the prediction for the previous timestep. + """ + # TODO(Patrick) non-PyTorch - z = torch.randn_like(x) - x_mean = x - f - x = x_mean + G[:, None, None, None] * z - return x, x_mean + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + noise = self.randn_like(x) - def step_correct(self, result, x): - # TODO(Patrick) better comments + non-PyTorch - noise = torch.randn_like(x) - grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() - noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + # compute step size from the score, the noise, and the snr + grad_norm = self.norm(score) + noise_norm = self.norm(noise) step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 - step_size = step_size * torch.ones(x.shape[0], device=x.device) - x_mean = x + step_size[:, None, None, None] * result + step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device) - x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise + # compute corrected sample: score term and noise term + x_mean = x + step_size[:, None, None, None] * score + x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise return x + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index dda32a2742..08f1c2af0c 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -24,9 +24,10 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): - def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): super().__init__() self.register_to_config( + num_train_timesteps=num_train_timesteps, beta_min=beta_min, beta_max=beta_max, sampling_eps=sampling_eps, @@ -39,14 +40,14 @@ def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format=" def set_timesteps(self, num_inference_steps): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) - def step_pred(self, result, x, t): + def step_pred(self, score, x, t): # TODO(Patrick) better comments + non-PyTorch - # postprocess model result + # postprocess model score log_mean_coeff = ( -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min ) std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) - result = -result / std[:, None, None, None] + score = -score / std[:, None, None, None] # compute dt = -1.0 / len(self.timesteps) @@ -54,11 +55,14 @@ def step_pred(self, result, x, t): beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) drift = -0.5 * beta_t[:, None, None, None] * x diffusion = torch.sqrt(beta_t) - drift = drift - diffusion[:, None, None, None] ** 2 * result + drift = drift - diffusion[:, None, None, None] ** 2 * score x_mean = x + drift * dt # add noise - z = torch.randn_like(x) - x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z + noise = torch.randn_like(x) + x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise return x, x_mean + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 7c5972434b..01040b2100 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -53,12 +53,22 @@ def log(self, tensor): raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + def long(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + + if tensor_format == "np": + return np.int64(tensor) + elif tensor_format == "pt": + return tensor.long() + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): """ Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. Args: - timesteps: an array or tensor of values to extract. + values: an array or tensor of values to extract. broadcast_array: an array with a larger shape of K dimensions with the batch dimension equal to the length of timesteps. Returns: @@ -74,3 +84,39 @@ def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: values = values.to(broadcast_array.device) return values + + def norm(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.linalg.norm(tensor) + elif tensor_format == "pt": + return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean() + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def randn_like(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.random.randn(*np.shape(tensor)) + elif tensor_format == "pt": + return torch.randn_like(tensor) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def repeat_scalar(self, tensor, count): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.repeat(tensor, count) + elif tensor_format == "pt": + return torch.repeat_interleave(tensor, count) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def zeros_like(self, tensor): + tensor_format = getattr(self, "tensor_format", "pt") + if tensor_format == "np": + return np.zeros_like(tensor) + elif tensor_format == "pt": + return torch.zeros_like(tensor) + + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index e473492e60..3cd98622fe 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1087,11 +1087,16 @@ def test_score_sde_ve_pipeline(self): image = sde_ve(num_inference_steps=2) if model.device.type == "cpu": - expected_image_sum = 3384805632.0 - expected_image_mean = 1076.000732421875 + # patrick's cpu + expected_image_sum = 3384805888.0 + expected_image_mean = 1076.00085 + + # m1 mbp + # expected_image_sum = 3384805376.0 + # expected_image_mean = 1076.000610351562 else: expected_image_sum = 3382849024.0 - expected_image_mean = 1075.3787841796875 + expected_image_mean = 1075.3788 assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 @@ -1109,6 +1114,10 @@ def test_score_sde_vp_pipeline(self): expected_image_sum = 4183.2012 expected_image_mean = 1.3617 + # on m1 mbp + # expected_image_sum = 4318.6729 + # expected_image_mean = 1.4058 + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 0603fa5ddd..e14d196d63 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -12,15 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import pdb import tempfile import unittest import numpy as np import torch -from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler +from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler torch.backends.cuda.matmul.allow_tf32 = False @@ -208,7 +207,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): def get_scheduler_config(self, **kwargs): config = { - "timesteps": 1000, + "num_train_timesteps": 1000, "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", @@ -221,7 +220,7 @@ def get_scheduler_config(self, **kwargs): def test_timesteps(self): for timesteps in [1, 5, 100, 1000]: - self.check_over_configs(timesteps=timesteps) + self.check_over_configs(num_train_timesteps=timesteps) def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): @@ -288,7 +287,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): def get_scheduler_config(self, **kwargs): config = { - "timesteps": 1000, + "num_train_timesteps": 1000, "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", @@ -300,7 +299,7 @@ def get_scheduler_config(self, **kwargs): def test_timesteps(self): for timesteps in [100, 500, 1000]: - self.check_over_configs(timesteps=timesteps) + self.check_over_configs(num_train_timesteps=timesteps) def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): @@ -367,7 +366,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): def get_scheduler_config(self, **kwargs): config = { - "timesteps": 1000, + "num_train_timesteps": 1000, "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", @@ -431,11 +430,11 @@ def check_over_forward_pmls(self, time_step=0, **forward_kwargs): def test_timesteps(self): for timesteps in [100, 1000]: - self.check_over_configs(timesteps=timesteps) + self.check_over_configs(num_train_timesteps=timesteps) def test_timesteps_pmls(self): for timesteps in [100, 1000]: - self.check_over_configs_pmls(timesteps=timesteps) + self.check_over_configs_pmls(num_train_timesteps=timesteps) def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): @@ -507,3 +506,115 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 199.1169) < 1e-2 assert abs(result_mean.item() - 0.2593) < 1e-3 + + +class ScoreSdeVeSchedulerTest(SchedulerCommonTest): + scheduler_classes = (ScoreSdeVeScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 2000, + "snr": 0.15, + "sigma_min": 0.01, + "sigma_max": 1348, + "sampling_eps": 1e-5, + "tensor_format": "np", # TODO add test for tensor formats + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + for scheduler_class in self.scheduler_classes: + scheduler_class = self.scheduler_classes[0] + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + + output = scheduler.step_pred(residual, sample, time_step, **kwargs) + new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs) + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output = scheduler.step_correct(residual, sample, **kwargs) + new_output = new_scheduler.step_correct(residual, sample, **kwargs) + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + + output = scheduler.step_pred(residual, sample, time_step, **kwargs) + new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs) + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output = scheduler.step_correct(residual, sample, **kwargs) + new_output = new_scheduler.step_correct(residual, sample, **kwargs) + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" + + def test_timesteps(self): + for timesteps in [10, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_sigmas(self): + for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 100, 1000]): + self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max) + + def test_time_indices(self): + for t in [1, 5, 10]: + self.check_over_forward(time_step=t) + + def test_full_loop_no_noise(self): + np.random.seed(0) + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 3 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + sigma_t = scheduler.sigmas[i] + + for _ in range(scheduler.correct_steps): + with torch.no_grad(): + result = model(sample, sigma_t) + sample = scheduler.step_correct(result, sample) + + with torch.no_grad(): + result = model(sample, sigma_t) + + sample, sample_mean = scheduler.step_pred(result, sample, t) + + result_sum = np.sum(np.abs(sample)) + result_mean = np.mean(np.abs(sample)) + + assert abs(result_sum.item() - 10629923278.7104) < 1e-2 + assert abs(result_mean.item() - 13841045.9358) < 1e-3