-
Notifications
You must be signed in to change notification settings - Fork 0
/
forward_process.py
95 lines (55 loc) · 2 KB
/
forward_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import matplotlib.pyplot as plt
import random
import matplotlib.pyplot as plt
import numpy
import torch
from torch.distributions.uniform import Uniform
import torch
import torch.distributions as dist
import torch
from torch.distributions.uniform import Uniform
import torch
import torch
import torch.distributions as dist
import torch
import torch.distributions as dist
import torch
import torch.distributions as dist
import torch
from torch.distributions.uniform import Uniform
import torch
import torch.distributions as dist
import torch
import torch
from torch.distributions.uniform import Uniform
def forward_diffusion_sample_chebyshev(x_0, t, constant_dict, config):
"""
Takes an image and a timestep as input and
returns the version of it with Chebyshev noise.
"""
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod = constant_dict['sqrt_alphas_cumprod'], constant_dict['sqrt_one_minus_alphas_cumprod']
# Create Chebyshev noise
# For simplicity, we're using first-order Chebyshev polynomial
uniform_dist = Uniform(-1, 1)
noise = uniform_dist.sample(x_0.shape)
chebyshev_noise = torch.cos(torch.acos(noise))
device = config.model.device
sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x_0.shape, config
)
# Apply Chebyshev noise
x = sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
+ sqrt_one_minus_alphas_cumprod_t.to(device) * (x_0.to(device) * chebyshev_noise.to(device))
x = x.to(device)
noise = noise.to(device) # Returning the original noise
return x, noise
def get_index_from_list(vals, t, x_shape):
"""
Returns a specific index t of a passed list of values vals
while considering the batch dimension.
"""
batch_size = t.shape[0]
out = vals.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)