diff --git a/configs/textual_inversion/README.md b/configs/textual_inversion/README.md new file mode 100644 index 0000000000..9fc993f286 --- /dev/null +++ b/configs/textual_inversion/README.md @@ -0,0 +1,109 @@ +# Textual Inversion (2022) + +> [An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion](https://arxiv.org/abs/2208.01618) + +> **Task**: Text2Image + + + +## Abstract + + + +Text-to-image models offer unprecedented freedom to guide creation through natural language. Yet, it is unclear how such freedom can be exercised to generate images of specific unique concepts, modify their appearance, or compose them in new roles and novel scenes. In other words, we ask: how can we use language-guided models to turn our cat into a painting, or imagine a new product based on our favorite toy? Here we present a simple approach that allows such creative freedom. Using only 3-5 images of a user-provided concept, like an object or a style, we learn to represent it through new "words" in the embedding space of a frozen text-to-image model. These "words" can be composed into natural language sentences, guiding personalized creation in an intuitive way. Notably, we find evidence that a single word embedding is sufficient for capturing unique and varied concepts. We compare our approach to a wide range of baselines, and demonstrate that it can more faithfully portray the concepts across a range of applications and tasks. + + + +
+ +
+ +## Configs + +| Model | Dataset | Download | +| :-----------------------------------------: | :-----: | :------: | +| [Textual Inversion](./textual_inversion.py) | - | - | + +## Quick Start + +1. Download [data](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save to `data/cat_toy` + +The file structure will be like this: + +```text +data +└── cat_toy + ├── 1.jpeg + ├── 2.jpeg + ├── 3.jpeg + ├── 3.jpeg + ├── 4.jpeg + ├── 6.jpeg + └── 7.jpeg +``` + +2. Start training with the following command: + +```bash +bash tools/dist_train.sh configs/textual_inversion/textual_inversion.py 1 +``` + +
+ +
+
+ +3. Inference with trained textual embedding: + +```python +import torch +from mmengine import Config + +from mmagic.registry import MODELS +from mmagic.utils import register_all_modules + +register_all_modules() + + +def process_state_dict(state_dict): + new_state_dict = dict() + for k, v in state_dict.items(): + new_k = k.replace('module.', '') + new_state_dict[new_k] = v + + return new_state_dict + + +cfg = Config.fromfile('configs/textual_inversion/textual_inversion.py') +checkpoint = torch.load('work_dirs/textual_inversion/iter_3000.pth') +state_dict = process_state_dict(checkpoint['state_dict']) +model = MODELS.build(cfg.model) +model.load_state_dict(state_dict) + +model = model.cuda() +with torch.no_grad(): + sample = model.infer('a bag')['samples'][0] + +sample.save('cat-toy-bag.png') +``` + +## Comments + +Our codebase for the stable diffusion models builds heavily on [diffusers codebase](https://github.com/huggingface/diffusers) and the model weights are from [stable-diffusion-1.5](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py). + +Thanks for the efforts of the community! + +## Citation + +```bibtex +@misc{gal2022textual, + doi = {10.48550/ARXIV.2208.01618}, + url = {https://arxiv.org/abs/2208.01618}, + author = {Gal, Rinon and Alaluf, Yuval and Atzmon, Yuval and Patashnik, Or and Bermano, Amit H. and Chechik, Gal and Cohen-Or, Daniel}, + title = {An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion}, + publisher = {arXiv}, + year = {2022}, + primaryClass={cs.CV} +} + +``` diff --git a/configs/textual_inversion/metafile.yml b/configs/textual_inversion/metafile.yml new file mode 100644 index 0000000000..d56190d308 --- /dev/null +++ b/configs/textual_inversion/metafile.yml @@ -0,0 +1,18 @@ +Collections: +- Name: Textual Inversion + Paper: + Title: 'An Image is Worth One Word: Personalizing Text-to-Image Generation using + Textual Inversion' + URL: https://arxiv.org/abs/2208.01618 + README: configs/textual_inversion/README.md + Task: + - text2image + Year: 2022 +Models: +- Config: configs/textual_inversion/textual_inversion.py + In Collection: Textual Inversion + Name: textual_inversion + Results: + - Dataset: '-' + Metrics: {} + Task: Text2Image diff --git a/configs/textual_inversion/textual_inversion.py b/configs/textual_inversion/textual_inversion.py new file mode 100644 index 0000000000..9a62bd55d8 --- /dev/null +++ b/configs/textual_inversion/textual_inversion.py @@ -0,0 +1,85 @@ +_base_ = '../_base_/gen_default_runtime.py' + +# config for model +dtype = 'fp16' +stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5' + +placeholder_token = '' +initialize_token = 'toy' +num_vectors_per_token = 1 +val_prompts = [ + 'a on packbag', 'a on sofa', + 'a in swimming pool', 'a ' +] + +model = dict( + type='TextualInversion', + placeholder_token=placeholder_token, + vae=dict( + type='AutoencoderKL', + from_pretrained=stable_diffusion_v15_url, + subfolder='vae'), + unet=dict( + type='UNet2DConditionModel', + from_pretrained=stable_diffusion_v15_url, + subfolder='unet'), + text_encoder=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_v15_url, + subfolder='text_encoder'), + tokenizer=stable_diffusion_v15_url, + initialize_token=initialize_token, + num_vectors_per_token=num_vectors_per_token, + val_prompts=val_prompts, + scheduler=dict( + type='DDPMScheduler', + from_pretrained=stable_diffusion_v15_url, + subfolder='scheduler'), + test_scheduler=dict( + type='DDIMScheduler', + from_pretrained=stable_diffusion_v15_url, + subfolder='scheduler'), + data_preprocessor=dict(type='DataPreprocessor', data_keys=None)) + +train_cfg = dict(max_iters=3000) + +optim_wrapper = dict( + modules='.*trainable_embeddings', + optimizer=dict(type='AdamW', lr=5e-4), + accumulative_counts=1) + +pipeline = [ + dict(type='LoadImageFromFile', key='img', channel_order='rgb'), + dict(type='Resize', scale=(512, 512)), + dict(type='PackInputs') +] + +dataset = dict( + type='TextualInversionDataset', + data_root='./data/', + concept_dir='cat_toy', + placeholder=placeholder_token, + pipeline=pipeline) + +train_dataloader = dict( + dataset=dataset, + num_workers=16, + sampler=dict(type='InfiniteSampler', shuffle=True), + persistent_workers=True, + batch_size=1) +val_cfg = val_evaluator = val_dataloader = None +test_cfg = test_evaluator = test_dataloader = None + +default_hooks = dict( + logger=dict(interval=10), + checkpoint=dict(type='CheckpointHook', interval=10)) +custom_hooks = [ + dict( + type='VisualizationHook', + interval=50, + fixed_input=True, + # visualize train dataset + vis_kwargs_list=dict(type='Data', name='fake_img'), + n_samples=1) +] diff --git a/mmagic/datasets/__init__.py b/mmagic/datasets/__init__.py index 98c17e508b..240c68b2e5 100644 --- a/mmagic/datasets/__init__.py +++ b/mmagic/datasets/__init__.py @@ -11,6 +11,7 @@ from .mscoco_dataset import MSCoCoDataset from .paired_image_dataset import PairedImageDataset from .singan_dataset import SinGANDataset +from .textual_inversion_dataset import TextualInversionDataset from .unpaired_image_dataset import UnpairedImageDataset __all__ = [ @@ -18,5 +19,5 @@ 'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset', 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset', 'MSCoCoDataset', 'ControlNetDataset', 'DreamBoothDataset', - 'ControlNetDataset', 'SDFinetuneDataset' + 'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset' ] diff --git a/mmagic/datasets/textual_inversion_dataset.py b/mmagic/datasets/textual_inversion_dataset.py new file mode 100644 index 0000000000..665c06e27f --- /dev/null +++ b/mmagic/datasets/textual_inversion_dataset.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from random import choice +from typing import Callable, List, Union + +from mmengine import FileClient +from mmengine.dataset import BaseDataset + +from mmagic.registry import DATASETS + +imagenet_templates_small = [ + 'a photo of a {}', + 'a rendering of a {}', + 'a cropped photo of the {}', + 'the photo of a {}', + 'a photo of a clean {}', + 'a photo of a dirty {}', + 'a dark photo of the {}', + 'a photo of my {}', + 'a photo of the cool {}', + 'a close-up photo of a {}', + 'a bright photo of the {}', + 'a cropped photo of a {}', + 'a photo of the {}', + 'a good photo of the {}', + 'a photo of one {}', + 'a close-up photo of the {}', + 'a rendition of the {}', + 'a photo of the clean {}', + 'a rendition of a {}', + 'a photo of a nice {}', + 'a good photo of a {}', + 'a photo of the nice {}', + 'a photo of the small {}', + 'a photo of the weird {}', + 'a photo of the large {}', + 'a photo of a cool {}', + 'a photo of a small {}', +] + +imagenet_style_templates_small = [ + 'a painting in the style of {}', + 'a rendering in the style of {}', + 'a cropped painting in the style of {}', + 'the painting in the style of {}', + 'a clean painting in the style of {}', + 'a dirty painting in the style of {}', + 'a dark painting in the style of {}', + 'a picture in the style of {}', + 'a cool painting in the style of {}', + 'a close-up painting in the style of {}', + 'a bright painting in the style of {}', + 'a cropped painting in the style of {}', + 'a good painting in the style of {}', + 'a close-up painting in the style of {}', + 'a rendition in the style of {}', + 'a nice painting in the style of {}', + 'a small painting in the style of {}', + 'a weird painting in the style of {}', + 'a large painting in the style of {}', +] + + +@DATASETS.register_module() +class TextualInversionDataset(BaseDataset): + """Dataset for DreamBooth. + + Args: + data_root (str): Path to the data root. + concept_dir (str): Path to the concept images. + is_style (bool) + prompt (str): Prompt of the concept. + pipeline (list[dict | callable]): A sequence of data transforms. + """ + + def __init__(self, + data_root: str, + concept_dir: str, + placeholder: str, + is_style: bool = False, + pipeline: List[Union[dict, Callable]] = []): + + data_prefix = dict(img_path=concept_dir) + + self.placeholder = placeholder + if is_style: + self.template = imagenet_style_templates_small + else: + self.template = imagenet_templates_small + + super().__init__( + data_root=data_root, data_prefix=data_prefix, pipeline=pipeline) + + def load_data_list(self) -> list: + """Load data list from concept_dir and class_dir.""" + data_list = [] + + img_dir = self.data_prefix['img_path'] + file_client = FileClient.infer_client(uri=img_dir) + img_dir = osp.abspath(img_dir) + + for data_name in file_client.list_dir_or_file(img_dir, list_dir=False): + data_info = dict( + img_path=file_client.join_path(img_dir, data_name)) + data_list.append(data_info) + return data_list + + def prepare_data(self, idx): + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + data_info = self.get_data_info(idx) + # load random template + selected_template = choice(self.template) + prompt = selected_template.format(self.placeholder) + data_info['prompt'] = prompt + return self.pipeline(data_info) diff --git a/mmagic/models/editors/__init__.py b/mmagic/models/editors/__init__.py index e6cd7679b3..95499b9d53 100644 --- a/mmagic/models/editors/__init__.py +++ b/mmagic/models/editors/__init__.py @@ -58,6 +58,7 @@ from .stylegan3 import StyleGAN3, StyleGAN3Generator from .swinir import SwinIRNet from .tdan import TDAN, TDANNet +from .textual_inversion import TextualInversion from .tof import TOFlowVFINet, TOFlowVSRNet, ToFResBlock from .ttsr import LTE, TTSR, SearchTransformer, TTSRDiscriminator, TTSRNet from .wgan_gp import WGANGP @@ -88,5 +89,5 @@ 'StyleGAN3Generator', 'InstColorization', 'NAFBaseline', 'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DenoisingUnet', 'ClipWrapper', 'EG3D', 'Restormer', 'SwinIRNet', 'StableDiffusion', - 'ControlStableDiffusion', 'DreamBooth' + 'ControlStableDiffusion', 'DreamBooth', 'TextualInversion' ] diff --git a/mmagic/models/editors/dreambooth/dreambooth.py b/mmagic/models/editors/dreambooth/dreambooth.py index a711f16d8b..9639e2cf86 100644 --- a/mmagic/models/editors/dreambooth/dreambooth.py +++ b/mmagic/models/editors/dreambooth/dreambooth.py @@ -29,8 +29,6 @@ class DreamBooth(StableDiffusion): encoder. tokenizer (str): The **name** for CLIP tokenizer. unet (Union[dict, nn.Module]): The config or module for Unet model. - controlnet (Union[dict, nn.Module]): The config or module for - ControlNet. schedule (Union[dict, nn.Module]): The config or module for diffusion scheduler. test_scheduler (Union[dict, nn.Module], optional): The config or @@ -54,6 +52,10 @@ class DreamBooth(StableDiffusion): noise_offset_weight (bool, optional): The weight of noise offset introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise # noqa Defaults to 0. + tomesd_cfg (dict, optional): The config for TOMESD. Please refers to + https://github.com/dbolya/tomesd and + https://github.com/open-mmlab/mmagic/blob/main/mmagic/models/utils/tome_utils.py for detail. # noqa + Defaults to None. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to dict(type='DataPreprocessor'). diff --git a/mmagic/models/editors/textual_inversion/__init__.py b/mmagic/models/editors/textual_inversion/__init__.py new file mode 100644 index 0000000000..b660ed260b --- /dev/null +++ b/mmagic/models/editors/textual_inversion/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .textual_inversion import TextualInversion + +__all__ = ['TextualInversion'] diff --git a/mmagic/models/editors/textual_inversion/textual_inversion.py b/mmagic/models/editors/textual_inversion/textual_inversion.py new file mode 100644 index 0000000000..7057583f7f --- /dev/null +++ b/mmagic/models/editors/textual_inversion/textual_inversion.py @@ -0,0 +1,259 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.logging import MMLogger + +from mmagic.registry import MODELS +from mmagic.structures import DataSample +from mmagic.utils.typing import SampleList +from ..stable_diffusion.stable_diffusion import StableDiffusion + +logger = MMLogger.get_current_instance() + +ModelType = Union[Dict, nn.Module] + + +@MODELS.register_module() +class TextualInversion(StableDiffusion): + """Implementation of `An Image is Worth One Word: Personalizing Text-to- + Image Generation using Textual Inversion. + + `_ (Textual Inversion). + + Args: + vae (Union[dict, nn.Module]): The config or module for VAE model. + text_encoder (Union[dict, nn.Module]): The config or module for text + encoder. + tokenizer (str): The **name** for CLIP tokenizer. + unet (Union[dict, nn.Module]): The config or module for Unet model. + schedule (Union[dict, nn.Module]): The config or module for diffusion + scheduler. + test_scheduler (Union[dict, nn.Module], optional): The config or + module for diffusion scheduler in test stage (`self.infer`). If not + passed, will use the same scheduler as `schedule`. Defaults to + None. + dtype (str, optional): The dtype for the model. Defaults to 'fp16'. + enable_xformers (bool, optional): Whether to use xformers. + Defaults to True. + noise_offset_weight (bool, optional): The weight of noise offset + introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise # noqa + Defaults to 0. + tomesd_cfg (dict, optional): The config for TOMESD. Please refers to + https://github.com/dbolya/tomesd and + https://github.com/open-mmlab/mmagic/blob/main/mmagic/models/utils/tome_utils.py for detail. # noqa + Defaults to None. + initialize_token (str, optional): The initialization token for textual + embedding to train. Defaults to None. + num_vefctor_per_token (int): The length of the learnable embedding. + Defaults to 1. + val_prompts (Union[str, List[str]], optional): The prompts for + validation. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. Defaults to + dict(type='DataPreprocessor'). + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. Defaults to None/ + """ + + def __init__(self, + placeholder_token: str, + vae: ModelType, + text_encoder: ModelType, + tokenizer: str, + unet: ModelType, + scheduler: ModelType, + test_scheduler: Optional[ModelType] = None, + dtype: Optional[str] = None, + enable_xformers: bool = True, + noise_offset_weight: float = 0, + tomesd_cfg: Optional[dict] = None, + initialize_token: Optional[str] = None, + num_vectors_per_token: int = 1, + val_prompts=None, + data_preprocessor: Optional[ModelType] = dict( + type='DataPreprocessor'), + init_cfg: Optional[dict] = None): + + super().__init__(vae, text_encoder, tokenizer, unet, scheduler, + test_scheduler, dtype, enable_xformers, + noise_offset_weight, tomesd_cfg, data_preprocessor, + init_cfg) + + self.val_prompts = val_prompts + self.placeholder_token = placeholder_token + self.add_tokens(placeholder_token, initialize_token, + num_vectors_per_token) + self.prepare_models() + + def prepare_models(self): + """Disable gradient for untrainable modules to save memory.""" + self.vae.requires_grad_(False) + self.unet.requires_grad_(False) + self.text_encoder.set_only_embedding_trainable() + + @torch.no_grad() + def val_step(self, data: dict) -> SampleList: + """Gets the generated image of given data. Calls + ``self.data_preprocessor`` and ``self.infer`` in order. Return the + generated results which will be passed to evaluator or visualizer. + + Args: + data (dict or tuple or list): Data sampled from dataset. + + Returns: + SampleList: Generated image or image dict. + """ + data = self.data_preprocessor(data) + data_samples = data['data_samples'] + if self.val_prompts is None: + prompt = data_samples.prompt + else: + prompt = self.val_prompts + # construct a fake data_sample for destruct + data_samples.split() * len(prompt) + data_samples = DataSample.stack(data_samples.split() * len(prompt)) + + unet_dtype = next(self.unet.parameters()).dtype + self.unet.to(self.dtype) + + output = self.infer(prompt, return_type='tensor') + samples = output['samples'] + + self.unet.to(unet_dtype) + + samples = self.data_preprocessor.destruct(samples, data_samples) + + out_data_sample = DataSample(fake_img=samples, prompt=prompt) + data_sample_list = out_data_sample.split() + return data_sample_list + + @torch.no_grad() + def test_step(self, data: dict) -> SampleList: + """Gets the generated image of given data. Calls + ``self.data_preprocessor`` and ``self.infer`` in order. Return the + generated results which will be passed to evaluator or visualizer. + + Args: + data (dict or tuple or list): Data sampled from dataset. + + Returns: + SampleList: Generated image or image dict. + """ + if self.val_prompts is None: + data = self.data_preprocessor(data) + data_samples = data['data_samples'] + prompt = data_samples.prompt + else: + prompt = self.val_prompts + # construct a fake data_sample for destruct + data_samples = DataSample.stack(data['data_samples'] * len(prompt)) + + unet_dtype = next(self.unet.parameters()).dtype + self.unet.to(self.dtype) + + output = self.infer(prompt, return_type='tensor') + samples = output['samples'] + + self.unet.to(unet_dtype) + + samples = self.data_preprocessor.destruct(samples, data_samples) + + out_data_sample = DataSample(fake_img=samples, prompt=prompt) + data_sample_list = out_data_sample.split() + return data_sample_list + + def add_tokens(self, + placeholder_token: str, + initialize_token: str = None, + num_vectors_per_token: int = 1): + """Add token for training. + + # TODO: support add tokens as dict, then we can load pretrained tokens. + """ + self.tokenizer.add_placeholder_token( + placeholder_token, num_vec_per_token=num_vectors_per_token) + + self.text_encoder.set_embedding_layer() + embedding_layer = self.text_encoder.get_embedding_layer() + assert embedding_layer is not None, ( + 'Do not support get embedding layer for current text encoder. ' + 'Please check your configuration.') + + if initialize_token: + init_id = self.tokenizer(initialize_token).input_ids[1] + initialize_embedding = embedding_layer.weight[init_id] + initialize_embedding = initialize_embedding[None, ...].repeat( + num_vectors_per_token, 1) + else: + emb_dim = embedding_layer.weight.shape[1] + initialize_embedding = torch.zeros(num_vectors_per_token, emb_dim) + + token_info = self.tokenizer.get_token_info(placeholder_token) + token_info['embedding'] = initialize_embedding + token_info['trainable'] = True + self.token_info = token_info + embedding_layer.add_embeddings(token_info) + + def train_step(self, data, optim_wrapper): + """Training step.""" + data = self.data_preprocessor(data) + inputs, data_samples = data['inputs'], data['data_samples'] + + vae = self.vae.module if hasattr(self.vae, 'module') else self.vae + vae_dtype = next(vae.parameters()).dtype + unet_dtype = next(self.unet.parameters()).dtype + + with optim_wrapper.optim_context(self.unet): + image = inputs # image for new concept + prompt = data_samples.prompt + num_batches = image.shape[0] + + image = image.to(vae_dtype) + latents = vae.encode(image).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + noise = torch.randn_like(latents) + timesteps = torch.randint( + 0, + self.scheduler.num_train_timesteps, (num_batches, ), + device=self.device) + timesteps = timesteps.long() + + noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + + input_ids = self.tokenizer( + prompt, + max_length=self.tokenizer.model_max_length, + return_tensors='pt', + padding='max_length', + truncation=True)['input_ids'].to(self.device) + + encoder_hidden_states = self.text_encoder(input_ids)[0] + + if self.scheduler.config.prediction_type == 'epsilon': + gt = noise + elif self.scheduler.config.prediction_type == 'v_prediction': + gt = self.scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError('Unknown prediction type ' + f'{self.scheduler.config.prediction_type}') + + model_output = self.unet( + noisy_latents.to(unet_dtype), + timesteps, + encoder_hidden_states=encoder_hidden_states.to(unet_dtype)) + model_pred = model_output['sample'] + + loss_dict = dict() + + # calculate loss in FP32 + loss = F.mse_loss(model_pred.float(), gt.float()) + loss_dict['loss'] = loss + + parsed_loss, log_vars = self.parse_losses(loss_dict) + optim_wrapper.update_params(parsed_loss) + + return log_vars diff --git a/model-index.yml b/model-index.yml index c2f648e40a..9f415004c6 100644 --- a/model-index.yml +++ b/model-index.yml @@ -49,6 +49,7 @@ Import: - configs/styleganv3/metafile.yml - configs/swinir/metafile.yml - configs/tdan/metafile.yml +- configs/textual_inversion/metafile.yml - configs/tof/metafile.yml - configs/ttsr/metafile.yml - configs/wgan-gp/metafile.yml