-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
161 lines (136 loc) · 7.1 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import glob
import PIL
import numpy as np
class CelebA(Dataset):
"""CelelebA Dataset"""
def __init__(self, dataset_path, img_size_sr, **kwargs):
super().__init__()
self.data = glob.glob(os.path.join(dataset_path) + "/*.*")
assert len(self.data) > 0, "Can't find real faces data, please check dataset_path."
self.transform = transforms.Compose([transforms.Resize(320), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((img_size_sr, img_size_sr))])
def __len__(self):
return len(self.data)
def __getitem__(self, index):
X = PIL.Image.open(self.data[index])
X = self.transform(X)
return X
class AAHQ(Dataset):
def __init__(self, dataset_path2, dataset_path3, img_size, **kwargs):
super().__init__()
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.Resize((img_size, img_size))])
self.style_face = glob.glob(os.path.join(dataset_path2) + "/*.*")
self.style_code = np.loadtxt(dataset_path3, delimiter = ",")
self.style_code = torch.tensor(self.style_code, dtype=torch.float32) # [23567, 512]
assert len(self.style_face) > 0, "Can't find style face data, please check dataset_path."
assert len(self.style_code) > 0, "Can't find style code data, please check dataset_path."
def __getitem__(self, index):
style_face = PIL.Image.open(self.style_face[index])
style_face = self.transform(style_face)
style_code = self.style_code[index]
return style_face, style_code
def __len__(self):
return len(self.style_face)
class face2anime_artnerf(Dataset):
def __init__(self, dataset_path1, dataset_path2, dataset_path3, img_size_sr, **kwargs):
super().__init__()
self.transform1 = transforms.Compose([transforms.Resize(320), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((img_size_sr, img_size_sr))])
self.transform2 = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.Resize((img_size_sr, img_size_sr))])
self.real_face = glob.glob(os.path.join(dataset_path1) + "/*.*")
self.style_face = glob.glob(os.path.join(dataset_path2) + "/*.*")
self.style_code = np.loadtxt(dataset_path3, delimiter = ",")
self.style_code = torch.tensor(self.style_code, dtype=torch.float32) # [23567,512]
assert len(self.real_face) > 0, "Can't find real face data, please check dataset_path."
assert len(self.style_face) > 0, "Can't find style face data, please check dataset_path."
assert len(self.style_code) > 0, "Can't find style code data, please check dataset_path."
def __getitem__(self, index):
real_face = PIL.Image.open(self.real_face[index])
style_face = PIL.Image.open(self.style_face[index])
style_code = self.style_code[index]
real_face = self.transform1(real_face)
style_face = self.transform2(style_face)
return real_face, style_face, style_code
def __len__(self):
return len(self.real_face)
class styleFace(Dataset):
def __init__(self, dataset_path, img_size):
super().__init__()
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.Resize((img_size, img_size))])
self.styleface = glob.glob(os.path.join(dataset_path) + "/*.*")
assert len(self.styleface) > 0, "Can't find style faces data, please check dataset_path."
def __getitem__(self, index):
styleface = PIL.Image.open(self.styleface[index])
styleface = self.transform(styleface)
return styleface
def __len__(self):
return len(self.styleface)
class realFace(Dataset):
def __init__(self, dataset_path, img_size, mode, **kwargs):
super().__init__()
# 验证集图像无需翻转
if mode == 'train':
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((img_size, img_size))])
self.realface = glob.glob(os.path.join(dataset_path, "%s" % mode, 'trainA') + "/*.*")
else:
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.Resize((img_size, img_size))])
self.realface = glob.glob(os.path.join(dataset_path, "%s" % mode, 'testA') + "/*.*")
assert len(self.realface) > 0, "Can't find real faces data, please check dataset_path."
def __getitem__(self, index):
face_A = PIL.Image.open(self.realface[index])
face_A = self.transform(face_A)
return face_A
def __len__(self):
return len(self.realface)
class aniFace(Dataset):
def __init__(self, dataset_path, img_size, mode, **kwargs):
super().__init__()
# 验证集图像无需翻转
if mode == 'train':
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.Resize((img_size, img_size), interpolation=0)])
self.aniface = glob.glob(os.path.join(dataset_path, "%s" % mode, 'trainB') + "/*.*")
else:
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.Resize((img_size, img_size), interpolation=0)])
self.aniface = glob.glob(os.path.join(dataset_path, "%s" % mode, 'testB') + "/*.*")
assert len(self.aniface) > 0, "Can't find anime faces data, please check dataset_path."
def __getitem__(self, index):
face_B = PIL.Image.open(self.aniface[index])
face_B = self.transform(face_B)
return face_B
def __len__(self):
return len(self.aniface)
# 单卡
def get_dataset(name, batch_size, **kwargs):
dataset = globals()[name](**kwargs)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=4
)
return dataloader
def get_dataset_distributed(name, world_size, rank, batch_size, **kwargs):
# 在train中执行get_dataset_distributed()函数时,已经定义了CelebA类
# globals()[name]调用CelebA类,而后面(**kwargs)是传入metadata字典作为带变量名的参数
dataset = globals()[name](**kwargs)
# 分布式sampler
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
)
# 构造dataloader
dataloader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
shuffle=False, # Sampler option is mutually exclusive with shuffle
drop_last=True,
pin_memory=True,
num_workers=4,
)
return dataloader