Skip to content

Commit

Permalink
add attention unet config
Browse files Browse the repository at this point in the history
  • Loading branch information
aigcliu committed Dec 14, 2020
1 parent 363c0d1 commit 26b8799
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
5 changes: 5 additions & 0 deletions configs/attention_unet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Attention U-Net: Learning Where to Look for the Pancreas

## Reference

> Oktay, Ozan, Jo Schlemper, Loic Le Folgoc, Matthew Lee, Mattias Heinrich, Kazunari Misawa, Kensaku Mori et al. "Attention u-net: Learning where to look for the pancreas." arXiv preprint arXiv:1804.03999 (2018).
15 changes: 15 additions & 0 deletions configs/attention_unet/attention_unet_cityscapes_1024x512_80k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_: '../_base_/cityscapes.yml'

batch_size: 2
iters: 80000

learning_rate:
value: 0.05
decay:
type: poly
power: 0.9
end_lr: 0.0

model:
type: AttentionUNet
pretrained: Null
46 changes: 23 additions & 23 deletions paddleseg/models/attention_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ class AttentionUNet(nn.Layer):
As mentioned in the original paper, author proposes a novel attention gate (AG)
that automatically learns to focus on target structures of varying shapes and sizes.
Models trained with AGs implicitly learn to suppress irrelevant regions in an input image while
highlighting salient features useful for a specific task
highlighting salient features useful for a specific task.
The original article refers to
Oktay, O, et, al. "Attention u-net: Learning where to look for the pancreas."
(https://arxiv.org/pdf/1804.03999.pdf).
Args:
num_classes (int): The unique number of target classes.
pretrained (str, optional): The path or url of pretrained model. Default: None.
Expand All @@ -41,22 +44,27 @@ def __init__(self, num_classes, pretrained=None):
self.encoder = Encoder(n_channels, [64, 128, 256, 512])
filters = np.array([64, 128, 256, 512, 1024])
self.up5 = UpConv(ch_in=filters[4], ch_out=filters[3])
self.att5 = AttentionBlock(F_g=filters[3], F_l=filters[3], F_out=filters[2])
self.att5 = AttentionBlock(
F_g=filters[3], F_l=filters[3], F_out=filters[2])
self.up_conv5 = ConvBlock(ch_in=filters[4], ch_out=filters[3])

self.up4 = UpConv(ch_in=filters[3], ch_out=filters[2])
self.att4 = AttentionBlock(F_g=filters[2], F_l=filters[2], F_out=filters[1])
self.att4 = AttentionBlock(
F_g=filters[2], F_l=filters[2], F_out=filters[1])
self.up_conv4 = ConvBlock(ch_in=filters[3], ch_out=filters[2])

self.up3 = UpConv(ch_in=filters[2], ch_out=filters[1])
self.att3 = AttentionBlock(F_g=filters[1], F_l=filters[1], F_out=filters[0])
self.att3 = AttentionBlock(
F_g=filters[1], F_l=filters[1], F_out=filters[0])
self.up_conv3 = ConvBlock(ch_in=filters[2], ch_out=filters[1])

self.up2 = UpConv(ch_in=filters[1], ch_out=filters[0])
self.att2 = AttentionBlock(F_g=filters[0], F_l=filters[0], F_out=filters[0] // 2)
self.att2 = AttentionBlock(
F_g=filters[0], F_l=filters[0], F_out=filters[0] // 2)
self.up_conv2 = ConvBlock(ch_in=filters[1], ch_out=filters[0])

self.conv_1x1 = nn.Conv2D(filters[0], num_classes, kernel_size=1, stride=1, padding=0)
self.conv_1x1 = nn.Conv2D(
filters[0], num_classes, kernel_size=1, stride=1, padding=0)
self.pretrained = pretrained
self.init_weight()

Expand Down Expand Up @@ -96,19 +104,15 @@ def __init__(self, F_g, F_l, F_out):
super().__init__()
self.W_g = nn.Sequential(
nn.Conv2D(F_g, F_out, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2D(F_out)
)
nn.BatchNorm2D(F_out))

self.W_x = nn.Sequential(
nn.Conv2D(F_l, F_out, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2D(F_out)
)
nn.BatchNorm2D(F_out))

self.psi = nn.Sequential(
nn.Conv2D(F_out, 1, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2D(1),
nn.Sigmoid()
)
nn.BatchNorm2D(1), nn.Sigmoid())

self.relu = nn.ReLU()

Expand All @@ -125,11 +129,9 @@ class UpConv(nn.Layer):
def __init__(self, ch_in, ch_out):
super().__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Upsample(scale_factor=2, mode="bilinear"),
nn.Conv2D(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2D(ch_out),
nn.ReLU()
)
nn.BatchNorm2D(ch_out), nn.ReLU())

def forward(self, x):
return self.up(x)
Expand All @@ -139,7 +141,8 @@ class Encoder(nn.Layer):
def __init__(self, input_channels, filters):
super().__init__()
self.double_conv = nn.Sequential(
layers.ConvBNReLU(input_channels, 64, 3), layers.ConvBNReLU(64, 64, 3))
layers.ConvBNReLU(input_channels, 64, 3),
layers.ConvBNReLU(64, 64, 3))
down_channels = filters
self.down_sample_list = nn.LayerList([
self.down_sampling(channel, channel * 2)
Expand Down Expand Up @@ -167,12 +170,9 @@ def __init__(self, ch_in, ch_out):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2D(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2D(ch_out),
nn.ReLU(),
nn.BatchNorm2D(ch_out), nn.ReLU(),
nn.Conv2D(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2D(ch_out),
nn.ReLU()
)
nn.BatchNorm2D(ch_out), nn.ReLU())

def forward(self, x):
return self.conv(x)

0 comments on commit 26b8799

Please sign in to comment.