Skip to content

Commit

Permalink
kakaobrain unCLIP (open-mmlab#1428)
Browse files Browse the repository at this point in the history
* [wip] attention block updates

* [wip] unCLIP unet decoder and super res

* [wip] unCLIP prior transformer

* [wip] scheduler changes

* [wip] text proj utility class

* [wip] UnCLIPPipeline

* [wip] kakaobrain unCLIP convert script

* [unCLIP pipeline] fixes re: @patrickvonplaten

remove callbacks

move denoising loops into call function

* UNCLIPScheduler re: @patrickvonplaten

Revert changes to DDPMScheduler. Make UNCLIPScheduler, a modified
DDPM scheduler with changes to support karlo

* mask -> attention_mask re: @patrickvonplaten

* [DDPMScheduler] remove leftover change

* [docs] PriorTransformer

* [docs] UNet2DConditionModel and UNet2DModel

* [nit] UNCLIPScheduler -> UnCLIPScheduler

matches existing unclip naming better

* [docs] SchedulingUnCLIP

* [docs] UnCLIPTextProjModel

* refactor

* finish licenses

* rename all to attention_mask and prep in models

* more renaming

* don't expose unused configs

* final renaming fixes

* remove x attn mask when not necessary

* configure kakao script to use new class embedding config

* fix copies

* [tests] UnCLIPScheduler

* finish x attn

* finish

* remove more

* rename condition blocks

* clean more

* Apply suggestions from code review

* up

* fix

* [tests] UnCLIPPipelineFastTests

* remove unused imports

* [tests] UnCLIPPipelineIntegrationTests

* correct

* make style

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
williamberman and patrickvonplaten committed Dec 18, 2022
1 parent 402b956 commit 2dcf64b
Show file tree
Hide file tree
Showing 21 changed files with 3,594 additions and 118 deletions.
1,159 changes: 1,159 additions & 0 deletions scripts/convert_kakao_brain_unclip_to_diffusers.py

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@
from .utils.dummy_pt_objects import * # noqa F403
else:
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import (
AutoencoderKL,
PriorTransformer,
Transformer2DModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
VQModel,
)
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down Expand Up @@ -63,6 +71,7 @@
RePaintScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
UnCLIPScheduler,
VQDiffusionScheduler,
)
from .training_utils import EMAModel
Expand Down Expand Up @@ -96,6 +105,7 @@
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionUpscalePipeline,
UnCLIPPipeline,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

if is_torch_available():
from .attention import Transformer2DModel
from .prior_transformer import PriorTransformer
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
Expand Down
124 changes: 94 additions & 30 deletions src/diffusers/models/attention.py

Large diffs are not rendered by default.

194 changes: 194 additions & 0 deletions src/diffusers/models/prior_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn.functional as F
from torch import nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps


@dataclass
class PriorTransformerOutput(BaseOutput):
"""
Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
"""

predicted_image_embedding: torch.FloatTensor


class PriorTransformer(ModelMixin, ConfigMixin):
"""
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
transformer predicts the image embeddings through a denoising diffusion process.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
For more details, see the original paper: https://arxiv.org/abs/2204.06125
Parameters:
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
image embeddings and text embeddings are both the same dimension.
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
length of the prompt after it has been tokenized.
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
additional_embeddings`.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""

@register_to_config
def __init__(
self,
num_attention_heads: int = 32,
attention_head_dim: int = 64,
num_layers: int = 20,
embedding_dim: int = 768,
num_embeddings=77,
additional_embeddings=4,
dropout: float = 0.0,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.additional_embeddings = additional_embeddings

self.time_proj = Timesteps(inner_dim, True, 0)
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)

self.proj_in = nn.Linear(embedding_dim, inner_dim)

self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)

self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))

self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))

self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
activation_fn="gelu",
attention_bias=True,
)
for d in range(num_layers)
]
)

self.norm_out = nn.LayerNorm(inner_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)

causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
)
causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...]
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)

self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))

def forward(
self,
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
x_t, the currently predicted image embeddings.
timestep (`torch.long`):
Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
Hidden states of the text embeddings the denoising process is conditioned on.
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
tuple.
Returns:
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
batch_size = hidden_states.shape[0]

timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(hidden_states.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)

timesteps_projected = self.time_proj(timesteps)

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)

proj_embeddings = self.embedding_proj(proj_embedding)
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
hidden_states = self.proj_in(hidden_states)
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)

hidden_states = torch.cat(
[
encoder_hidden_states,
proj_embeddings[:, None, :],
time_embeddings[:, None, :],
hidden_states[:, None, :],
prd_embedding,
],
dim=1,
)

hidden_states = hidden_states + positional_embeddings

if attention_mask is not None:
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)

for block in self.transformer_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask)

hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states[:, -1]
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)

if not return_dict:
return (predicted_image_embedding,)

return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)

def post_process_latents(self, prior_latents):
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
return prior_latents
16 changes: 15 additions & 1 deletion src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,14 @@ def __init__(
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")

self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None

Expand Down Expand Up @@ -465,9 +472,16 @@ def forward(self, input_tensor, temb):

if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]

if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb

hidden_states = self.norm2(hidden_states)

if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift

hidden_states = self.nonlinearity(hidden_states)

hidden_states = self.dropout(hidden_states)
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
Expand All @@ -66,6 +68,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
"""

@register_to_config
Expand All @@ -88,6 +92,8 @@ def __init__(
attention_head_dim: int = 8,
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -130,6 +136,7 @@ def __init__(
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)

Expand All @@ -140,9 +147,10 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
resnet_time_scale_shift=resnet_time_scale_shift,
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
add_attention=add_attention,
)

# up
Expand All @@ -167,6 +175,7 @@ def __init__(
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down
Loading

0 comments on commit 2dcf64b

Please sign in to comment.