PyTorch implementaton of our WACV 2023 paper "Adaptively-Realistic Image Generation from Stroke and Sketch with Diffusion Model". You can visit our project website here.
In this paper, we propose a unified framework supporting a three-dimensional control over the image synthesis from sketches and strokes based on diffusion models.
Adaptively-Realistic Image Generation from Stroke and Sketch with Diffusion Model
Shin-I Cheng*, Yu-Jie Chen*, Wei-Chen Chiu, Hung-Yu Tseng, Hsin-Ying Lee
Winter Conference on Applications of Computer Vision (WACV), 2023 (* equal contribution)
Please cite our paper if you find it useful for your research.
@inproceedings{cheng2023wacv,
title = {Adaptively-Realistic Image Generation from Stroke and Sketch with Diffusion Model},
author = {Shin-I Cheng and Yu-Jie Chen and Wei-Chen Chiu and Hung-Yu Tseng and Hsin-Ying Lee},
booktitle = {IEEE Winter Conference on Applications of Computer Vision (WACV)},
year = {2023}
}
- Prerequisities: Python 3.6 & Pytorch (at least 1.4.0)
- Clone this repo
git clone https://github.com/cyj407/DiSS.git
cd DiSS
- We provide a conda environment script, please run the following command after cloning our repo.
conda env create -f diss_env.yaml
- AFHQ (cat, dog, wildlife) You can follow the instructions in StarGAN v2 website to download the AFHQ (cat, dog, wildlife) dataset.
- Flowers You can follow the instructions in the website to download the flower dataset.
- Landscapes (LHQ) You can follow the instructions in ALIS website to download the LHQ1024 dataset.
- To draw a sketch/stroke images on the interface or upload a reference input image, run the following command and access http://127.0.0.1:5000. First loading would need more time.
python app.py
- The corresponding output image will be computed for about 55 seconds.
- We utilize Photo-Sketching: Inferring Contour Drawings from Images for sketach generation and Stylized Neural Painting, Paint Transformer: Feed Forward Neural Painting with Stroke Prediction for stroke generation, please refer to our paper for details.
- Please organize the whole dataset as follow:
Root Dir/
-<dataset_path>/
-Image1
-Image2
-...
-"<dataset_path>_sketch"/
-Image1
-Image2
-...
-"<dataset_path>_stroke"/
-Image1
-Image2
-...
- We use the following hyperparameters among the three datasets in our paper, you can design your own ones.
MODEL_FLAGS="--attention_resolutions 32,16,8 --image_size 512 --num_channels 128 --num_res_blocks 3 --use_scale_shift_norm True"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 2"
- Train your own sketch and stroke conditioned diffusion models with desired data stored in <dataset_path>.
python train.py --data_dir <dataset_path> --path <path_for_saving_the_models> $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
- Finetune the trained models (stored at <pretrained_model_path>) for sketch and stroke classifier-free guidance (randomly replacing 30% of each condition with an image filled with gray pixels).
python finetune.py --data_dir <dataset_path> --path <path_for_saving_the_models> --trained_model <pretrained_model_path> $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
- Download the pre-trained diffusion models, here we provide the pre-trained diffusion model for the AFHQ-cat, Flowers, LHQ and AFHQ(cats, dogs, and wildlife) dataset.
- Unzip the downloaded file, and save all the model under the path
./checkpoints/
.
- Flowers
MODEL_FLAGS="--attention_resolutions 32,16,8 --image_size 512 --num_channels 128 --num_res_blocks 3 --use_scale_shift_norm True"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
python sample.py $MODEL_FLAGS $DIFFUSION_FLAGS --timestep_respacing 250 --model_path "./checkpoints/flower512.pt" --lhq False --input_image "./test-examples/realism-example.png"
- Landscapes
MODEL_FLAGS="--attention_resolutions 32,16,8 --image_size 512 --num_channels 128 --num_res_blocks 3 --use_scale_shift_norm True"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
python sample.py $MODEL_FLAGS $DIFFUSION_FLAGS --timestep_respacing 250 --model_path "./checkpoints/lhq512.pt" --lhq True --input_image "./test-examples/realism-example2.png"
MODEL_FLAGS="--attention_resolutions 32,16,8 --image_size 512 --num_channels 128 --num_res_blocks 3 --use_scale_shift_norm True"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
CUDA_VISIBLE_DEVICES=1, python local-editing.py $MODEL_FLAGS $DIFFUSION_FLAGS --timestep_respacing 250
MODEL_FLAGS="--attention_resolutions 32,16,8 --image_size 512 --num_channels 128 --num_res_blocks 3 --use_scale_shift_norm True"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
CUDA_VISIBLE_DEVICES=0, python partial-editing.py $MODEL_FLAGS $DIFFUSION_FLAGS --timestep_respacing 250
MODEL_FLAGS="--attention_resolutions 32,16,8 --image_size 512 --num_channels 128 --num_res_blocks 3 --use_scale_shift_norm True"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
python multimodal.py $MODEL_FLAGS $DIFFUSION_FLAGS --timestep_respacing 250
Our code is based on guided-diffusion and pytorch-resizer. The inference code of the user interface is borrowed from blended-diffusion.