forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
61 lines (51 loc) · 2.21 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn
from DGLDigitCapsule import DGLDigitCapsuleLayer
from DGLRoutingLayer import squash
class Net(nn.Module):
def __init__(self, device='cpu'):
super(Net, self).__init__()
self.device = device
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,
out_channels=256,
kernel_size=9,
stride=1), nn.ReLU(inplace=True))
self.primary = PrimaryCapsuleLayer(device=device)
self.digits = DGLDigitCapsuleLayer(device=device)
def forward(self, x):
out_conv1 = self.conv1(x)
out_primary_caps = self.primary(out_conv1)
out_digit_caps = self.digits(out_primary_caps)
return out_digit_caps
def margin_loss(self, input, target):
batch_s = target.size(0)
one_hot_vec = torch.zeros(batch_s, 10).to(self.device)
for i in range(batch_s):
one_hot_vec[i, target[i]] = 1.0
batch_size = input.size(0)
v_c = torch.sqrt((input ** 2).sum(dim=2, keepdim=True))
zero = torch.zeros(1).to(self.device)
m_plus = 0.9
m_minus = 0.1
loss_lambda = 0.5
max_left = torch.max(m_plus - v_c, zero).view(batch_size, -1) ** 2
max_right = torch.max(v_c - m_minus, zero).view(batch_size, -1) ** 2
t_c = one_hot_vec
l_c = t_c * max_left + loss_lambda * (1.0 - t_c) * max_right
l_c = l_c.sum(dim=1)
return l_c.mean()
class PrimaryCapsuleLayer(nn.Module):
def __init__(self, in_channel=256, num_unit=8, device='cpu'):
super(PrimaryCapsuleLayer, self).__init__()
self.in_channel = in_channel
self.num_unit = num_unit
self.deivce = device
self.conv_units = nn.ModuleList([
nn.Conv2d(self.in_channel, 32, 9, 2) for _ in range(self.num_unit)
])
def forward(self, x):
unit = [self.conv_units[i](x) for i, l in enumerate(self.conv_units)]
unit = torch.stack(unit, dim=1)
batch_size = x.size(0)
unit = unit.view(batch_size, 8, -1)
return squash(unit, dim=2)