Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support Prompt-to-prompt, ddim inversion and null-text inversion #1908

Merged
merged 21 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions projects/prompt_to_prompt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Inversions & Editing (DDIM Inversion & Null-Text Inversion & Prompt-to-Prompt Editing)

```
Author: @FerryHuang

This is an implementation of the papers:
```

> [PROMPT-TO-PROMPT IMAGE EDITING
> WITH CROSS-ATTENTION CONTROL](https://prompt-to-prompt.github.io/ptp_files/Prompt-to-Prompt_preprint.pdf)

> [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/pdf/2211.09794.pdf)

> **Task**: Text2Image, diffusion, inversion, editing

<!-- [ALGORITHM] -->

## Abstract

<!-- [ABSTRACT] -->

Diffusion's inversion basically means you put an image (with or without a prompt) into a method and it will return a latent code which can be later turned back to a image with high simmilarity as the original one. Of course we want this latent code for an editing purpose, that's also why we always implement inversion methods together with the editing methods.

This project contains **Two inversion methods** and **One editing method**.

## From right to left: origin image, DDIM inversion, Null-text inversion

<center class="half">
<img src="https://github.com/FerryHuang/mmagic/assets/71176040/34d8a467-5378-41fb-83c6-b23c9dee8f0a" width="200"/><img src="https://github.com/FerryHuang/mmagic/assets/71176040/3d3814b4-7fb5-4232-a56f-fd7fef0ba28e" width="200"/><img src="https://github.com/FerryHuang/mmagic/assets/71176040/43008ed4-a5a3-4f81-ba9f-95d9e79e6a08" width="200"/>
</center>

## Prompt-to-prompt Editing

<div align="center">
<b>cat -> dog</b>
<br/>
<img src="https://github.com/FerryHuang/mmagic/assets/71176040/f5d3fc0c-aa7b-4525-9364-365b254d51ca" width="500"/>
</div>

<div align="center">
<b>spider man -> iron man(attention replace)</b>
<br/>
<img src="https://github.com/FerryHuang/mmagic/assets/71176040/074adbc6-bd48-4c82-99aa-f322cf937f5a" width="500"/>
</div>

<div align="center">
<b>Effel tower -> Effel tower at night (attention refine)</b>
<br/>
<img src="https://github.com/FerryHuang/mmagic/assets/71176040/f815dab3-b20c-4936-90e3-a060d3717e22" width="500"/>
</div>

<div align="center">
<b>blossom sakura tree -> blossom(-3) sakura tree (attention reweight)</b>
<br/>
<img src="https://github.com/FerryHuang/mmagic/assets/71176040/5ef770b9-4f28-4ae7-84b0-6c15ea7450e9" width="500"/>
</div>

## Quick Start

A walkthrough of the project is provided [here](visualize.ipynb)

or you can just run the following scripts to get the results:

```python
# load the mmagic SD1.5
from mmengine import MODELS, Config
from mmengine.registry import init_default_scope

init_default_scope('mmagic')

config = 'configs/stable_diffusion/stable-diffusion_ddim_denoisingunet.py'
config = Config.fromfile(config).copy()

StableDiffuser = MODELS.build(config.model)
StableDiffuser = StableDiffuser.to('cuda')
```

```python
# inversion
image_path = 'projects/prompt_to_prompt/assets/gnochi_mirror.jpeg'
prompt = "a cat sitting next to a mirror"
image_tensor = ptp_utils.load_512(image_path).to('cuda')

from inversions.null_text_inversion import NullTextInversion
from models.ptp import EmptyControl
from models import ptp_utils

null_inverter = NullTextInversion(StableDiffuser)
null_inverter.init_prompt(prompt)
ddim_latents = null_inverter.ddim_inversion(image_tensor)
x_t = ddim_latents[-1]
uncond_embeddings = null_inverter.null_optimization(ddim_latents, num_inner_steps=10, epsilon=1e-5)
null_text_rec, _ = ptp_utils.text2image_ldm_stable(StableDiffuser, [prompt], EmptyControl(), latent=x_t, uncond_embeddings=uncond_embeddings)
ptp_utils.view_images(null_text_rec)
```

```python
# prompt-to-prompt editing
prompts = ["A cartoon of spiderman",
"A cartoon of ironman"]
import torch
from models.ptp import LocalBlend, AttentionReplace
from models.ptp_utils import text2image_ldm_stable
g = torch.Generator().manual_seed(2023616)
lb = LocalBlend(prompts, ("spiderman", "ironman"), model=StableDiffuser)
controller = AttentionReplace(prompts, 50,
cross_replace_steps={"default_": 1., "ironman": .2},
self_replace_steps=0.4,
local_blend=lb, model=StableDiffuser)
images, x_t = text2image_ldm_stable(StableDiffuser, prompts, controller, latent=None,
num_inference_steps=50, guidance_scale=7.5, uncond_embeddings=None, generator=g)
```

## Citation

```bibtex
@article{hertz2022prompt,
title = {Prompt-to-Prompt Image Editing with Cross Attention Control},
author = {Hertz, Amir and Mokady, Ron and Tenenbaum, Jay and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel},
journal = {arXiv preprint arXiv:2208.01626},
year = {2022},
}
@article{mokady2022null,
title={Null-text Inversion for Editing Real Images using Guided Diffusion Models},
author={Mokady, Ron and Hertz, Amir and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2211.09794},
year={2022}
}
```
Empty file.
168 changes: 168 additions & 0 deletions projects/prompt_to_prompt/inversions/ddim_inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""This code was originally taken from https://github.com/google/prompt-to-
prompt and modified by Ferry."""

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import numpy as np
import torch
from PIL import Image

from mmagic.models.diffusion_schedulers import EditDDIMScheduler


class DDIMInversion:

def __init__(self, model, **kwargs):
scheduler = EditDDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
clip_sample=False,
set_alpha_to_one=False)
self.model = model
self.num_ddim_steps = kwargs.pop('num_ddim_steps', 50)
self.guidance_scale = kwargs.pop('guidance_scale', 7.5)
self.tokenizer = self.model.tokenizer
self.model.scheduler = scheduler
self.model.scheduler.set_timesteps(self.num_ddim_steps)
self.prompt = None
self.context = None

def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
prev_timestep = timestep - (
self.scheduler.num_train_timesteps //
self.scheduler.num_inference_steps)
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[prev_timestep]
if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
pred_original_sample = (
sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_sample_direction = (1 - alpha_prod_t_prev)**0.5 * model_output
prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample \
+ pred_sample_direction
return prev_sample

def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
timestep, next_timestep = min(
timestep - self.scheduler.num_train_timesteps //
self.scheduler.num_inference_steps, 999), timestep
alpha_prod_t = self.scheduler.alphas_cumprod[
timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
next_original_sample = (
sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
next_sample_direction = (1 - alpha_prod_t_next)**0.5 * model_output
next_sample = alpha_prod_t_next**0.5 * next_original_sample \
+ next_sample_direction
return next_sample

def get_noise_pred_single(self, latents, t, context):
noise_pred = self.model.unet(
latents, t, encoder_hidden_states=context)['sample']
return noise_pred

def get_noise_pred(self, latents, t, is_forward=True, context=None):
latents_input = torch.cat([latents] * 2)
if context is None:
context = self.context
guidance_scale = 1 if is_forward else self.guidance_scale
noise_pred = self.model.unet(
latents_input, t, encoder_hidden_states=context)['sample']
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_prediction_text - noise_pred_uncond)
if is_forward:
latents = self.next_step(noise_pred, t, latents)
else:
latents = self.prev_step(noise_pred, t, latents)
return latents

@torch.no_grad()
def latent2image(self, latents, return_type='np'):
latents = 1 / 0.18215 * latents.detach()
image = self.model.vae.decode(latents, return_dict=False)[0]
if return_type == 'np':
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).astype(np.uint8)
return image

@torch.no_grad()
def image2latent(self, image):
with torch.no_grad():
if type(image) is Image:
image = np.array(image)
if type(image) is torch.Tensor and image.dim() == 4:
latents = image
else:
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0,
1).unsqueeze(0).to(self.model.device)
latents = self.model.vae.encode(image)['latent_dist'].mean
latents = latents * 0.18215
return latents

@torch.no_grad()
def init_prompt(self, prompt: str):
LeoXing1996 marked this conversation as resolved.
Show resolved Hide resolved
uncond_input = self.model.tokenizer(
[''],
padding='max_length',
max_length=self.model.tokenizer.model_max_length,
return_tensors='pt')
uncond_embeddings = self.model.text_encoder(
uncond_input.input_ids.to(self.model.device))[0]
text_input = self.model.tokenizer(
[prompt],
padding='max_length',
max_length=self.model.tokenizer.model_max_length,
truncation=True,
return_tensors='pt',
)
text_embeddings = self.model.text_encoder(
text_input.input_ids.to(self.model.device))[0]
self.context = torch.cat([uncond_embeddings, text_embeddings])
self.prompt = prompt

@torch.no_grad()
def ddim_loop(self, latent):
uncond_embeddings, cond_embeddings = self.context.chunk(2)
all_latent = [latent]
latent = latent.clone().detach()
for i in range(self.num_ddim_steps):
t = self.model.scheduler.timesteps[
len(self.model.scheduler.timesteps) - i - 1]
noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
latent = self.next_step(noise_pred, t, latent)
all_latent.append(latent)
return all_latent

@property
def scheduler(self):
return self.model.scheduler

@torch.no_grad()
def ddim_inversion(self, image):
latent = self.model.vae.encode(image)['latent_dist'].mean
latent = latent * 0.18215
# image_rec = self.latent2image(latent)
ddim_latents = self.ddim_loop(latent)
return ddim_latents
48 changes: 48 additions & 0 deletions projects/prompt_to_prompt/inversions/null_text_inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch.nn.functional as F
from inversions.ddim_inversion import DDIMInversion
from torch.optim import Adam
from tqdm import tqdm


class NullTextInversion(DDIMInversion):

# basically the only thing null_text_inversion does is
# to add a null_text_optimization method over ddim inversion

def null_optimization(self, latents, num_inner_steps, epsilon):
uncond_embeddings, cond_embeddings = self.context.chunk(2)
uncond_embeddings_list = []
latent_cur = latents[-1]
bar = tqdm(total=num_inner_steps * self.num_ddim_steps)
for i in range(self.num_ddim_steps):
uncond_embeddings = uncond_embeddings.clone().detach()
uncond_embeddings.requires_grad = True
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
latent_prev = latents[len(latents) - i - 2]
t = self.model.scheduler.timesteps[i]
with torch.no_grad():
noise_pred_cond = self.get_noise_pred_single(
latent_cur, t, cond_embeddings)
for j in range(num_inner_steps):
noise_pred_uncond = self.get_noise_pred_single(
latent_cur, t, uncond_embeddings)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_cond - noise_pred_uncond)
latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
loss = F.mse_loss(latents_prev_rec, latent_prev)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_item = loss.item()
bar.update()
if loss_item < epsilon + i * 2e-5:
break
for j in range(j + 1, num_inner_steps):
bar.update()
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
with torch.no_grad():
context = torch.cat([uncond_embeddings, cond_embeddings])
latent_cur = self.get_noise_pred(latent_cur, t, False, context)
bar.close()
return uncond_embeddings_list
Loading