From d08b556777a2621ad41bcea08920a10a93c2078c Mon Sep 17 00:00:00 2001 From: jiawei wang Date: Tue, 7 Nov 2017 12:01:56 +0800 Subject: [PATCH 1/3] # bugfix: Use correct digit capsule to reconstruct input image rather than the longest digit capsule. --- capsule_network.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/capsule_network.py b/capsule_network.py index 20a1c32..1e69ba5 100644 --- a/capsule_network.py +++ b/capsule_network.py @@ -50,9 +50,11 @@ def __init__(self, use_routing=True) reconstruction_size = image_width * image_height * image_channels - self.reconstruct0 = nn.Linear(num_output_units*output_unit_size, (reconstruction_size * 2) / 3) - self.reconstruct1 = nn.Linear((reconstruction_size * 2) / 3, (reconstruction_size * 3) / 2) - self.reconstruct2 = nn.Linear((reconstruction_size * 3) / 2, reconstruction_size) + # self.reconstruct0 = nn.Linear(num_output_units*output_unit_size, (reconstruction_size * 2) / 3) + # self.reconstruct1 = nn.Linear((reconstruction_size * 2) / 3, (reconstruction_size * 3) / 2) + self.reconstruct0 = nn.Linear(output_unit_size, 512) + self.reconstruct1 = nn.Linear(512, 1024) + self.reconstruct2 = nn.Linear(1024, reconstruction_size) self.relu = nn.ReLU(inplace=True) self.sigmoid = nn.Sigmoid() @@ -61,7 +63,7 @@ def forward(self, x): return self.digits(self.primary(self.conv1(x))) def loss(self, images, input, target, size_average=True): - return self.margin_loss(input, target, size_average) + self.reconstruction_loss(images, input, size_average) + return self.margin_loss(input, target, size_average) + self.reconstruction_loss(images, input, target, size_average) def margin_loss(self, input, target, size_average=True): batch_size = input.size(0) @@ -87,17 +89,15 @@ def margin_loss(self, input, target, size_average=True): return L_c - def reconstruction_loss(self, images, input, size_average=True): - # Get the lengths of capsule outputs. - v_mag = torch.sqrt((input**2).sum(dim=2)) + def reconstruction_loss(self, images, input, target, size_average=True): + # Use the target to reconstruct input image. - # Get index of longest capsule output. - _, v_max_index = v_mag.max(dim=1) - v_max_index = v_max_index.data - - # Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image. - masked = Variable(torch.zeros(input.size())).cuda() - masked[:,v_max_index] = input[:,v_max_index] + # (batch_size, num_output_units, output_unit_size) + input = torch.squeeze(input, 3) + # (batch_size, num_output_units, 1) + target = torch.unsqueeze(target, 2) + # (batch_size, output_unit_size, 1) + masked = torch.matmul(input.transpose(2,1), target) # Reconstruct input image. masked = masked.view(input.size(0), -1) From 1a9e38fc9fa8be31b70d83747f1e464b54d4f5ef Mon Sep 17 00:00:00 2001 From: jiawei wang Date: Tue, 7 Nov 2017 21:42:58 +0800 Subject: [PATCH 2/3] # rename the no_routing part to make it more readable. --- capsule_layer.py | 11 ++++++----- capsule_network.py | 4 ++-- main.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/capsule_layer.py b/capsule_layer.py index 40ac382..55c30eb 100644 --- a/capsule_layer.py +++ b/capsule_layer.py @@ -31,6 +31,7 @@ def __init__(self, in_units, in_channels, num_units, unit_size, use_routing): self.in_units = in_units self.in_channels = in_channels self.num_units = num_units + self.unit_size = unit_size self.use_routing = use_routing if self.use_routing: @@ -45,7 +46,7 @@ def create_conv_unit(unit_idx): unit = ConvUnit(in_channels=in_channels) self.add_module("unit_" + str(unit_idx), unit) return unit - self.units = [create_conv_unit(i) for i in range(self.num_units)] + self.units = [create_conv_unit(i) for i in range(self.unit_size)] @staticmethod def squash(s): @@ -64,13 +65,13 @@ def forward(self, x): def no_routing(self, x): # Get output for each unit. # Each will be (batch, channels, height, width). - u = [self.units[i](x) for i in range(self.num_units)] + u = [self.units[i](x) for i in range(self.unit_size)] - # Stack all unit outputs (batch, unit, channels, height, width). + # Stack all unit outputs (batch, unit_size, channels, height, width). u = torch.stack(u, dim=1) - # Flatten to (batch, unit, output). - u = u.view(x.size(0), self.num_units, -1) + # Flatten to (batch, unit_size, output). + u = u.view(x.size(0), self.unit_size, -1) # Return squashed outputs. return CapsuleLayer.squash(u) diff --git a/capsule_network.py b/capsule_network.py index 1e69ba5..8bfa574 100644 --- a/capsule_network.py +++ b/capsule_network.py @@ -43,8 +43,8 @@ def __init__(self, unit_size=primary_unit_size, use_routing=False) - self.digits = CapsuleLayer(in_units=num_primary_units, - in_channels=primary_unit_size, + self.digits = CapsuleLayer(in_units=primary_unit_size, + in_channels=num_primary_units, num_units=num_output_units, unit_size=output_unit_size, use_routing=True) diff --git a/main.py b/main.py index ea43cbd..b3637d3 100644 --- a/main.py +++ b/main.py @@ -46,8 +46,8 @@ conv_inputs = 1 conv_outputs = 256 -num_primary_units = 8 -primary_unit_size = 32 * 6 * 6 # fixme get from conv2d +num_primary_units = 32 * 6 * 6 +primary_unit_size = 8 # fixme get from conv2d output_unit_size = 16 network = CapsuleNetwork(image_width=28, From d0e88771806b74dbf48067b8173727ccd97d3bb2 Mon Sep 17 00:00:00 2001 From: jiawei wang Date: Wed, 8 Nov 2017 18:19:34 +0800 Subject: [PATCH 3/3] #fixbug: squash the wrong dimention give rise to the bad result. --- capsule_conv_layer.py | 2 +- capsule_layer.py | 12 +++++++----- capsule_network.py | 8 ++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/capsule_conv_layer.py b/capsule_conv_layer.py index 7e94a16..53e810d 100644 --- a/capsule_conv_layer.py +++ b/capsule_conv_layer.py @@ -19,7 +19,7 @@ def __init__(self, in_channels, out_channels): out_channels=out_channels, kernel_size=9, # fixme constant stride=1, - bias=False) + bias=True) self.relu = nn.ReLU(inplace=True) diff --git a/capsule_layer.py b/capsule_layer.py index 55c30eb..34ceb97 100644 --- a/capsule_layer.py +++ b/capsule_layer.py @@ -19,7 +19,7 @@ def __init__(self, in_channels): out_channels=32, # fixme constant kernel_size=9, # fixme constant stride=2, - bias=False) # fixme constant + bias=True) # fixme constant def forward(self, x): return self.conv0(x) @@ -49,9 +49,9 @@ def create_conv_unit(unit_idx): self.units = [create_conv_unit(i) for i in range(self.unit_size)] @staticmethod - def squash(s): + def squash(s, dim=2): # This is equation 1 from the paper. - mag_sq = torch.sum(s**2, dim=2, keepdim=True) + mag_sq = torch.sum(s**2, dim, keepdim=True) mag = torch.sqrt(mag_sq) s = (mag_sq / (1.0 + mag_sq)) * (s / mag) return s @@ -74,7 +74,7 @@ def no_routing(self, x): u = u.view(x.size(0), self.unit_size, -1) # Return squashed outputs. - return CapsuleLayer.squash(u) + return CapsuleLayer.squash(u, dim=1) def routing(self, x): batch_size = x.size(0) @@ -85,7 +85,7 @@ def routing(self, x): # (batch, features, in_units) -> (batch, features, num_units, in_units, 1) x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4) - # (batch, features, in_units, unit_size, num_units) + # (batch, features, num_units, unit_size, in_units) W = torch.cat([self.W] * batch_size, dim=0) # Transform inputs by weight matrix. @@ -100,6 +100,7 @@ def routing(self, x): for iteration in range(num_iterations): # Convert routing logits to softmax. # (batch, features, num_units, 1, 1) + # fixme: seems apply wrong dimention here. but can't train the network if change it. weird. c_ij = F.softmax(b_ij) c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4) @@ -108,6 +109,7 @@ def routing(self, x): s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) # (batch_size, 1, num_units, unit_size, 1) + # fixme: seems apply wrong dimention here. but can't train the network if change it. weird. v_j = CapsuleLayer.squash(s_j) # (batch_size, features, num_units, unit_size, 1) diff --git a/capsule_network.py b/capsule_network.py index 8bfa574..930e587 100644 --- a/capsule_network.py +++ b/capsule_network.py @@ -75,8 +75,8 @@ def margin_loss(self, input, target, size_average=True): zero = Variable(torch.zeros(1)).cuda() m_plus = 0.9 m_minus = 0.1 - max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1) - max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1) + max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1) ** 2 + max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1) ** 2 # This is equation 4 from the paper. loss_lambda = 0.5 @@ -118,11 +118,11 @@ def reconstruction_loss(self, images, input, target, size_average=True): vutils.save_image(output_image, "reconstruction.png") self.reconstructed_image_count += 1 - # The reconstruction loss is the mean squared difference between the input image and reconstructed image. + # The reconstruction loss is the sum squared difference between the input image and reconstructed image. # Multiplied by a small number so it doesn't dominate the margin (class) loss. error = (output - images).view(output.size(0), -1) error = error**2 - error = torch.mean(error, dim=1) * 0.0005 + error = torch.sum(error, dim=1) * 0.0005 # Average over batch if size_average: