Skip to content

Commit

Permalink
Merge pull request #2 from WillSuen/master
Browse files Browse the repository at this point in the history
update LostGAN-v2
  • Loading branch information
iVMCL authored Jun 5, 2020
2 parents 69e5ad3 + 70f5d4f commit e8daaa9
Show file tree
Hide file tree
Showing 17 changed files with 1,685 additions and 1,140 deletions.
81 changes: 39 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,42 +1,39 @@
# LostGANs: Image Synthesis From Reconfigurable Layout and Style
This is implementation of our ICCV19 paper [**Image Synthesis From Reconfigurable Layout and Style**](https://arxiv.org/abs/1908.07500)

## Network Structure
![network_structure](./figures/network_structure.png)

## Installation
Check [INSTALL.md](INSTALL.md) for installation instructions.
#### 1. Download pretrained model
Download pretrained models to `pretrained_model/`
* Pretrained model on [COCO](https://drive.google.com/open?id=1WO6fLZqJeTUnmJTmieUopKLj9KGBhGd6)
* Pretrained model on [VG](https://drive.google.com/open?id=1A_gP_WwZWonlXJhwcdBDgHuVaGSiVjMO)

#### 2. Train models
```
python train.py --dataset coco --out_path outputs/
```

#### 3. Run pretrained model
```
python test.py --dataset coco --model_path pretrained_model/G_coco.pth --sample_path samples/coco/
```


## Results
###### Multiple samples generated from same layout
![various_out](./figures/various_outs.png)
###### Generation results by adding new objects or change spatial position of object
![add_obj](./figures/add_obj.png)
###### Linear interpolation of instance style
![style_morph](./figures/style_morph.png)
###### Synthesized images and learned masks for given layout
![mask](./figures/mask.png)

## Contact
Please feel free to report issues and any related problems to Wei Sun (wsun12 at ncsu dot edu) and Tianfu Wu (tianfu_wu at ncsu dot edu).


## Reference
* Synchronized-BatchNorm-PyTorch: [https://github.com/vacancy/Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch)
* Image Generation from Scene Graphs: [https://github.com/google/sg2im](https://github.com/google/sg2im)
* Faster R-CNN and Mask R-CNN in PyTorch 1.0: [https://github.com/facebookresearch/maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark)
# LostGANs: Image Synthesis From Reconfigurable Layout and Style
This is implementation of our paper [**Image Synthesis From Reconfigurable Layout and Style**](https://arxiv.org/abs/1908.07500) and [**Learning Layout and Style Reconfigurable GANs for Controllable Image Synthesis**](https://arxiv.org/abs/2003.11571)


## Network Structure
![network_structure](./figures/network_structure.png)

## Installation
Check [INSTALL.md](INSTALL.md) for installation instructions.
#### 1. Download pretrained model
Download pretrained [models](https://drive.google.com/drive/folders/1peI9d4PI7jJZJzFTcr-5mwZqnrNsX_3p?usp=sharing) to `pretrained_model/`

#### 2. Train models
```
python train.py --dataset coco --out_path outputs/
```

#### 3. Run pretrained model
```
python test.py --dataset coco --model_path pretrained_model/G_coco.pth --sample_path samples/coco/
```


## Results
###### Compare different models
![compare](./figures/generated_images.png)
###### Multiple samples generated from same layout
![various_out](./figures/various_outs.png)
###### Synthesized images and learned masks for given layout
![mask](./figures/mask.png)

## Contact
Please feel free to report issues and any related problems to Wei Sun (wsun12 at ncsu dot edu) and Tianfu Wu (tianfu_wu at ncsu dot edu).


## Reference
* Synchronized-BatchNorm-PyTorch: [https://github.com/vacancy/Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch)
* Image Generation from Scene Graphs: [https://github.com/google/sg2im](https://github.com/google/sg2im)
* Faster R-CNN and Mask R-CNN in PyTorch 1.0: [https://github.com/facebookresearch/maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark)
1 change: 1 addition & 0 deletions data/vg_splits.json

Large diffs are not rendered by default.

Binary file removed figures/add_obj.png
Binary file not shown.
Binary file added figures/generated_images.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/network_structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed figures/style_morph.png
Binary file not shown.
Binary file modified figures/various_outs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
158 changes: 102 additions & 56 deletions model/mask_regression.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,102 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.bilinear import *
from .sync_batchnorm import SynchronizedBatchNorm2d
from .norm_module import AdaptiveBatchNorm2
import pickle
import numpy as np
from torch.nn import Parameter


class MaskRegressNet(nn.Module):
def __init__(self, obj_feat=128, mask_size=16, map_size=64):
super(MaskRegressNet, self).__init__()
self.mask_size = mask_size
self.map_size = map_size

self.fc = nn.utils.spectral_norm(nn.Linear(obj_feat, 128 * 4 * 4))
conv1 = list()
conv1.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1)))
conv1.append(SynchronizedBatchNorm2d(128))
conv1.append(nn.ReLU())
self.conv1 = nn.Sequential(*conv1)

conv2 = list()
conv2.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1)))
conv2.append(SynchronizedBatchNorm2d(128))
conv2.append(nn.ReLU())
self.conv2 = nn.Sequential(*conv2)

conv3 = list()
conv3.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1)))
conv3.append(SynchronizedBatchNorm2d(128))
conv3.append(nn.ReLU())
conv3.append(nn.utils.spectral_norm(nn.Conv2d(128, 1, 1, 1)))
conv3.append(nn.Sigmoid())
self.conv3 = nn.Sequential(*conv3)

def forward(self, obj_feat, bbox):
"""
:param obj_feat: (b*num_o, feat_dim)
:param bbox: (b, num_o, 4)
:return: bbmap: (b, num_o, map_size, map_size)
"""
b, num_o, _ = bbox.size()
obj_feat = obj_feat.view(b * num_o, -1)
x = self.fc(obj_feat)
x = self.conv1(x.view(b * num_o, 128, 4, 4))
x = F.interpolate(x, size=8, mode='bilinear')
x = self.conv2(x)
x = F.interpolate(x, size=16, mode='bilinear')
x = self.conv3(x)
x = x.view(b, num_o, 16, 16)

bbmap = masks_to_layout(bbox, x, self.map_size).view(b, num_o, self.map_size, self.map_size)
return bbmap
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.bilinear import *
from .sync_batchnorm import SynchronizedBatchNorm2d
import pickle
import numpy as np
from torch.nn import Parameter


class MaskRegressNet(nn.Module):
def __init__(self, obj_feat=128, mask_size=16, map_size=64):
super(MaskRegressNet, self).__init__()
self.mask_size = mask_size
self.map_size = map_size

self.fc = nn.utils.spectral_norm(nn.Linear(obj_feat, 128 * 4 * 4))
conv1 = list()
conv1.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1)))
conv1.append(SynchronizedBatchNorm2d(128))
conv1.append(nn.ReLU())
self.conv1 = nn.Sequential(*conv1)

conv2 = list()
conv2.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1)))
conv2.append(SynchronizedBatchNorm2d(128))
conv2.append(nn.ReLU())
self.conv2 = nn.Sequential(*conv2)

conv3 = list()
conv3.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1)))
conv3.append(SynchronizedBatchNorm2d(128))
conv3.append(nn.ReLU())
conv3.append(nn.utils.spectral_norm(nn.Conv2d(128, 1, 1, 1)))
conv3.append(nn.Sigmoid())
self.conv3 = nn.Sequential(*conv3)

def forward(self, obj_feat, bbox):
"""
:param obj_feat: (b*num_o, feat_dim)
:param bbox: (b, num_o, 4)
:return: bbmap: (b, num_o, map_size, map_size)
"""
b, num_o, _ = bbox.size()
obj_feat = obj_feat.view(b * num_o, -1)
x = self.fc(obj_feat)
x = self.conv1(x.view(b * num_o, 128, 4, 4))
x = F.interpolate(x, size=8, mode='bilinear')
x = self.conv2(x)
x = F.interpolate(x, size=16, mode='bilinear')
x = self.conv3(x)
x = x.view(b, num_o, 16, 16)

bbmap = masks_to_layout(bbox, x, self.map_size).view(b, num_o, self.map_size, self.map_size)
return bbmap


class MaskRegressNetv2(nn.Module):
def __init__(self, obj_feat=128, mask_size=16, map_size=64):
super(MaskRegressNetv2, self).__init__()
self.mask_size = mask_size
self.map_size = map_size

self.fc = nn.utils.spectral_norm(nn.Linear(obj_feat, 256 * 4 * 4))
conv1 = list()
conv1.append(nn.utils.spectral_norm(nn.Conv2d(256, 256, 3, 1, 1)))
conv1.append(nn.InstanceNorm2d(256))
conv1.append(nn.ReLU())
self.conv1 = nn.Sequential(*conv1)

conv2 = list()
conv2.append(nn.utils.spectral_norm(nn.Conv2d(256, 256, 3, 1, 1)))
conv2.append(nn.InstanceNorm2d(256))
conv2.append(nn.ReLU())
self.conv2 = nn.Sequential(*conv2)

conv3 = list()
conv3.append(nn.utils.spectral_norm(nn.Conv2d(256, 256, 3, 1, 1)))
conv3.append(nn.InstanceNorm2d(256))
conv3.append(nn.ReLU())
conv3.append(nn.utils.spectral_norm(nn.Conv2d(256, 1, 1, 1)))
conv3.append(nn.Sigmoid())
self.conv3 = nn.Sequential(*conv3)

def forward(self, obj_feat, bbox):
"""
:param obj_feat: (b*num_o, feat_dim)
:param bbox: (b, num_o, 4)
:return: bbmap: (b, num_o, map_size, map_size)
"""
b, num_o, _ = bbox.size()
obj_feat = obj_feat.view(b * num_o, -1)
x = self.fc(obj_feat)
x = self.conv1(x.view(b * num_o, 256, 4, 4))
x = F.interpolate(x, size=8, mode='bilinear')
x = self.conv2(x)
x = F.interpolate(x, size=16, mode='bilinear')
x = self.conv3(x)
x = x.view(b, num_o, 16, 16)

bbmap = masks_to_layout(bbox, x, self.map_size).view(b, num_o, self.map_size, self.map_size)
return bbmap
21 changes: 10 additions & 11 deletions model/norm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,16 @@ def __repr__(self):
from .sync_batchnorm import SynchronizedBatchNorm2d


class SpatialAdaptiveSynBatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, num_w=512, eps=1e-5, momentum=0.1, affine=False,
class SpatialAdaptiveSynBatchNorm2d(nn.Module):
def __init__(self, num_features, num_w=512, batchnorm_func=SynchronizedBatchNorm2d, eps=1e-5, momentum=0.1, affine=False,
track_running_stats=True):
super(SpatialAdaptiveSynBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
super(SpatialAdaptiveSynBatchNorm2d, self).__init__()
# projection layer
self.num_features = num_features
self.weight_proj = nn.utils.spectral_norm(nn.Linear(num_w, num_features))
self.bias_proj = nn.utils.spectral_norm(nn.Linear(num_w, num_features))
# self.weight_proj = nn.Linear(num_w, num_features)
# self.bias_proj = nn.Linear(num_w, num_features)
self.batch_norm2d = SynchronizedBatchNorm2d(num_features, eps=self.eps, affine=False)
self.batch_norm2d = batchnorm_func(num_features, eps=eps, momentum=momentum,
affine=affine)

def forward(self, x, vector, bbox):
"""
Expand All @@ -169,12 +167,13 @@ def forward(self, x, vector, bbox):
:param bbox: bbox map (b, o, h, w)
:return:
"""
self._check_input_dim(x)
# self._check_input_dim(x)
output = self.batch_norm2d(x)

b, o, _, _ = bbox.size()
b, o, bh, bw = bbox.size()
_, _, h, w = x.size()
bbox = F.interpolate(bbox, size=(h, w), mode='bilinear')
if bh != h or bw != w:
bbox = F.interpolate(bbox, size=(h, w), mode='bilinear')
# calculate weight and bias
weight, bias = self.weight_proj(vector), self.bias_proj(vector)

Expand Down
Loading

0 comments on commit e8daaa9

Please sign in to comment.