-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from WillSuen/master
update LostGAN-v2
- Loading branch information
Showing
17 changed files
with
1,685 additions
and
1,140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,39 @@ | ||
# LostGANs: Image Synthesis From Reconfigurable Layout and Style | ||
This is implementation of our ICCV19 paper [**Image Synthesis From Reconfigurable Layout and Style**](https://arxiv.org/abs/1908.07500) | ||
|
||
## Network Structure | ||
![network_structure](./figures/network_structure.png) | ||
|
||
## Installation | ||
Check [INSTALL.md](INSTALL.md) for installation instructions. | ||
#### 1. Download pretrained model | ||
Download pretrained models to `pretrained_model/` | ||
* Pretrained model on [COCO](https://drive.google.com/open?id=1WO6fLZqJeTUnmJTmieUopKLj9KGBhGd6) | ||
* Pretrained model on [VG](https://drive.google.com/open?id=1A_gP_WwZWonlXJhwcdBDgHuVaGSiVjMO) | ||
|
||
#### 2. Train models | ||
``` | ||
python train.py --dataset coco --out_path outputs/ | ||
``` | ||
|
||
#### 3. Run pretrained model | ||
``` | ||
python test.py --dataset coco --model_path pretrained_model/G_coco.pth --sample_path samples/coco/ | ||
``` | ||
|
||
|
||
## Results | ||
###### Multiple samples generated from same layout | ||
![various_out](./figures/various_outs.png) | ||
###### Generation results by adding new objects or change spatial position of object | ||
![add_obj](./figures/add_obj.png) | ||
###### Linear interpolation of instance style | ||
![style_morph](./figures/style_morph.png) | ||
###### Synthesized images and learned masks for given layout | ||
![mask](./figures/mask.png) | ||
|
||
## Contact | ||
Please feel free to report issues and any related problems to Wei Sun (wsun12 at ncsu dot edu) and Tianfu Wu (tianfu_wu at ncsu dot edu). | ||
|
||
|
||
## Reference | ||
* Synchronized-BatchNorm-PyTorch: [https://github.com/vacancy/Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) | ||
* Image Generation from Scene Graphs: [https://github.com/google/sg2im](https://github.com/google/sg2im) | ||
* Faster R-CNN and Mask R-CNN in PyTorch 1.0: [https://github.com/facebookresearch/maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) | ||
# LostGANs: Image Synthesis From Reconfigurable Layout and Style | ||
This is implementation of our paper [**Image Synthesis From Reconfigurable Layout and Style**](https://arxiv.org/abs/1908.07500) and [**Learning Layout and Style Reconfigurable GANs for Controllable Image Synthesis**](https://arxiv.org/abs/2003.11571) | ||
|
||
|
||
## Network Structure | ||
![network_structure](./figures/network_structure.png) | ||
|
||
## Installation | ||
Check [INSTALL.md](INSTALL.md) for installation instructions. | ||
#### 1. Download pretrained model | ||
Download pretrained [models](https://drive.google.com/drive/folders/1peI9d4PI7jJZJzFTcr-5mwZqnrNsX_3p?usp=sharing) to `pretrained_model/` | ||
|
||
#### 2. Train models | ||
``` | ||
python train.py --dataset coco --out_path outputs/ | ||
``` | ||
|
||
#### 3. Run pretrained model | ||
``` | ||
python test.py --dataset coco --model_path pretrained_model/G_coco.pth --sample_path samples/coco/ | ||
``` | ||
|
||
|
||
## Results | ||
###### Compare different models | ||
![compare](./figures/generated_images.png) | ||
###### Multiple samples generated from same layout | ||
![various_out](./figures/various_outs.png) | ||
###### Synthesized images and learned masks for given layout | ||
![mask](./figures/mask.png) | ||
|
||
## Contact | ||
Please feel free to report issues and any related problems to Wei Sun (wsun12 at ncsu dot edu) and Tianfu Wu (tianfu_wu at ncsu dot edu). | ||
|
||
|
||
## Reference | ||
* Synchronized-BatchNorm-PyTorch: [https://github.com/vacancy/Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) | ||
* Image Generation from Scene Graphs: [https://github.com/google/sg2im](https://github.com/google/sg2im) | ||
* Faster R-CNN and Mask R-CNN in PyTorch 1.0: [https://github.com/facebookresearch/maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) |
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,102 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from utils.bilinear import * | ||
from .sync_batchnorm import SynchronizedBatchNorm2d | ||
from .norm_module import AdaptiveBatchNorm2 | ||
import pickle | ||
import numpy as np | ||
from torch.nn import Parameter | ||
|
||
|
||
class MaskRegressNet(nn.Module): | ||
def __init__(self, obj_feat=128, mask_size=16, map_size=64): | ||
super(MaskRegressNet, self).__init__() | ||
self.mask_size = mask_size | ||
self.map_size = map_size | ||
|
||
self.fc = nn.utils.spectral_norm(nn.Linear(obj_feat, 128 * 4 * 4)) | ||
conv1 = list() | ||
conv1.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1))) | ||
conv1.append(SynchronizedBatchNorm2d(128)) | ||
conv1.append(nn.ReLU()) | ||
self.conv1 = nn.Sequential(*conv1) | ||
|
||
conv2 = list() | ||
conv2.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1))) | ||
conv2.append(SynchronizedBatchNorm2d(128)) | ||
conv2.append(nn.ReLU()) | ||
self.conv2 = nn.Sequential(*conv2) | ||
|
||
conv3 = list() | ||
conv3.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1))) | ||
conv3.append(SynchronizedBatchNorm2d(128)) | ||
conv3.append(nn.ReLU()) | ||
conv3.append(nn.utils.spectral_norm(nn.Conv2d(128, 1, 1, 1))) | ||
conv3.append(nn.Sigmoid()) | ||
self.conv3 = nn.Sequential(*conv3) | ||
|
||
def forward(self, obj_feat, bbox): | ||
""" | ||
:param obj_feat: (b*num_o, feat_dim) | ||
:param bbox: (b, num_o, 4) | ||
:return: bbmap: (b, num_o, map_size, map_size) | ||
""" | ||
b, num_o, _ = bbox.size() | ||
obj_feat = obj_feat.view(b * num_o, -1) | ||
x = self.fc(obj_feat) | ||
x = self.conv1(x.view(b * num_o, 128, 4, 4)) | ||
x = F.interpolate(x, size=8, mode='bilinear') | ||
x = self.conv2(x) | ||
x = F.interpolate(x, size=16, mode='bilinear') | ||
x = self.conv3(x) | ||
x = x.view(b, num_o, 16, 16) | ||
|
||
bbmap = masks_to_layout(bbox, x, self.map_size).view(b, num_o, self.map_size, self.map_size) | ||
return bbmap | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from utils.bilinear import * | ||
from .sync_batchnorm import SynchronizedBatchNorm2d | ||
import pickle | ||
import numpy as np | ||
from torch.nn import Parameter | ||
|
||
|
||
class MaskRegressNet(nn.Module): | ||
def __init__(self, obj_feat=128, mask_size=16, map_size=64): | ||
super(MaskRegressNet, self).__init__() | ||
self.mask_size = mask_size | ||
self.map_size = map_size | ||
|
||
self.fc = nn.utils.spectral_norm(nn.Linear(obj_feat, 128 * 4 * 4)) | ||
conv1 = list() | ||
conv1.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1))) | ||
conv1.append(SynchronizedBatchNorm2d(128)) | ||
conv1.append(nn.ReLU()) | ||
self.conv1 = nn.Sequential(*conv1) | ||
|
||
conv2 = list() | ||
conv2.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1))) | ||
conv2.append(SynchronizedBatchNorm2d(128)) | ||
conv2.append(nn.ReLU()) | ||
self.conv2 = nn.Sequential(*conv2) | ||
|
||
conv3 = list() | ||
conv3.append(nn.utils.spectral_norm(nn.Conv2d(128, 128, 3, 1, 1))) | ||
conv3.append(SynchronizedBatchNorm2d(128)) | ||
conv3.append(nn.ReLU()) | ||
conv3.append(nn.utils.spectral_norm(nn.Conv2d(128, 1, 1, 1))) | ||
conv3.append(nn.Sigmoid()) | ||
self.conv3 = nn.Sequential(*conv3) | ||
|
||
def forward(self, obj_feat, bbox): | ||
""" | ||
:param obj_feat: (b*num_o, feat_dim) | ||
:param bbox: (b, num_o, 4) | ||
:return: bbmap: (b, num_o, map_size, map_size) | ||
""" | ||
b, num_o, _ = bbox.size() | ||
obj_feat = obj_feat.view(b * num_o, -1) | ||
x = self.fc(obj_feat) | ||
x = self.conv1(x.view(b * num_o, 128, 4, 4)) | ||
x = F.interpolate(x, size=8, mode='bilinear') | ||
x = self.conv2(x) | ||
x = F.interpolate(x, size=16, mode='bilinear') | ||
x = self.conv3(x) | ||
x = x.view(b, num_o, 16, 16) | ||
|
||
bbmap = masks_to_layout(bbox, x, self.map_size).view(b, num_o, self.map_size, self.map_size) | ||
return bbmap | ||
|
||
|
||
class MaskRegressNetv2(nn.Module): | ||
def __init__(self, obj_feat=128, mask_size=16, map_size=64): | ||
super(MaskRegressNetv2, self).__init__() | ||
self.mask_size = mask_size | ||
self.map_size = map_size | ||
|
||
self.fc = nn.utils.spectral_norm(nn.Linear(obj_feat, 256 * 4 * 4)) | ||
conv1 = list() | ||
conv1.append(nn.utils.spectral_norm(nn.Conv2d(256, 256, 3, 1, 1))) | ||
conv1.append(nn.InstanceNorm2d(256)) | ||
conv1.append(nn.ReLU()) | ||
self.conv1 = nn.Sequential(*conv1) | ||
|
||
conv2 = list() | ||
conv2.append(nn.utils.spectral_norm(nn.Conv2d(256, 256, 3, 1, 1))) | ||
conv2.append(nn.InstanceNorm2d(256)) | ||
conv2.append(nn.ReLU()) | ||
self.conv2 = nn.Sequential(*conv2) | ||
|
||
conv3 = list() | ||
conv3.append(nn.utils.spectral_norm(nn.Conv2d(256, 256, 3, 1, 1))) | ||
conv3.append(nn.InstanceNorm2d(256)) | ||
conv3.append(nn.ReLU()) | ||
conv3.append(nn.utils.spectral_norm(nn.Conv2d(256, 1, 1, 1))) | ||
conv3.append(nn.Sigmoid()) | ||
self.conv3 = nn.Sequential(*conv3) | ||
|
||
def forward(self, obj_feat, bbox): | ||
""" | ||
:param obj_feat: (b*num_o, feat_dim) | ||
:param bbox: (b, num_o, 4) | ||
:return: bbmap: (b, num_o, map_size, map_size) | ||
""" | ||
b, num_o, _ = bbox.size() | ||
obj_feat = obj_feat.view(b * num_o, -1) | ||
x = self.fc(obj_feat) | ||
x = self.conv1(x.view(b * num_o, 256, 4, 4)) | ||
x = F.interpolate(x, size=8, mode='bilinear') | ||
x = self.conv2(x) | ||
x = F.interpolate(x, size=16, mode='bilinear') | ||
x = self.conv3(x) | ||
x = x.view(b, num_o, 16, 16) | ||
|
||
bbmap = masks_to_layout(bbox, x, self.map_size).view(b, num_o, self.map_size, self.map_size) | ||
return bbmap |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.