-
Notifications
You must be signed in to change notification settings - Fork 17
/
cnn.py
119 lines (102 loc) · 3.91 KB
/
cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import sys
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import omniglot
import memory
class Net(nn.Module):
def __init__(self, input_shape):
super(Net, self).__init__()
# Constants
kernel = 3
pad = int((kernel-1)/2.0)
p = 0.3
ch, row, col = input_shape
self.conv1 = nn.Conv2d(ch, 32, kernel, padding=(pad, pad))
self.conv2 = nn.Conv2d(32, 32, kernel, padding=(pad, pad))
self.conv3 = nn.Conv2d(32, 64, kernel, padding=(pad, pad))
self.conv4 = nn.Conv2d(64, 64, kernel, padding=(pad, pad))
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(row // 4 * col // 4 * 64, 128)
self.dropout = nn.Dropout(p)
def forward(self, x, predict):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.pool(x)
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
if not predict:
x = self.dropout(x)
return x
memory_size = 8192
batch_size = 16
key_dim = 128
episode_length = 30
episode_width = 5
validation_frequency = 20
DATA_FILE_FORMAT = os.path.join(os.getcwd(), '%s_omni.pkl')
train_filepath = DATA_FILE_FORMAT % 'train'
trainset = omniglot.OmniglotDataset(train_filepath)
trainloader = trainset.sample_episode_batch(episode_length, episode_width, batch_size, N=100000)
test_filepath = DATA_FILE_FORMAT % 'test'
testset = omniglot.OmniglotDataset(test_filepath)
#torch.cuda.set_device(1)
net = Net(input_shape=(1,28,28))
mem = memory.Memory(memory_size, key_dim)
net.add_module("memory", mem)
net.cuda()
optimizer = optim.Adam(net.parameters(), lr=1e-4, eps=1e-4)
cummulative_loss = 0
counter = 0
for i, data in enumerate(trainloader, 0):
# erase memory before training episode
mem.build()
x, y = data
for xx, yy in zip(x, y):
optimizer.zero_grad()
xx_cuda, yy_cuda = Variable(xx.cuda()), Variable(yy.cuda())
embed = net(xx_cuda, False)
yy_hat, softmax_embed, loss = mem.query(embed, yy_cuda, False)
loss.backward()
optimizer.step()
cummulative_loss += loss.data[0]
counter += 1
if i % validation_frequency == 0:
# validation
correct = []
correct_by_k_shot = dict((k, list()) for k in range(episode_width + 1))
testloader = testset.sample_episode_batch(episode_length, episode_width, batch_size=1, N=50)
for data in testloader:
# erase memory before validation episode
mem.build()
x, y = data
y_hat = []
for xx, yy in zip(x, y):
xx_cuda, yy_cuda = Variable(xx.cuda()), Variable(yy.cuda())
query = net(xx_cuda, True)
yy_hat, embed, loss = mem.query(query, yy_cuda, True)
y_hat.append(yy_hat)
correct.append(float(torch.equal(yy_hat.cpu(), torch.unsqueeze(yy, dim=1))))
# compute per_shot accuracies
seen_count = [0 for idx in range(episode_width)]
# loop over episode steps
for yy, yy_hat in zip(y, y_hat):
count = seen_count[yy[0] % episode_width]
if count < (episode_width + 1):
correct_by_k_shot[count].append(float(torch.equal(yy_hat.cpu(), torch.unsqueeze(yy, dim=1))))
seen_count[yy[0] % episode_width] += 1
print("episode batch: {0:d} average loss: {1:.6f}".format(i, (cummulative_loss / (counter))))
print("validation overall accuracy {0:f}".format(np.mean(correct)))
for idx in range(episode_width + 1):
print("{0:d}-shot: {1:.3f}".format(idx, np.mean(correct_by_k_shot[idx])))
cummulative_loss = 0
counter = 0