-
Notifications
You must be signed in to change notification settings - Fork 33
/
cifar10_zca.py
53 lines (46 loc) · 1.92 KB
/
cifar10_zca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
'''
Created on May 5, 2018
@author: vermavik
'''
import torch
from torch.autograd import Variable
import os, errno
import numpy as np
from scipy import linalg
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
def ZCA(data, reg=1e-6):
mean = np.mean(data, axis=0)
mdata = data - mean
sigma = np.dot(mdata.T, mdata) / mdata.shape[0]
U, S, V = linalg.svd(sigma)
components = np.dot(np.dot(U, np.diag(1 / np.sqrt(S) + reg)), U.T)
whiten = np.dot(data - mean, components.T)
return components, mean, whiten
def compute_zca(data_aug, data_target_dir):
import numpy as np
from functools import reduce
from operator import __or__
from torch.utils.data.sampler import SubsetRandomSampler
if data_aug==1:
train_transform = transforms.Compose(
[transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=2), transforms.ToTensor()])
else:
train_transform = transforms.Compose(
[transforms.ToTensor()])
train_data = datasets.CIFAR10(data_target_dir, train=True, transform=train_transform, download=True)
import pdb; pdb.set_trace()
num_classes = 10
temp_data = train_data.train_data.astype(float)
temp_data = temp_data.astype(float)
temp_data[:,:,:,0] = ((temp_data[:,:,:,0] - 125.3))/(63.0)
temp_data[:,:,:,1] = ((temp_data[:,:,:,1] - 123.0))/(62.1)
temp_data[:,:,:,2] = ((temp_data[:,:,:,2] - 113.9))/(66.7)
temp_data = np.transpose(temp_data, (0,3,1,2))
temp_data = temp_data.reshape(temp_data.shape[0],temp_data.shape[1]*temp_data.shape[2]*temp_data.shape[3])
components, mean, whiten = ZCA(temp_data)
np.save('data/cifar10/zca_components', components)
np.save('data/cifar10/zca_mean', mean)
if __name__ == '__main__':
compute_zca(data_aug=0, data_target_dir="data/cifar10/")