-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvae_generations.py
126 lines (105 loc) · 5.2 KB
/
vae_generations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Given a trained TreeVAE model, this script generates reconstructions and samples for each leaf in the learned tree.
"""
import os
import yaml
import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt
from utils.data_utils import get_data, get_gen
from utils.data_utils import get_data, get_gen
from utils.model_utils import construct_tree_fromnpy
from utils.utils import display_image
from models.model import TreeVAE
def vae_recons():
parser = argparse.ArgumentParser()
parser.add_argument('--config_name', type=str,
choices=['mnist', 'fmnist', 'news20', 'omniglot', 'cifar10', 'cifar100', 'celeba', 'cubicc'],
help='the override file name for config.yml', default='cifar10')
parser.add_argument('--seed', type=int, help='random seed', default=42)
parser.add_argument('--mode', type=str, help='evaluation mode: vae_recons or vae_samples')
parser.add_argument('--model_name', type=str, help='path to the pretrained TreeVAE model')
args = parser.parse_args()
# set seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
mode = args.mode
dataset = args.config_name
ex_name = args.model_name
path = 'models/experiments/'
checkpoint_path = path+dataset+ex_name
with open(checkpoint_path + "/config.yaml", 'r') as stream:
configs = yaml.load(stream,Loader=yaml.Loader)
print(configs)
if dataset == "cubicc":
configs['training']['batch_size'] = 32
_, _, testset = get_data(configs)
gen_test = get_gen(testset, configs, validation=True, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TreeVAE(**configs['training'])
data_tree = np.load(checkpoint_path+'/data_tree.npy', allow_pickle=True)
model = construct_tree_fromnpy(model, data_tree, configs)
if not (configs['globals']['eager_mode'] and configs['globals']['wandb_logging']!='offline'):
#model = torch.compile(model)
pass
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(checkpoint_path+'/model_weights.pt', map_location=device), strict=True)
model.to(device)
model.eval()
# get test set reconstructions
if mode == 'vae_recons':
# setup dirs
vae_save_path = f"../results_ICLR/{dataset}/cond_on_path/ddim/seed_1/vae"
img_save_path = os.path.join(vae_save_path, "recons_all_leaves")
# loop over gen_test
for j, (x, y) in enumerate(gen_test):
x = x.to(device)
y = y.to(device)
with torch.no_grad():
res = model.compute_reconstruction(x)
recons = res[0]
nodes = res[1]
num_leaves = len(nodes)
# loop over each class and save every TreeVAE reconstruction of this class separately
for c in range(num_leaves):
# Setup a dir for each class
class_save_pass = os.path.join(img_save_path, f"img_cluster_{c}")
os.makedirs(class_save_pass, exist_ok=True)
# save every image of this class separately
for i in range(x.shape[0]):
prob = nodes[c]['prob'][i].cpu()
fig, axs = plt.subplots(1, 1, figsize=(2, 2))
axs.imshow(display_image(recons[c][i]), cmap=plt.get_cmap('gray'))
axs.set_title(f"L{c}: " + f"p=%.2f" % torch.round(prob, decimals=2))
axs.axis('off')
# save image
plt.savefig(os.path.join(class_save_pass, f"output__{0}_{j}_{i}_{prob}.png"))
plt.close()
# get new generations
elif mode == 'vae_samples':
# setup dirs
vae_save_path = f"../results_ICLR/{dataset}/cond_on_path/ddim/seed_1/vae"
img_save_path = os.path.join(vae_save_path, "sample_all_leaves")
# loop over gen_test --> not really used, only to get again 10k
for j, (x, y) in enumerate(gen_test):
n_samples = x.shape[0]
reconstructions, p_c_z = model.generate_images(n_samples, device)
num_leaves = len(reconstructions)
# loop over each class and save every TreeVAE reconstruction of this class separately
for c in range(num_leaves):
# Setup a dir for each class
class_save_pass = os.path.join(img_save_path, f"img_cluster_{c}")
os.makedirs(class_save_pass, exist_ok=True)
# save every image of this class separately
for i in range(n_samples):
prob = p_c_z[i][c].cpu().detach()
fig, axs = plt.subplots(1, 1, figsize=(2, 2))
axs.imshow(display_image(reconstructions[c][i].cpu().detach()), cmap=plt.get_cmap('gray'))
axs.set_title(f"L{c}: " + f"p=%.2f" % torch.round(prob, decimals=2))
axs.axis('off')
# save image
plt.savefig(os.path.join(class_save_pass, f"output__{0}_{j}_{i}_{prob}.png"))
plt.close()
if __name__ == '__main__':
vae_recons()