Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix for the reconstruction part and the squash function and the margin&reconsturction loss function. #4

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion capsule_conv_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, in_channels, out_channels):
out_channels=out_channels,
kernel_size=9, # fixme constant
stride=1,
bias=False)
bias=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good find! i was experimenting with this myself locally but didn't want to commit either way. did you measure an accuracy improvement on MNIST with this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I just think if there is no BN layer in network, the bias should be add up. This strategy is work in traditional convnet. But I don't know if it is useful too in capsnet, I will test it after this model can get better accuracy. It always stuck at about 99.4% acc in my local version.


self.relu = nn.ReLU(inplace=True)

Expand Down
23 changes: 13 additions & 10 deletions capsule_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, in_channels):
out_channels=32, # fixme constant
kernel_size=9, # fixme constant
stride=2,
bias=False) # fixme constant
bias=True) # fixme constant

def forward(self, x):
return self.conv0(x)
Expand All @@ -31,6 +31,7 @@ def __init__(self, in_units, in_channels, num_units, unit_size, use_routing):
self.in_units = in_units
self.in_channels = in_channels
self.num_units = num_units
self.unit_size = unit_size
self.use_routing = use_routing

if self.use_routing:
Expand All @@ -45,12 +46,12 @@ def create_conv_unit(unit_idx):
unit = ConvUnit(in_channels=in_channels)
self.add_module("unit_" + str(unit_idx), unit)
return unit
self.units = [create_conv_unit(i) for i in range(self.num_units)]
self.units = [create_conv_unit(i) for i in range(self.unit_size)]

@staticmethod
def squash(s):
def squash(s, dim=2):
# This is equation 1 from the paper.
mag_sq = torch.sum(s**2, dim=2, keepdim=True)
mag_sq = torch.sum(s**2, dim, keepdim=True)
mag = torch.sqrt(mag_sq)
s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
return s
Expand All @@ -64,16 +65,16 @@ def forward(self, x):
def no_routing(self, x):
# Get output for each unit.
# Each will be (batch, channels, height, width).
u = [self.units[i](x) for i in range(self.num_units)]
u = [self.units[i](x) for i in range(self.unit_size)]

# Stack all unit outputs (batch, unit, channels, height, width).
# Stack all unit outputs (batch, unit_size, channels, height, width).
u = torch.stack(u, dim=1)

# Flatten to (batch, unit, output).
u = u.view(x.size(0), self.num_units, -1)
# Flatten to (batch, unit_size, output).
u = u.view(x.size(0), self.unit_size, -1)

# Return squashed outputs.
return CapsuleLayer.squash(u)
return CapsuleLayer.squash(u, dim=1)

def routing(self, x):
batch_size = x.size(0)
Expand All @@ -84,7 +85,7 @@ def routing(self, x):
# (batch, features, in_units) -> (batch, features, num_units, in_units, 1)
x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)

# (batch, features, in_units, unit_size, num_units)
# (batch, features, num_units, unit_size, in_units)
W = torch.cat([self.W] * batch_size, dim=0)

# Transform inputs by weight matrix.
Expand All @@ -99,6 +100,7 @@ def routing(self, x):
for iteration in range(num_iterations):
# Convert routing logits to softmax.
# (batch, features, num_units, 1, 1)
# fixme: seems apply wrong dimention here. but can't train the network if change it. weird.
c_ij = F.softmax(b_ij)
c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

Expand All @@ -107,6 +109,7 @@ def routing(self, x):
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)

# (batch_size, 1, num_units, unit_size, 1)
# fixme: seems apply wrong dimention here. but can't train the network if change it. weird.
v_j = CapsuleLayer.squash(s_j)

# (batch_size, features, num_units, unit_size, 1)
Expand Down
53 changes: 20 additions & 33 deletions capsule_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,18 @@ def __init__(self,
unit_size=primary_unit_size,
use_routing=False)

self.digits = CapsuleLayer(in_units=num_primary_units,
in_channels=primary_unit_size,
self.digits = CapsuleLayer(in_units=primary_unit_size,
in_channels=num_primary_units,
num_units=num_output_units,
unit_size=output_unit_size,
use_routing=True)

reconstruction_size = image_width * image_height * image_channels
self.reconstruct0 = nn.Linear(num_output_units*output_unit_size, (reconstruction_size * 2) / 3)
self.reconstruct1 = nn.Linear((reconstruction_size * 2) / 3, (reconstruction_size * 3) / 2)
self.reconstruct2 = nn.Linear((reconstruction_size * 3) / 2, reconstruction_size)
# self.reconstruct0 = nn.Linear(num_output_units*output_unit_size, (reconstruction_size * 2) / 3)
Copy link
Owner

@timomernick timomernick Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure this is correct? the paper (and other implementations i've seen) seem to say that the reconstruction input is all capsules, but with the inactive capsules masked out.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you're right. Now I can get 98.9% in first epoch after fix few bug. But I still wonder why we softmax the wrong dim and squash the wrong dim in routing part can still get such a great result.
And the most weird thing is that if I make the dim correct this model would crash sometimes, and get bad accuracy... Hope you can fix it. Thx.

# self.reconstruct1 = nn.Linear((reconstruction_size * 2) / 3, (reconstruction_size * 3) / 2)
self.reconstruct0 = nn.Linear(output_unit_size, 512)
self.reconstruct1 = nn.Linear(512, 1024)
self.reconstruct2 = nn.Linear(1024, reconstruction_size)

self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
Expand All @@ -61,7 +63,7 @@ def forward(self, x):
return self.digits(self.primary(self.conv1(x)))

def loss(self, images, input, target, size_average=True):
return self.margin_loss(input, target, size_average) + self.reconstruction_loss(images, input, size_average)
return self.margin_loss(input, target, size_average) + self.reconstruction_loss(images, input, target, size_average)

def margin_loss(self, input, target, size_average=True):
batch_size = input.size(0)
Expand All @@ -73,8 +75,8 @@ def margin_loss(self, input, target, size_average=True):
zero = Variable(torch.zeros(1)).cuda()
m_plus = 0.9
m_minus = 0.1
max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1)
max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1)
max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1) ** 2
max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1) ** 2

# This is equation 4 from the paper.
loss_lambda = 0.5
Expand All @@ -87,29 +89,14 @@ def margin_loss(self, input, target, size_average=True):

return L_c

def reconstruction_loss(self, images, input, size_average=True):
# Get the lengths of capsule outputs.
v_mag = torch.sqrt((input**2).sum(dim=2))

# Get index of longest capsule output.
_, v_max_index = v_mag.max(dim=1)
v_max_index = v_max_index.data

# Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image.
batch_size = input.size(0)
all_masked = [None] * batch_size
for batch_idx in range(batch_size):
# Get one sample from the batch.
input_batch = input[batch_idx]

# Copy only the maximum capsule index from this batch sample.
# This masks out (leaves as zero) the other capsules in this sample.
batch_masked = Variable(torch.zeros(input_batch.size())).cuda()
batch_masked[v_max_index[batch_idx]] = input_batch[v_max_index[batch_idx]]
all_masked[batch_idx] = batch_masked

# Stack masked capsules over the batch dimension.
masked = torch.stack(all_masked, dim=0)
def reconstruction_loss(self, images, input, target, size_average=True):
# Use the target to reconstruct input image.
# (batch_size, num_output_units, output_unit_size)
input = torch.squeeze(input, 3)
# (batch_size, num_output_units, 1)
target = torch.unsqueeze(target, 2)
# (batch_size, output_unit_size, 1)
masked = torch.matmul(input.transpose(2,1), target)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, i don't quite understand this part. what does 'masked' look like for a typical sample after this matmul? also see earlier comment, i thought the input was all capsules with all but one masked out -- otherwise the reconstruction would be confused by different classes of digits

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, there is no doubt to use the target to reconstruction. But my implementation maybe wrong after I rethink about this question. I thought the reconstruction network just need 16 values to reconstruct 10 images, but we know every digit capsule represent 1 entity(digit), so we should use all the 160 values to finish the reconstruction, and this also prove the fact how to reconstruct 2 overlapped MNIST at a time as the original paper said(Although this can be done by find the two max longest digit capsule and then reconstruct them two times and sum it, not enough neat as aforementioned). So yes, I think you're right. :)
just change the masked part to ' masked = input * target ' and change the number of first layer units to ' num_output_units*output_unit_size ' then it will work.


# Reconstruct input image.
masked = masked.view(input.size(0), -1)
Expand All @@ -130,11 +117,11 @@ def reconstruction_loss(self, images, input, size_average=True):
vutils.save_image(output_image, "reconstruction.png")
self.reconstructed_image_count += 1

# The reconstruction loss is the mean squared difference between the input image and reconstructed image.
# The reconstruction loss is the sum squared difference between the input image and reconstructed image.
# Multiplied by a small number so it doesn't dominate the margin (class) loss.
error = (output - images).view(output.size(0), -1)
error = error**2
error = torch.mean(error, dim=1) * 0.0005
error = torch.sum(error, dim=1) * 0.0005
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good find!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Your code is very neat, helpful to me.


# Average over batch
if size_average:
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@

conv_inputs = 1
conv_outputs = 256
num_primary_units = 8
primary_unit_size = 32 * 6 * 6 # fixme get from conv2d
num_primary_units = 32 * 6 * 6
primary_unit_size = 8 # fixme get from conv2d
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you're right, i did this part wrong. let me re-read the paper and consider your change carefully.

output_unit_size = 16

network = CapsuleNetwork(image_width=28,
Expand Down