-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
155 lines (120 loc) · 4.82 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import time
import neptune
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
from torchvision.models import resnet50
from dataset import OmniglotReactionTimeDataset
from psychloss import RtPsychCrossEntropyLoss
from psychloss import AccPsychCrossEntropyLoss
# args
parser = argparse.ArgumentParser(description='Training Psych Loss.')
parser.add_argument('--num_epochs', type=int, default=20,
help='number of epochs to use')
parser.add_argument('--batch_size', type=int, default=64,
help='batch size')
parser.add_argument('--num_classes', type=int, default=100,
help='number of classes')
parser.add_argument('--learning_rate', type=float, default=0.0001,
help='learning rate')
parser.add_argument('--loss_fn', type=str, default='psych-rt',
help='loss function to use. select: cross-entropy, psych-rt, psych-acc')
parser.add_argument('--dataset_file', type=str, default='small_dataset.csv',
help='dataset file to use. out.csv is the full set')
parser.add_argument('--use_neptune', type=bool, default=False,
help='log metrics via neptune')
args = parser.parse_args()
if args.use_neptune:
# choose within your local path setup
# eg. neptune_path = 'alice/psyphy-loss' ...
neptune_path = ''
if neptune_path:
neptune.init(neptune_path)
neptune.create_experiment(name='sandbox-{}'.format(args.loss_fn), params={'lr': args.learning_rate}, tags=[args.loss_fn])
else:
print('Please enter a correct neptune path aligned with an existing neptune project.')
# seed for test replication
random_seed = 5 ** 3
torch.manual_seed(random_seed)
# configs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device is', device)
num_epochs = args.num_epochs
batch_size = args.batch_size
model = resnet50(pretrained=True).to(device)
model.fc = nn.Linear(2048, args.num_classes).to(device)
# LOAD PATH
# should be just in the project directory
# you can change to whatever you prfer
load_path = 'rt-mod.pth'
model.load_state_dict(torch.load(load_path))
optim = torch.optim.Adam(model.parameters(), 0.001)
if args.loss_fn == 'cross-entropy':
loss_fn = nn.CrossEntropyLoss()
# transforms and data loader
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Grayscale(num_output_channels=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
dataset = OmniglotReactionTimeDataset(args.dataset_file,
transforms=train_transform)
validation_split = .2
shuffle_dataset = True
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
# NOTE: not using train_loader at all
_ = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=valid_sampler)
model.eval()
accuracies = []
losses = []
exp_time = time.time()
correct = 0.0
total = 0.0
with torch.no_grad():
for idx, sample in enumerate(validation_loader):
image1 = sample['image1']
image2 = sample['image2']
label1 = sample['label1']
label2 = sample['label2']
if args.loss_fn == 'psych-acc':
psych = sample['acc']
else:
psych = sample['rt']
# concatenate the batched images for now
inputs = torch.cat([image1, image2], dim=0).to(device)
labels = torch.cat([label1, label2], dim=0).to(device)
psych_tensor = torch.zeros(len(labels))
j = 0
for i in range(len(psych_tensor)):
if i % 2 == 0:
psych_tensor[i] = psych[j]
j += 1
else:
psych_tensor[i] = psych_tensor[i-1]
psych_tensor = psych_tensor.to(device)
outputs = model(inputs).to(device)
# we don't update weights at test time
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'accuracy: {accuracy:.2f}%')
if args.use_neptune:
neptune.log_metric('test accuracy', accuracy)
accuracies.append(accuracy)
print(f'{time.time() - exp_time:.2f} seconds')