-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclipboard.py
148 lines (112 loc) · 4.83 KB
/
clipboard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as TF
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import os
import os.path as ops
# Ignore warnings
import warnings
warnings.filterwarnings('ignore')
class CustomDataset(Dataset):
"""
Create a dataset that can produced data_augmented images
"""
def __init__(self, logits, dataset, data_aug=False, normalization=True):
"""
Initialize the parameters(logits) for creating a custom dataset
:param logits: [train/test] Logits from a teacher model as numpy npy file
:param dataset: [train/test] dataset of cifar-10 dataset (not dataloader)
:param data_aug: If data augmentation is enabled or not
"""
self._logits = logits
self._dataset = dataset
self._data_aug = data_aug
self._normalization = normalization
if not self._is_source_data_complete():
raise ValueError("Input datas are not complete"
"Wrong file types or"
"Files doesn't exit.")
if normalization:
self.normalize_fn = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
# create Tensordatasets
self._KD_data_list = self.construct()
def _is_source_data_complete(self):
"""
Check if source datas are complete
"""
_, file_extension = os.path.splitext(self._logits)
correct_file_type = file_extension == '.npy'
file_exist = ops.exists(self._logits)
return correct_file_type and file_exist
def __len__(self):
return len(self._KD_data_list)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = {"image": self._KD_data_list[idx][0], "label": self._KD_data_list[idx][1], "logit" : self._KD_data_list[idx][2]}
if self._data_aug:
sample = self.transform(sample)
if self._normalization:
sample = self.normalize(sample)
return sample
def normalize(self, sample):
"""
Perform channel-wise normalization
"""
tensor_image = sample["image"]
normalized_tensor_image = self.normalize_fn(tensor_image)
# re-assign the image back to sample
sample["image"] = normalized_tensor_image
return sample
def transform(self, sample):
"""
Perform image augmentation functions on sample['image']
"""
image = sample["image"]
# Do Random Crop on image
i, j, h, w = transforms.RandomCrop.get_params(
image, output_size=(32, 32)
)
image = TF.crop(image, i, j, h, w)
# Do random horizontal flip
if random.random() > 0.5:
image = TF.hflip(image)
# Transform to tensor
image = TF.to_tensor(image)
# re-assign the transformed back to sample
sample["image"] = image
return sample
def construct(self):
"""
Accept logits that are generated from a teacher model
Construct a dataset for knowledge distillation
:param train_logits: logits generated for the train set of cifar-10 by a teacher model
:param test_logits:
"""
# Load logits (.npy) files
logits = np.load(self._logits)
# convert the data type to have data type consistency
logits_tensor = logits.astype(np.float32)
KD_data_list = []
for data, logit in zip(self._dataset, logits):
# `image` is PIL.Image.Image data type
# `label` is int
image, label = data
# change labels and logits into numpy array
# Do not convert image to Tensor because image augmentation function
# only accepts PIL images
label = torch.tensor([label])
logit = torch.from_numpy(logit)
# create a data sample
data_sample = [image, label, logit]
# A giant list with lists made of numpy arrays
KD_data_list.append(data_sample)
# # convert the lists containg tensors into tensors
# image_tensor = torch.stack(all_images)
# label_tensor = torch.stack(all_labels)
return KD_data_list