-
Notifications
You must be signed in to change notification settings - Fork 59
/
lion.py
91 lines (81 loc) · 4.33 KB
/
lion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
from models.vae_adain import Model as VAE
from models.latent_points_ada_localprior import PVCNN2Prior as LocalPrior
from utils.diffusion_pvd import DiffusionDiscretized
from utils.vis_helper import plot_points
from utils.model_helper import import_model
from diffusers import DDPMScheduler
import torch
from matplotlib import pyplot as plt
class LION(object):
def __init__(self, cfg):
self.vae = VAE(cfg).cuda()
GlobalPrior = import_model(cfg.latent_pts.style_prior)
global_prior = GlobalPrior(cfg.sde, cfg.latent_pts.style_dim, cfg).cuda()
local_prior = LocalPrior(cfg.sde, cfg.shapelatent.latent_dim, cfg).cuda()
self.priors = torch.nn.ModuleList([global_prior, local_prior])
self.scheduler = DDPMScheduler(clip_sample=False,
beta_start=cfg.ddpm.beta_1, beta_end=cfg.ddpm.beta_T, beta_schedule=cfg.ddpm.sched_mode,
num_train_timesteps=cfg.ddpm.num_steps, variance_type=cfg.ddpm.model_var_type)
self.diffusion = DiffusionDiscretized(None, None, cfg)
# self.load_model(cfg)
def load_model(self, model_path):
# model_path = cfg.ckpt.path
ckpt = torch.load(model_path)
self.priors.load_state_dict(ckpt['dae_state_dict'])
self.vae.load_state_dict(ckpt['vae_state_dict'])
print(f'INFO finish loading from {model_path}')
@torch.no_grad()
def sample(self, num_samples=10, clip_feat=None, save_img=False):
self.scheduler.set_timesteps(1000, device='cuda')
timesteps = self.scheduler.timesteps
latent_shape = self.vae.latent_shape()
global_prior, local_prior = self.priors[0], self.priors[1]
assert(not local_prior.mixed_prediction and not global_prior.mixed_prediction)
sampled_list = []
output_dict = {}
# start sample global prior
x_T_shape = [num_samples] + latent_shape[0]
x_noisy = torch.randn(size=x_T_shape, device='cuda')
condition_input = None
for i, t in enumerate(timesteps):
t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1)
noise_pred = global_prior(x=x_noisy, t=t_tensor.float(),
condition_input=condition_input, clip_feat=clip_feat)
x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample
sampled_list.append(x_noisy)
output_dict['z_global'] = x_noisy
condition_input = x_noisy
condition_input = self.vae.global2style(condition_input)
# start sample local prior
x_T_shape = [num_samples] + latent_shape[1]
x_noisy = torch.randn(size=x_T_shape, device='cuda')
for i, t in enumerate(timesteps):
t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1)
noise_pred = local_prior(x=x_noisy, t=t_tensor.float(),
condition_input=condition_input, clip_feat=clip_feat)
x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample
sampled_list.append(x_noisy)
output_dict['z_local'] = x_noisy
# decode the latent
output = self.vae.sample(num_samples=num_samples, decomposed_eps=sampled_list)
if save_img:
out_name = plot_points(output, "/tmp/tmp.png")
print(f'INFO save plot image at {out_name}')
output_dict['points'] = output
return output_dict
def get_mixing_component(self, noise_pred, t):
# usage:
# if global_prior.mixed_prediction:
# mixing_component = self.get_mixing_component(noise_pred, t)
# coeff = torch.sigmoid(global_prior.mixing_logit)
# noise_pred = (1 - coeff) * mixing_component + coeff * noise_pred
alpha_bar = self.scheduler.alphas_cumprod[t]
one_minus_alpha_bars_sqrt = np.sqrt(1.0 - alpha_bar)
return noise_pred * one_minus_alpha_bars_sqrt