-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Konstantin Sofiiuk
committed
Jun 1, 2020
0 parents
commit 5f6e613
Showing
87 changed files
with
9,187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 | ||
|
||
RUN apt-get update && apt-get install -y --no-install-recommends \ | ||
build-essential \ | ||
git \ | ||
curl \ | ||
libglib2.0-0 \ | ||
software-properties-common \ | ||
python3.6-dev \ | ||
python3-pip | ||
|
||
WORKDIR /tmphtop | ||
|
||
RUN pip3 install --upgrade pip | ||
RUN pip3 install setuptools | ||
RUN pip3 install matplotlib numpy pandas scipy tqdm pyyaml easydict scikit-image bridson Pillow ninja | ||
RUN pip3 install imgaug mxboard graphviz | ||
RUN pip3 install git+https://github.com/albu/albumentations --no-deps | ||
RUN pip3 install opencv-python-headless | ||
RUN pip3 install Cython | ||
RUN pip3 install torch | ||
RUN pip3 install torchvision | ||
RUN pip3 install scikit-learn | ||
RUN pip3 install tensorboard | ||
|
||
RUN mkdir /work | ||
WORKDIR /work | ||
RUN chmod -R 777 /work && chmod -R 777 /root | ||
|
||
ENV TINI_VERSION v0.18.0 | ||
ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /usr/bin/tini | ||
RUN chmod +x /usr/bin/tini | ||
ENTRYPOINT [ "/usr/bin/tini", "--" ] | ||
CMD [ "/bin/bash" ] |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
# Foreground-aware Semantic Representations for Image Harmonization | ||
|
||
<p align="center"> | ||
<img src="./images/ih_teaser.jpg" alt="drawing", width="800"/> | ||
</p> | ||
|
||
This repository contains the official PyTorch implementation of the following paper: | ||
> **Foreground-aware Semantic Representations for Image Harmonization**<br> | ||
> Konstantin Sofiiuk, Polina Popenova, Anton Konushin<br> | ||
> Samsung AI Center Moscow<br> | ||
> | ||
> **Abstract:** | ||
> Image harmonization is an important step in photo editing to achieve visual consistency in composite images by adjusting the appearances of foreground to make it compatible with background. | ||
> Previous approaches to harmonize composites are based on training of encoder-decoder networks from scratch, which makes it challenging for a neural network to learn a high-level representation of objects. | ||
> We propose a novel architecture to utilize the space of high-level features learned by a pre-trained classification network. | ||
> We create our models as a combination of existing encoder-decoder architectures and a pre-trained foreground-aware deep high-resolution network. | ||
> We extensively evaluate the proposed method on existing image harmonization benchmark and set up a new state-of-the-art in terms of MSE and PSNR metrics. | ||
## Setting up an environment | ||
|
||
This framework is built using Python 3.6 and relies on the PyTorch 1.4.0+. The following command installs all necessary packages: | ||
|
||
```.bash | ||
pip3 install -r requirements.txt | ||
``` | ||
|
||
You can also use our [Dockerfile](./Dockerfile) to build a container with configured environment. | ||
|
||
If you want to run training or testing, you must configure the paths to the datasets in [config.yml](./config.yml). | ||
|
||
## Datasets | ||
We train and evaluate all our models on the [iHarmony4 dataset](https://github.com/bcmi/Image_Harmonization_Datasets). | ||
It contains 65742 training and 7404 test objects. Each object is a triple consisting of real image, composite image and foreground mask. | ||
|
||
Before training we resize HAdobe5k subdataset so that each side is smaller than 1024. | ||
The resizing script is provided in [resize_hdataset.ipynb](./notebooks/resize_hdataset.ipynb). | ||
|
||
Don't forget to change the paths to the datasets in [config.yml](./config.yml) after downloading and unpacking. | ||
|
||
## Training | ||
|
||
We provide the scripts for training our models on images of size 256 and 512. | ||
For each experiment, a separate folder is created in the `./harmonization_exps` with Tensorboard logs, text logs, visualization and model's checkpoints. | ||
You can specify another path in the [config.yml](./config.yml) (see `EXPS_PATH` variable). | ||
|
||
Start training with the following commands: | ||
```.bash | ||
python3 train.py <model-script-path> --gpus=0 --workers=4 --exp-name=first-try | ||
|
||
# iDIH: fully convolutional Encoder-Decoder with output image blending and foreground-normalized MSE loss | ||
python3 train.py models/fixed256/improved_dih.py --gpus=0 --workers=4 --exp-name=first-try | ||
|
||
# HRNet18s-V2p + iDIH: feature pyramid of 4 HRNet18-small-V2 outputs is concatenated to 4 outputs of the iDIH encoder | ||
python3 train.py models/fixed256/hrnet18_idih.py --gpus=0 --workers=4 --exp-name=first-try | ||
|
||
# HRNet18-V2 + iDIH: single output of HRNet18-V2 is concatenated to single output of the iDIH encoder | ||
python3 train.py models/fixed256/hrnet18_idih.py --gpus=0 --workers=4 --exp-name=first-try | ||
|
||
# iDIH trained on 512x512 | ||
python3 train.py models/crop512/improved_dih.py --gpus=0 --workers=4 --exp-name=first-try | ||
``` | ||
To see all training parameters, run `python3 train.py --help`. | ||
|
||
We used pre-trained HRNetV2 models from the [official repository](https://github.com/HRNet/HRNet-Image-Classification). | ||
To train one of our models with HRNet backbone, download HRNet weights and specify their path in [config.yml](./config.yml) (see `IMAGENET_PRETRAINED_MODELS` variable). | ||
|
||
## Evaluation | ||
We provide scripts to both evaluate and get predictions from any model. | ||
To do that, we specify all models configs in [mconfigs](./iharm/mconfigs). | ||
To evaluate a model different from the provided, a new config entry should be added. | ||
|
||
You can specify the checkpoints path in [config.yml](./config.yml) (see `MODELS_PATH` variable) in advance | ||
and provide the scripts only with a checkpoint name instead of an absolute checkpoint path. | ||
|
||
### Evaluate model | ||
To get metrics table on the iHarmony4 test set run the following command: | ||
```.bash | ||
python3 scripts/evaluate_model.py <model-name> <checkpoint-path> --resize-strategy Fixed256 | ||
|
||
# iDIH | ||
python3 scripts/evaluate_model.py improved_dih256 /hdd0/harmonization_exps/fixed256/improved_dih/checkpoints/last_checkpoint.pth --resize-strategy Fixed256 | ||
``` | ||
To see all evaluation parameters run `python3 scripts/evaluate_model.py --help`. | ||
|
||
### Get model predictions | ||
To get predictions on a set of images, run the following command: | ||
```.bash | ||
python3 scripts/predict_for_dir.py <model-name> <checkpoint-path> --images <composite-images-path> --masks <masks-path> --resize 256 | ||
|
||
# iDIH | ||
python3 scripts/evaluate_model.py improved_dih256 /hdd0/harmonization_exps/fixed256/improved_dih/checkpoints/last_checkpoint.pth \ | ||
--images /hdd0/datasets/ImageHarmonization/test/composite_images --masks /hdd0/datasets/ImageHarmonization/test/masks \ | ||
--resize 256 | ||
``` | ||
To see all evaluation parameters run `python3 scripts/predict_for_dir.py --help`. | ||
|
||
### Jupyter notebook | ||
For interactive models testing with samples visualization see [eval_and_vis_harmonization_model.ipynb](./notebooks/eval_and_vis_harmonization_model.ipynb). | ||
|
||
## Results | ||
We provide metrics and pre-trained weights for several models trained on images of size 256x256 augmented with horizontal flip and random resized crop. | ||
Metric values may differ slightly from the ones in the paper since all the models were retrained from scratch with the new codebase. | ||
|
||
Pre-trained models: | ||
TODO | ||
<table class="tg"> | ||
<tr> | ||
<th class="tg-0pky">Model</th> | ||
<th class="tg-0pky">Link</th> | ||
</tr> | ||
</table> | ||
|
||
Evaluation metrics: | ||
<table class="tg"> | ||
<tr> | ||
<th class="tg-0pky">Model</th> | ||
<th class="tg-0pky" colspan="2">HCOCO</th> | ||
<th class="tg-0pky" colspan="2">HAdobe5k</th> | ||
<th class="tg-0pky" colspan="2">HFlickr</th> | ||
<th class="tg-0pky" colspan="2">Hday2night</th> | ||
<th class="tg-0pky" colspan="2">All</th> | ||
</tr> | ||
<tr> | ||
<td class="tg-0pky">Evaluation metric</td> | ||
<td class="tg-0pky">MSE</td> | ||
<td class="tg-0pky">PSNR</td> | ||
<td class="tg-0pky">MSE</td> | ||
<td class="tg-0pky">PSNR</td> | ||
<td class="tg-0pky">MSE</td> | ||
<td class="tg-0pky">PSNR</td> | ||
<td class="tg-0pky">MSE</td> | ||
<td class="tg-0pky">PSNR</td> | ||
<td class="tg-0pky">MSE</td> | ||
<td class="tg-0pky">PSNR</td> | ||
</tr> | ||
<tr> | ||
<th class="tg-0pky" colspan="11">Base models</th> | ||
</tr> | ||
<tr> | ||
<td class="tg-0pky">iDIH256</td> | ||
<td class="tg-0pky">19.58</td> | ||
<td class="tg-0pky">38.34</td> | ||
<td class="tg-0pky">30.84</td> | ||
<td class="tg-0pky">36.00</td> | ||
<td class="tg-0pky">84.74</td> | ||
<td class="tg-0pky">32.58</td> | ||
<td class="tg-0pky">50.05</td> | ||
<td class="tg-0pky">37.10</td> | ||
<td class="tg-0pky">30.70</td> | ||
<td class="tg-0pky">36.99</td> | ||
</tr> | ||
<tr> | ||
<td class="tg-0pky">iSSAM256</td> | ||
<td class="tg-0pky">16.48</td> | ||
<td class="tg-0pky">39.16</td> | ||
<td class="tg-0pky">22.60</td> | ||
<td class="tg-0pky">37.24</td> | ||
<td class="tg-0pky">69.67</td> | ||
<td class="tg-0pky">33.56</td> | ||
<td class="tg-0pky">40.59</td> | ||
<td class="tg-0pky">37.72</td> | ||
<td class="tg-0pky">24.65</td> | ||
<td class="tg-0pky">37.95</td> | ||
</tr> | ||
<tr> | ||
<th class="tg-0pky" colspan="11">iDIH256 with backbone</th> | ||
</tr> | ||
<tr> | ||
<td class="tg-0pky">DeepLab-ResNet34</td> | ||
<td class="tg-0pky">17.68</td> | ||
<td class="tg-0pky">38.97</td> | ||
<td class="tg-0pky">28.13</td> | ||
<td class="tg-0pky">36.33</td> | ||
<td class="tg-0pky">70.89</td> | ||
<td class="tg-0pky">33.25</td> | ||
<td class="tg-0pky">56.17</td> | ||
<td class="tg-0pky">37.25</td> | ||
<td class="tg-0pky">27.37</td> | ||
<td class="tg-0pky">37.53</td> | ||
</tr> | ||
<tr> | ||
<td class="tg-0pky">HRNet18s</td> | ||
<td class="tg-0pky">14.30</td> | ||
<td class="tg-0pky">39.52</td> | ||
<td class="tg-0pky">22.57</td> | ||
<td class="tg-0pky">37.18</td> | ||
<td class="tg-0pky">63.03</td> | ||
<td class="tg-0pky">33.70</td> | ||
<td class="tg-0pky">51.20</td> | ||
<td class="tg-0pky">37.69</td> | ||
<td class="tg-0pky">22.82</td> | ||
<td class="tg-0pky">38.15</td> | ||
</tr> | ||
<tr> | ||
<td class="tg-0pky">HRNet18</td> | ||
<td class="tg-0pky">13.79</td> | ||
<td class="tg-0pky">39.62</td> | ||
<td class="tg-0pky">25.44</td> | ||
<td class="tg-0pky">36.91</td> | ||
<td class="tg-0pky">60.63</td> | ||
<td class="tg-0pky">33.88</td> | ||
<td class="tg-0pky">44.94</td> | ||
<td class="tg-0pky">37.74</td> | ||
<td class="tg-0pky">22.99</td> | ||
<td class="tg-0pky">38.16</td> | ||
</tr> | ||
<tr> | ||
<td class="tg-0pky">HRNet32</td> | ||
<td class="tg-0pky">14.00</td> | ||
<td class="tg-0pky">39.71</td> | ||
<td class="tg-0pky">23.04</td> | ||
<td class="tg-0pky">37.13</td> | ||
<td class="tg-0pky">57.55</td> | ||
<td class="tg-0pky">34.06</td> | ||
<td class="tg-0pky">53.70</td> | ||
<td class="tg-0pky">37.70</td> | ||
<td class="tg-0pky">22.22</td> | ||
<td class="tg-0pky">38.29</td> | ||
</tr> | ||
</table> | ||
|
||
## License | ||
The code is released under the MPL 2.0 License. MPL is a copyleft license that is easy to comply with. You must make the source code for any of your changes available under MPL, but you can combine the MPL software with proprietary code, as long as you keep the MPL code in separate files. | ||
|
||
## Citation | ||
If you find this work is useful for your research, please cite our paper: | ||
|
||
``` | ||
@article{sofiiuk2020harmonization, | ||
title={Foreground-aware Semantic Representations for Image Harmonization}, | ||
author={Konstantin Sofiiuk, Polina Popenova, Anton Konushin}, | ||
journal={arXiv preprint arXiv:20??.?????}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
MODELS_PATH: "/hdd0/harmonization_exps/models" | ||
EXPS_PATH: "/hdd0/harmonization_exps" | ||
|
||
HFLICKR_PATH: "/hdd0/datasets/ImageHarmonization/HFlickr" | ||
HDAY2NIGHT_PATH: "/hdd0/datasets/ImageHarmonization/Hday2night" | ||
HCOCO_PATH: "/hdd0/datasets/ImageHarmonization/HCOCO" | ||
HADOBE5K_PATH: "/hdd0/datasets/ImageHarmonization/HAdobe5k" | ||
|
||
IMAGENET_PRETRAINED_MODELS: | ||
HRNETV2_W18_SMALL: "./pretrained_models/hrnet_w18_small_model_v2.pth" | ||
HRNETV2_W18: "./pretrained_models/hrnetv2_w18_imagenet_pretrained.pth" | ||
HRNETV2_W32: "./pretrained_models/hrnetv2_w32_imagenet_pretrained.pth" | ||
HRNETV2_W40: "./pretrained_models/hrnetv2_w40_imagenet_pretrained.pth" | ||
HRNETV2_W48: "./pretrained_models/hrnetv2_w48_imagenet_pretrained.pth" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import random | ||
import numpy as np | ||
import torch | ||
|
||
|
||
class BaseHDataset(torch.utils.data.dataset.Dataset): | ||
def __init__(self, | ||
augmentator=None, | ||
input_transform=None, | ||
keep_background_prob=0.0, | ||
with_image_info=False, | ||
epoch_len=-1): | ||
super(BaseHDataset, self).__init__() | ||
self.epoch_len = epoch_len | ||
self.input_transform = input_transform | ||
self.augmentator = augmentator | ||
self.keep_background_prob = keep_background_prob | ||
self.with_image_info = with_image_info | ||
|
||
if input_transform is None: | ||
input_transform = lambda x: x | ||
|
||
self.input_transform = input_transform | ||
self.dataset_samples = None | ||
|
||
def __getitem__(self, index): | ||
if self.epoch_len > 0: | ||
index = random.randrange(0, len(self.dataset_samples)) | ||
|
||
sample = self.get_sample(index) | ||
self.check_sample_types(sample) | ||
sample = self.augment_sample(sample) | ||
|
||
image = self.input_transform(sample['image']) | ||
target_image = self.input_transform(sample['target_image']) | ||
obj_mask = sample['object_mask'].astype(np.float32) | ||
|
||
output = { | ||
'images': image, | ||
'masks': obj_mask[np.newaxis, ...].astype(np.float32), | ||
'target_images': target_image | ||
} | ||
|
||
if self.with_image_info and 'image_id' in sample: | ||
output['image_info'] = sample['image_id'] | ||
return output | ||
|
||
def check_sample_types(self, sample): | ||
assert sample['image'].dtype == 'uint8' | ||
if 'target_image' in sample: | ||
assert sample['target_image'].dtype == 'uint8' | ||
|
||
def augment_sample(self, sample): | ||
if self.augmentator is None: | ||
return sample | ||
|
||
additional_targets = {target_name: sample[target_name] | ||
for target_name in self.augmentator.additional_targets.keys()} | ||
|
||
valid_augmentation = False | ||
while not valid_augmentation: | ||
aug_output = self.augmentator(image=sample['image'], **additional_targets) | ||
valid_augmentation = self.check_augmented_sample(sample, aug_output) | ||
|
||
for target_name, transformed_target in aug_output.items(): | ||
sample[target_name] = transformed_target | ||
|
||
return sample | ||
|
||
def check_augmented_sample(self, sample, aug_output): | ||
if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob: | ||
return True | ||
|
||
return aug_output['object_mask'].sum() > 1.0 | ||
|
||
def get_sample(self, index): | ||
raise NotImplementedError | ||
|
||
def __len__(self): | ||
if self.epoch_len > 0: | ||
return self.epoch_len | ||
else: | ||
return len(self.dataset_samples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .base import BaseHDataset | ||
|
||
|
||
class ComposeDataset(BaseHDataset): | ||
def __init__(self, datasets, **kwargs): | ||
super(ComposeDataset, self).__init__(**kwargs) | ||
|
||
self._datasets = datasets | ||
self.dataset_samples = [] | ||
for dataset_indx, dataset in enumerate(self._datasets): | ||
self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) | ||
|
||
def get_sample(self, index): | ||
dataset_indx, sample_indx = self.dataset_samples[index] | ||
return self._datasets[dataset_indx].get_sample(sample_indx) |
Oops, something went wrong.