Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

U-net implementation #247

Merged
merged 8 commits into from
Sep 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
from pl_bolts.datamodules.dummy_dataset import DummyDetectionDataset
from pl_bolts.datamodules.dummy_dataset import DummyDataset, DummyDetectionDataset

try:
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
Expand Down
1 change: 1 addition & 0 deletions pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from pl_bolts.models.regression import LinearRegression, LogisticRegression
from pl_bolts.models.vision import PixelCNN
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT
from pl_bolts.models.vision import UNet
1 change: 1 addition & 0 deletions pl_bolts/models/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pl_bolts.models.vision.pixel_cnn import PixelCNN
from pl_bolts.models.vision.unet import UNet
129 changes: 129 additions & 0 deletions pl_bolts/models/vision/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNet(nn.Module):
"""
PyTorch Lightning implementation of `U-Net: Convolutional Networks for Biomedical Image Segmentation
<https://arxiv.org/abs/1505.04597>`_

Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox

Model implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_
- `Akshay Kulkarni <https://github.com/akshaykvnit>`_

.. warning:: Work in progress. This implementation is still being verified.

Args:
num_classes: Number of output classes required
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
"""

def __init__(
self,
num_classes: int,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
super().__init__()
self.num_layers = num_layers

layers = [DoubleConv(3, features_start)]

feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2

for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2, bilinear))
feats //= 2

layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

self.layers = nn.ModuleList(layers)

def forward(self, x):
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1:self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers:-1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])


class DoubleConv(nn.Module):
"""
[ Conv2d => BatchNorm (optional) => ReLU ] x 2
"""

def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.net(x)


class Down(nn.Module):
"""
Downscale with MaxPool => DoubleConvolution block
"""

def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
DoubleConv(in_ch, out_ch)
)

def forward(self, x):
return self.net(x)


class Up(nn.Module):
"""
Upsampling (by either bilinear interpolation or transpose convolutions)
followed by concatenation of feature map from contracting path,
followed by DoubleConv.
"""

def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
super().__init__()
self.upsample = None
if bilinear:
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
)
else:
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)

self.conv = DoubleConv(in_ch, out_ch)

def forward(self, x1, x2):
x1 = self.upsample(x1)

# Pad x1 to the size of x2
diff_h = x2.shape[2] - x1.shape[2]
diff_w = x2.shape[3] - x1.shape[3]

x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])

# Concatenate along the channels axis
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
11 changes: 9 additions & 2 deletions tests/models/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import torch

from pl_bolts.datamodules import MNISTDataModule, FashionMNISTDataModule
from pl_bolts.models import GPT2, ImageGPT

from pl_bolts.models import GPT2, ImageGPT, UNet

def test_igpt(tmpdir):
pl.seed_everything(0)
Expand Down Expand Up @@ -47,3 +46,11 @@ def test_gpt2(tmpdir):
num_classes=10,
)
model(x)


def test_unet(tmpdir):
x = torch.rand(10, 3, 28, 28)
model = UNet(num_classes=2)
y = model(x)
assert y.shape == torch.Size([10, 2, 28, 28])