-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bugfix for the reconstruction part and the squash function and the margin&reconsturction loss function. #4
base: master
Are you sure you want to change the base?
Changes from all commits
d08b556
1a9e38f
d0e8877
7b4aabd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,16 +43,18 @@ def __init__(self, | |
unit_size=primary_unit_size, | ||
use_routing=False) | ||
|
||
self.digits = CapsuleLayer(in_units=num_primary_units, | ||
in_channels=primary_unit_size, | ||
self.digits = CapsuleLayer(in_units=primary_unit_size, | ||
in_channels=num_primary_units, | ||
num_units=num_output_units, | ||
unit_size=output_unit_size, | ||
use_routing=True) | ||
|
||
reconstruction_size = image_width * image_height * image_channels | ||
self.reconstruct0 = nn.Linear(num_output_units*output_unit_size, (reconstruction_size * 2) / 3) | ||
self.reconstruct1 = nn.Linear((reconstruction_size * 2) / 3, (reconstruction_size * 3) / 2) | ||
self.reconstruct2 = nn.Linear((reconstruction_size * 3) / 2, reconstruction_size) | ||
# self.reconstruct0 = nn.Linear(num_output_units*output_unit_size, (reconstruction_size * 2) / 3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you sure this is correct? the paper (and other implementations i've seen) seem to say that the reconstruction input is all capsules, but with the inactive capsules masked out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, you're right. Now I can get 98.9% in first epoch after fix few bug. But I still wonder why we softmax the wrong dim and squash the wrong dim in routing part can still get such a great result. |
||
# self.reconstruct1 = nn.Linear((reconstruction_size * 2) / 3, (reconstruction_size * 3) / 2) | ||
self.reconstruct0 = nn.Linear(output_unit_size, 512) | ||
self.reconstruct1 = nn.Linear(512, 1024) | ||
self.reconstruct2 = nn.Linear(1024, reconstruction_size) | ||
|
||
self.relu = nn.ReLU(inplace=True) | ||
self.sigmoid = nn.Sigmoid() | ||
|
@@ -61,7 +63,7 @@ def forward(self, x): | |
return self.digits(self.primary(self.conv1(x))) | ||
|
||
def loss(self, images, input, target, size_average=True): | ||
return self.margin_loss(input, target, size_average) + self.reconstruction_loss(images, input, size_average) | ||
return self.margin_loss(input, target, size_average) + self.reconstruction_loss(images, input, target, size_average) | ||
|
||
def margin_loss(self, input, target, size_average=True): | ||
batch_size = input.size(0) | ||
|
@@ -73,8 +75,8 @@ def margin_loss(self, input, target, size_average=True): | |
zero = Variable(torch.zeros(1)).cuda() | ||
m_plus = 0.9 | ||
m_minus = 0.1 | ||
max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1) | ||
max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1) | ||
max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1) ** 2 | ||
max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1) ** 2 | ||
|
||
# This is equation 4 from the paper. | ||
loss_lambda = 0.5 | ||
|
@@ -87,29 +89,14 @@ def margin_loss(self, input, target, size_average=True): | |
|
||
return L_c | ||
|
||
def reconstruction_loss(self, images, input, size_average=True): | ||
# Get the lengths of capsule outputs. | ||
v_mag = torch.sqrt((input**2).sum(dim=2)) | ||
|
||
# Get index of longest capsule output. | ||
_, v_max_index = v_mag.max(dim=1) | ||
v_max_index = v_max_index.data | ||
|
||
# Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image. | ||
batch_size = input.size(0) | ||
all_masked = [None] * batch_size | ||
for batch_idx in range(batch_size): | ||
# Get one sample from the batch. | ||
input_batch = input[batch_idx] | ||
|
||
# Copy only the maximum capsule index from this batch sample. | ||
# This masks out (leaves as zero) the other capsules in this sample. | ||
batch_masked = Variable(torch.zeros(input_batch.size())).cuda() | ||
batch_masked[v_max_index[batch_idx]] = input_batch[v_max_index[batch_idx]] | ||
all_masked[batch_idx] = batch_masked | ||
|
||
# Stack masked capsules over the batch dimension. | ||
masked = torch.stack(all_masked, dim=0) | ||
def reconstruction_loss(self, images, input, target, size_average=True): | ||
# Use the target to reconstruct input image. | ||
# (batch_size, num_output_units, output_unit_size) | ||
input = torch.squeeze(input, 3) | ||
# (batch_size, num_output_units, 1) | ||
target = torch.unsqueeze(target, 2) | ||
# (batch_size, output_unit_size, 1) | ||
masked = torch.matmul(input.transpose(2,1), target) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, i don't quite understand this part. what does 'masked' look like for a typical sample after this matmul? also see earlier comment, i thought the input was all capsules with all but one masked out -- otherwise the reconstruction would be confused by different classes of digits There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First, there is no doubt to use the target to reconstruction. But my implementation maybe wrong after I rethink about this question. I thought the reconstruction network just need 16 values to reconstruct 10 images, but we know every digit capsule represent 1 entity(digit), so we should use all the 160 values to finish the reconstruction, and this also prove the fact how to reconstruct 2 overlapped MNIST at a time as the original paper said(Although this can be done by find the two max longest digit capsule and then reconstruct them two times and sum it, not enough neat as aforementioned). So yes, I think you're right. :) |
||
|
||
# Reconstruct input image. | ||
masked = masked.view(input.size(0), -1) | ||
|
@@ -130,11 +117,11 @@ def reconstruction_loss(self, images, input, size_average=True): | |
vutils.save_image(output_image, "reconstruction.png") | ||
self.reconstructed_image_count += 1 | ||
|
||
# The reconstruction loss is the mean squared difference between the input image and reconstructed image. | ||
# The reconstruction loss is the sum squared difference between the input image and reconstructed image. | ||
# Multiplied by a small number so it doesn't dominate the margin (class) loss. | ||
error = (output - images).view(output.size(0), -1) | ||
error = error**2 | ||
error = torch.mean(error, dim=1) * 0.0005 | ||
error = torch.sum(error, dim=1) * 0.0005 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good find! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx. Your code is very neat, helpful to me. |
||
|
||
# Average over batch | ||
if size_average: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,8 +46,8 @@ | |
|
||
conv_inputs = 1 | ||
conv_outputs = 256 | ||
num_primary_units = 8 | ||
primary_unit_size = 32 * 6 * 6 # fixme get from conv2d | ||
num_primary_units = 32 * 6 * 6 | ||
primary_unit_size = 8 # fixme get from conv2d | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think you're right, i did this part wrong. let me re-read the paper and consider your change carefully. |
||
output_unit_size = 16 | ||
|
||
network = CapsuleNetwork(image_width=28, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good find! i was experimenting with this myself locally but didn't want to commit either way. did you measure an accuracy improvement on MNIST with this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I just think if there is no BN layer in network, the bias should be add up. This strategy is work in traditional convnet. But I don't know if it is useful too in capsnet, I will test it after this model can get better accuracy. It always stuck at about 99.4% acc in my local version.