-
Notifications
You must be signed in to change notification settings - Fork 5
/
test.py
119 lines (89 loc) · 3.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torchvision
def test_drwkv_model():
from models_drwkv import DRWKV_models
from thop import profile
for k, v in DRWKV_models.items():
print(k)
model = v(img_size=32).cuda()
input_image = torch.randn(1, 3, 32, 32).cuda()
times_steps = torch.randint(1, 100, (1,)).cuda()
flops, _ = profile(model, inputs=(input_image, times_steps ))
# out = model(x=input_image, timesteps=times_steps)
#print(out.size())
print('FLOPs = ' + str(flops * 2/1000**3) + 'G')
parameters_sum = sum(x.numel() for x in model.parameters())
print(parameters_sum / 1000000.0, "M")
def test_cifar10():
data_path = "/TrainData/Multimodal/zhengcong.fei/dis/data"
cifar10 = torchvision.datasets.CIFAR10(
root=data_path,
train=True,
download=False
)
cifar10_test = torchvision.datasets.CIFAR10(
root=data_path,
train=False,
download=False
)
print(cifar10)
print(cifar10_test[0])
def test_imagenet1k():
data_path = '/maindata/data/shared/multimodal/public/dataset_img_only/imagenet/data/train'
import torchvision.datasets as datasets
dataset_train = datasets.ImageFolder(data_path)
print(dataset_train[0])
def imagenet_formation():
data_path = '/maindata/data/shared/multimodal/public/dataset_img_only/imagenet/data/train_org'
from tqdm import tqdm
import os
import shutil
img_list = os.listdir(data_path)
print(len(img_list))
class_list = []
for img in img_list:
class_name = img.split('_')[0]
class_list.append(class_name)
class_list = set(class_list)
print(len(class_list))
target_path = '/maindata/data/shared/multimodal/public/dataset_img_only/imagenet/data/train'
for class_name in class_list:
directory = os.path.join(target_path, class_name)
if not os.path.exists(directory):
os.makedirs(directory)
for img in tqdm(img_list):
class_name = img.split('_')[0]
src_path = os.path.join(data_path, img)
tgt_path = os.path.join(target_path, class_name, img)
shutil.move(src_path, tgt_path)
def test_celeba():
from datasets import load_dataset
data_path = "/TrainData/Multimodal/zhengcong.fei/dis/data/CelebA"
dataset = load_dataset(data_path)
# dataset = dataset['train']
# dataset = dataset.map(lambda e: e['image'].convert('RGB'), batched=True)
#print(dataset[0])
print(dataset['train'][0].keys())
#print(dataset['train'][0]['image'].convert("RGB"))
# print(len(dataset['train']))
def test_fid_score():
from tools.fid_score import calculate_fid_given_paths
path1 = '/TrainData/Multimodal/zhengcong.fei/dis/results/cond_cifar10_small/his'
path2 = '/TrainData/Multimodal/zhengcong.fei/dis/results/uncond_cifar10_small/his'
fid = calculate_fid_given_paths((path1, path2))
def test_vae():
from diffusers.models import AutoencoderKL
vae_path = '/TrainData/Multimodal/zhengcong.fei/dis/vae'
vae = AutoencoderKL.from_pretrained(vae_path)
def test_rwkv():
from models_drwkv import DiffRWKVModel
model = DiffRWKVModel().cuda()
input_image = torch.randn(1, 3, 64, 64).cuda()
times_steps = torch.randint(1, 100, (1,)).cuda()
output = model(input_image, timesteps=times_steps)
test_drwkv_model()
# test_cifar10()
# test_imagenet1k()
# test_celeba()
# test_fid_score()
# test_vae()