Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Kandinskyv3 #117

Merged
merged 2 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ For detailed user guides and advanced guides, please refer to our [Documentation
<td>
<ul>
<li><a href="configs/kandinsky_v22/README.md">Kandinsky 2.2 (2023)</a></li>
<li><a href="configs/kandinsky_v3/README.md">Kandinsky 3 (2023)</a></li>
</ul>
</td>
</tr>
Expand Down
28 changes: 28 additions & 0 deletions configs/_base_/datasets/pokemon_blip_kandinsky_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
train_pipeline = [
dict(type="torchvision/Resize", size=1024, interpolation="bicubic"),
dict(type="RandomCrop", size=1024),
dict(type="RandomHorizontalFlip", p=0.5),
dict(type="torchvision/ToTensor"),
dict(type="torchvision/Normalize", mean=[0.5], std=[0.5]),
dict(type="PackInputs"),
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
dataset=dict(
type="HFDataset",
dataset="lambdalabs/pokemon-blip-captions",
pipeline=train_pipeline),
sampler=dict(type="DefaultSampler", shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(type="VisualizationHook", prompt=["yoda pokemon"] * 4,
height=1024, width=1024),
dict(type="SDCheckpointHook"),
]
3 changes: 3 additions & 0 deletions configs/_base_/models/kandinsky_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model = dict(
type="KandinskyV3",
model="kandinsky-community/kandinsky-3")
83 changes: 83 additions & 0 deletions configs/kandinsky_v3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Kandinsky 3

[Kandinsky 3](https://ai-forever.github.io/Kandinsky-3/)

## Abstract

We present Kandinsky 3.0, a large-scale text-to-image generation model based on latent diffusion, continuing the series of text-to-image Kandinsky models and reflecting our progress to achieve higher quality and realism of image generation. Compared to previous versions of Kandinsky 2.x, Kandinsky 3.0 leverages a two times larger UNet backbone, a ten times larger text encoder and remove diffusion mapping. We describe the architecture of the model, the data collection procedure, the training technique, the production system of user interaction. We focus on the key components that, as we have identified as a result of a large number of experiments, had the most significant impact on improving the quality of our model in comparison with the other ones. By results of our side by side comparisons Kandinsky become better in text understanding and works better on specific domains.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/2d670f44-9fa1-4095-be96-a82c91c9590b"/>
</div>

## Citation

```
@misc{arkhipkin2023kandinsky,
title={Kandinsky 3.0 Technical Report},
author={Vladimir Arkhipkin and Andrei Filatov and Viacheslav Vasilev and Anastasia Maltseva and Said Azizov and Igor Pavlov and Julia Agafonova and Andrey Kuznetsov and Denis Dimitrov},
year={2023},
eprint={2312.03511},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py
```

## Inference prior with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.

Before inferencing, we should convert weights for diffusers format,

```bash
$ mim run diffengine publish_model2diffusers ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
# Example
# Note that when training colossalai, use `--colossalai` and set `INPUT_FILENAME` to index file.
$ mim run diffengine publish_model2diffusers configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py work_dirs/kandinsky_v3_pokemon_blip/epoch_50.pth/model/pytorch_model.bin.index.json work_dirs/kandinsky_v3_pokemon_blip --save-keys unet --colossalai
```

Then we can run inference.

```py
from diffusers import AutoPipelineForText2Image, Kandinsky3UNet

prompt = 'yoda pokemon'
checkpoint = 'work_dirs/kandinsky_v3_pokemon_blip'

unet = Kandinsky3UNet.from_pretrained(
checkpoint, subfolder='unet')
pipe = AutoPipelineForText2Image.from_pretrained(
"kandinsky-community/kandinsky-3",
unet=unet,
variant="fp16",
)
pipe.to('cuda')

image = pipe(
prompt,
num_inference_steps=50,
width=1024,
height=1024,
).images[0]
image.save('demo.png')
```

## Results Example

#### kandinsky_v3_pokemon_blip

![example1](https://github.com/okotaku/diffengine/assets/24734142/8f078fa8-9485-40d9-8174-5996257aed88)
24 changes: 24 additions & 0 deletions configs/kandinsky_v3/kandinsky_v3_pokemon_blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_base_ = [
"../_base_/models/kandinsky_v3.py",
"../_base_/datasets/pokemon_blip_kandinsky_v3.py",
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

optim_wrapper = dict(
_delete_=True,
optimizer=dict(
type="HybridAdam",
lr=1e-5,
weight_decay=1e-2),
accumulative_counts=4)

default_hooks = dict(
checkpoint=dict(save_param_scheduler=False)) # no scheduler in this config

runner_type = "FlexibleRunner"
strategy = dict(type="ColossalAIStrategy",
plugin=dict(type="LowLevelZeroPlugin",
stage=2,
precision="bf16",
max_norm=1.0))
3 changes: 2 additions & 1 deletion diffengine/models/editors/kandinsky/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .kandinskyv3 import KandinskyV3
from .kandinskyv22_decoder import KandinskyV22Decoder
from .kandinskyv22_decoder_preprocessor import KandinskyV22DecoderDataPreprocessor
from .kandinskyv22_prior import KandinskyV22Prior

__all__ = ["KandinskyV22Prior", "KandinskyV22Decoder",
"KandinskyV22DecoderDataPreprocessor"]
"KandinskyV22DecoderDataPreprocessor", "KandinskyV3"]
Loading