- 2024.11.14 The repository of RDIStitcher is created.
- 2024.11.18 Release the inference code and pre-trained models.
- 2024.11.28 Release the training code.
- 2024.11.28 Release the MLLM-based metrics code.
-
Python >= 3.9
-
GPU (NVIDIA CUDA compatible) >=24 GB VRAM If you only have GPUs with 16GB VRAM, we give a low memory plan.
-
Create a virtual environment (optional but recommended):
conda create -n rdistitcher python==3.10 conda activate rdistitcher
-
Install the required dependencies:
pip install -r requirements.txt
The UDIS-D dataset, aligned images, and masks can be obtained from UDIS
The datasets should be organized as follows:
Click for details...
train
├── input1
│ ├── 000001.jpg
│ ├── ...
├── input2
│ ├── 000001.jpg
│ ├── ...
├── mask1
│ ├── 000001.jpg
│ ├── ...
├── mask2
│ ├── 000001.jpg
│ ├── ...
test
├── warp1
│ ├── 000001.jpg
│ ├── ...
├── warp2
│ ├── 000001.jpg
│ ├── ...
├── mask1
│ ├── 000001.jpg
│ ├── ...
├── mask2
│ ├── 000001.jpg
│ ├── ...
bash train.sh
This training configuration of train.sh
requires ~24 GB VRAM with 2 GPUs.
Click for details...
export MODEL_NAME="stabilityai/stable-diffusion-2-inpainting"
export TRAIN_DIR="path-to-train-set"
export TEST_DIR="path-to-test-set"
export OUTPUT_DIR="RDIStitcherModel"
export CUDA_VISIBLE_DEVICES=0,1
accelerate launch train.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--val_data_dir=$TEST_DIR \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--unet_learning_rate=2e-4 \
--text_encoder_learning_rate=4e-5 \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=10001 \
--lora_rank=8 \
--lora_dropout=0.1 \
--lora_alpha=16 \
--seed 0 \
--mixed_precision "no" \
--identifier "<A>" \
--tempmodel_steps 10000 \
--validation_steps 500
This training configuration of train.sh
requires ~16 GB VRAM with 2 GPUs.
Click for details...
export MODEL_NAME="stabilityai/stable-diffusion-2-inpainting"
export TRAIN_DIR="path-to-train-set"
export TEST_DIR="path-to-test-set"
export OUTPUT_DIR="RDIStitcherModel"
export CUDA_VISIBLE_DEVICES=0,1
accelerate launch train.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--val_data_dir=$TEST_DIR \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--use_8bit_adam \
--enable_xformers_memory_efficient_attention \
--set_grads_to_none \
--unet_learning_rate=2e-4 \
--text_encoder_learning_rate=4e-5 \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=10001 \
--lora_rank=8 \
--lora_dropout=0.1 \
--lora_alpha=16 \
--seed 0 \
--mixed_precision "no" \
--identifier "<A>" \
--tempmodel_steps 10000 \
--validation_steps 500
This inference process requires ~6 GB VRAM.
Click for details...
python inference.py \
--device "cuda" \
--pretrained_model_name_or_path "stabilityai/stable-diffusion-2-inpainting" \
--loramodel_path "loraweight" \
--data_root "path-to-test-set" \
--test_prompt "<A>" \
--num_seed 5
Our pre-trained LoRA weights are very small at only 12MB, so you can use it directly in the loraweight
document. Due to hardware limitations, we cannot give the best "LoRA" setting, but we think the presented pre-trained LoRA weight is sufficient to demonstrate the effectiveness of our work.
Click for details...
For using qwen:
pip install openai
For using glm:
pip install zhipuai
python mllmmetrics.py \
--metric_type "qwen-siqs" or "glm-siqs" \
--image_path "path-to-stitched-images" \
--api_key "your-api_key" \
--base_url "your-base-url" \
python mllmmetrics.py \
--metric_type "qwen-micqs" or "glm-micqs" \
--image_path "path-to-stitched-images" \
--image_path2 "path2-to-stitched-images" \
--api_key "your-api_key" \
--base_url "your-base-url" \
@misc{xie2024modificationtakescourageseamless,
title={Modification Takes Courage: Seamless Image Stitching via Reference-Driven Inpainting},
author={Ziqi Xie and Xiao Lai and Weidong Zhao and Xianhui Liu and Wenlong Hou},
year={2024},
eprint={2411.10309},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2411.10309},
}