|
15 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
16 | 16 |
|
17 | 17 | # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
|
| 18 | +import pdb |
18 | 19 |
|
19 | 20 | import numpy as np
|
20 | 21 | import torch
|
|
24 | 25 |
|
25 | 26 |
|
26 | 27 | class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
27 |
| - def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): |
| 28 | + """ |
| 29 | + The variance exploding stochastic differential equation (SDE) scheduler. |
| 30 | +
|
| 31 | + :param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param |
| 32 | + sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the |
| 33 | + distribution of the data. |
| 34 | + :param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to |
| 35 | + epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format: |
| 36 | + "np" or "pt" for the expected format of samples passed to the Scheduler. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + num_train_timesteps=2000, |
| 42 | + snr=0.15, |
| 43 | + sigma_min=0.01, |
| 44 | + sigma_max=1348, |
| 45 | + sampling_eps=1e-5, |
| 46 | + correct_steps=1, |
| 47 | + tensor_format="pt", |
| 48 | + ): |
28 | 49 | super().__init__()
|
29 | 50 | self.register_to_config(
|
| 51 | + num_train_timesteps=num_train_timesteps, |
30 | 52 | snr=snr,
|
31 | 53 | sigma_min=sigma_min,
|
32 | 54 | sigma_max=sigma_max,
|
33 | 55 | sampling_eps=sampling_eps,
|
| 56 | + correct_steps=correct_steps, |
34 | 57 | )
|
35 | 58 |
|
36 | 59 | self.sigmas = None
|
37 | 60 | self.discrete_sigmas = None
|
38 | 61 | self.timesteps = None
|
39 | 62 |
|
| 63 | + # TODO - update step to be torch-independant |
| 64 | + self.set_format(tensor_format=tensor_format) |
| 65 | + |
40 | 66 | def set_timesteps(self, num_inference_steps):
|
41 |
| - self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) |
| 67 | + tensor_format = getattr(self, "tensor_format", "pt") |
| 68 | + if tensor_format == "np": |
| 69 | + self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps) |
| 70 | + elif tensor_format == "pt": |
| 71 | + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) |
| 72 | + else: |
| 73 | + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") |
42 | 74 |
|
43 | 75 | def set_sigmas(self, num_inference_steps):
|
44 | 76 | if self.timesteps is None:
|
45 | 77 | self.set_timesteps(num_inference_steps)
|
46 | 78 |
|
47 |
| - self.discrete_sigmas = torch.exp( |
48 |
| - torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) |
49 |
| - ) |
50 |
| - self.sigmas = torch.tensor( |
51 |
| - [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] |
52 |
| - ) |
53 |
| - |
54 |
| - def step_pred(self, result, x, t): |
| 79 | + tensor_format = getattr(self, "tensor_format", "pt") |
| 80 | + if tensor_format == "np": |
| 81 | + self.discrete_sigmas = np.exp( |
| 82 | + np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) |
| 83 | + ) |
| 84 | + self.sigmas = np.array( |
| 85 | + [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] |
| 86 | + ) |
| 87 | + elif tensor_format == "pt": |
| 88 | + self.discrete_sigmas = torch.exp( |
| 89 | + torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) |
| 90 | + ) |
| 91 | + self.sigmas = torch.tensor( |
| 92 | + [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] |
| 93 | + ) |
| 94 | + else: |
| 95 | + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") |
| 96 | + |
| 97 | + def get_adjacent_sigma(self, timesteps, t): |
| 98 | + tensor_format = getattr(self, "tensor_format", "pt") |
| 99 | + if tensor_format == "np": |
| 100 | + return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) |
| 101 | + elif tensor_format == "pt": |
| 102 | + return torch.where( |
| 103 | + timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device) |
| 104 | + ) |
| 105 | + |
| 106 | + raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") |
| 107 | + |
| 108 | + def step_pred(self, score, x, t): |
| 109 | + """ |
| 110 | + Predict the sample at the previous timestep by reversing the SDE. |
| 111 | + """ |
55 | 112 | # TODO(Patrick) better comments + non-PyTorch
|
56 |
| - t = t * torch.ones(x.shape[0], device=x.device) |
57 |
| - timestep = (t * (len(self.timesteps) - 1)).long() |
58 |
| - |
59 |
| - sigma = self.discrete_sigmas.to(t.device)[timestep] |
60 |
| - adjacent_sigma = torch.where( |
61 |
| - timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device) |
62 |
| - ) |
63 |
| - f = torch.zeros_like(x) |
64 |
| - G = torch.sqrt(sigma**2 - adjacent_sigma**2) |
| 113 | + t = self.repeat_scalar(t, x.shape[0]) |
| 114 | + timesteps = self.long((t * (len(self.timesteps) - 1))) |
| 115 | + |
| 116 | + sigma = self.discrete_sigmas[timesteps] |
| 117 | + adjacent_sigma = self.get_adjacent_sigma(timesteps, t) |
| 118 | + drift = self.zeros_like(x) |
| 119 | + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 |
| 120 | + |
| 121 | + # equation 6 in the paper: the score modeled by the network is grad_x log pt(x) |
| 122 | + # also equation 47 shows the analog from SDE models to ancestral sampling methods |
| 123 | + drift = drift - diffusion[:, None, None, None] ** 2 * score |
| 124 | + |
| 125 | + # equation 6: sample noise for the diffusion term of |
| 126 | + noise = self.randn_like(x) |
| 127 | + x_mean = x - drift # subtract because `dt` is a small negative timestep |
| 128 | + # TODO is the variable diffusion the correct scaling term for the noise? |
| 129 | + x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g |
| 130 | + return x, x_mean |
65 | 131 |
|
66 |
| - f = f - G[:, None, None, None] ** 2 * result |
| 132 | + def step_correct(self, score, x): |
| 133 | + """ |
| 134 | + Correct the predicted sample based on the output score of the network. This is often run repeatedly after |
| 135 | + making the prediction for the previous timestep. |
| 136 | + """ |
| 137 | + # TODO(Patrick) non-PyTorch |
67 | 138 |
|
68 |
| - z = torch.randn_like(x) |
69 |
| - x_mean = x - f |
70 |
| - x = x_mean + G[:, None, None, None] * z |
71 |
| - return x, x_mean |
| 139 | + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" |
| 140 | + # sample noise for correction |
| 141 | + noise = self.randn_like(x) |
72 | 142 |
|
73 |
| - def step_correct(self, result, x): |
74 |
| - # TODO(Patrick) better comments + non-PyTorch |
75 |
| - noise = torch.randn_like(x) |
76 |
| - grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() |
77 |
| - noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() |
| 143 | + # compute step size from the score, the noise, and the snr |
| 144 | + grad_norm = self.norm(score) |
| 145 | + noise_norm = self.norm(noise) |
78 | 146 | step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
|
79 |
| - step_size = step_size * torch.ones(x.shape[0], device=x.device) |
80 |
| - x_mean = x + step_size[:, None, None, None] * result |
| 147 | + step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device) |
81 | 148 |
|
82 |
| - x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise |
| 149 | + # compute corrected sample: score term and noise term |
| 150 | + x_mean = x + step_size[:, None, None, None] * score |
| 151 | + x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise |
83 | 152 |
|
84 | 153 | return x
|
| 154 | + |
| 155 | + def __len__(self): |
| 156 | + return self.config.num_train_timesteps |
0 commit comments