Skip to content

Commit

Permalink
VQ-diffusion (huggingface#658)
Browse files Browse the repository at this point in the history
* Changes for VQ-diffusion VQVAE

Add specify dimension of embeddings to VQModel:
`VQModel` will by default set the dimension of embeddings to the number
of latent channels. The VQ-diffusion VQVAE has a smaller
embedding dimension, 128, than number of latent channels, 256.

Add AttnDownEncoderBlock2D and AttnUpDecoderBlock2D to the up and down
unet block helpers. VQ-diffusion's VQVAE uses those two block types.

* Changes for VQ-diffusion transformer

Modify attention.py so SpatialTransformer can be used for
VQ-diffusion's transformer.

SpatialTransformer:
- Can now operate over discrete inputs (classes of vector embeddings) as well as continuous.
- `in_channels` was made optional in the constructor so two locations where it was passed as a positional arg were moved to kwargs
- modified forward pass to take optional timestep embeddings

ImagePositionalEmbeddings:
- added to provide positional embeddings to discrete inputs for latent pixels

BasicTransformerBlock:
- norm layers were made configurable so that the VQ-diffusion could use AdaLayerNorm with timestep embeddings
- modified forward pass to take optional timestep embeddings

CrossAttention:
- now may optionally take a bias parameter for its query, key, and value linear layers

FeedForward:
- Internal layers are now configurable

ApproximateGELU:
- Activation function in VQ-diffusion's feedforward layer

AdaLayerNorm:
- Norm layer modified to incorporate timestep embeddings

* Add VQ-diffusion scheduler

* Add VQ-diffusion pipeline

* Add VQ-diffusion convert script to diffusers

* Add VQ-diffusion dummy objects

* Add VQ-diffusion markdown docs

* Add VQ-diffusion tests

* some renaming

* some fixes

* more renaming

* correct

* fix typo

* correct weights

* finalize

* fix tests

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* finish

* finish

* up

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
  • Loading branch information
4 people authored Nov 3, 2022
1 parent e211af5 commit 7da0950
Show file tree
Hide file tree
Showing 14 changed files with 1,274 additions and 161 deletions.
4 changes: 3 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand All @@ -38,6 +38,7 @@
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
VQDiffusionPipeline,
)
from .schedulers import (
DDIMScheduler,
Expand All @@ -50,6 +51,7 @@
RePaintScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
VQDiffusionScheduler,
)
from .training_utils import EMAModel
else:
Expand Down
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


if is_torch_available():
from .attention import Transformer2DModel
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
Expand Down
444 changes: 330 additions & 114 deletions models/attention.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __call__(self, hidden_states, context, deterministic=True):
return hidden_states


class FlaxSpatialTransformer(nn.Module):
class FlaxTransformer2DModel(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
https://arxiv.org/pdf/1506.02025.pdf
Expand Down
65 changes: 65 additions & 0 deletions models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,68 @@ def forward(self, x):
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out


class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
For VQ-diffusion:
Output vector embeddings are used as input for the transformer.
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
Args:
num_embed (`int`):
Number of embeddings for the latent pixels embeddings.
height (`int`):
Height of the latent image i.e. the number of height embeddings.
width (`int`):
Width of the latent image i.e. the number of width embeddings.
embed_dim (`int`):
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
"""

def __init__(
self,
num_embed: int,
height: int,
width: int,
embed_dim: int,
):
super().__init__()

self.height = height
self.width = width
self.num_embed = num_embed
self.embed_dim = embed_dim

self.emb = nn.Embedding(self.num_embed, embed_dim)
self.height_emb = nn.Embedding(self.height, embed_dim)
self.width_emb = nn.Embedding(self.width, embed_dim)

def forward(self, index):
emb = self.emb(index)

height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))

# 1 x H x D -> 1 x H x 1 x D
height_emb = height_emb.unsqueeze(2)

width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))

# 1 x W x D -> 1 x 1 x W x D
width_emb = width_emb.unsqueeze(1)

pos_emb = height_emb + width_emb

# 1 x H x W x D -> 1 x L xD
pos_emb = pos_emb.view(1, self.height * self.width, -1)

emb = emb + pos_emb[:, : emb.shape[1], :]

return emb
94 changes: 62 additions & 32 deletions models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from torch import nn

from .attention import AttentionBlock, SpatialTransformer
from .attention import AttentionBlock, Transformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D


Expand Down Expand Up @@ -109,6 +109,19 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
elif down_block_type == "AttnDownEncoderBlock2D":
return AttnDownEncoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{down_block_type} does not exist.")


def get_up_block(
Expand Down Expand Up @@ -200,6 +213,17 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
elif up_block_type == "AttnUpDecoderBlock2D":
return AttnUpDecoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")


Expand Down Expand Up @@ -249,7 +273,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)
resnets.append(
Expand Down Expand Up @@ -325,13 +349,13 @@ def __init__(

for _ in range(num_layers):
attentions.append(
SpatialTransformer(
in_channels,
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
resnets.append(
Expand Down Expand Up @@ -374,7 +398,7 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states)
hidden_states = attn(hidden_states, encoder_hidden_states).sample
hidden_states = resnet(hidden_states, temb)

return hidden_states
Expand Down Expand Up @@ -427,7 +451,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)

Expand Down Expand Up @@ -506,13 +530,13 @@ def __init__(
)
)
attentions.append(
SpatialTransformer(
out_channels,
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
Expand Down Expand Up @@ -556,19 +580,22 @@ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
return module(*inputs)
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample

output_states += (hidden_states,)

Expand Down Expand Up @@ -763,7 +790,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)

Expand Down Expand Up @@ -1014,7 +1041,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)

Expand Down Expand Up @@ -1089,13 +1116,13 @@ def __init__(
)
)
attentions.append(
SpatialTransformer(
out_channels,
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
Expand Down Expand Up @@ -1145,19 +1172,22 @@ def forward(

if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
return module(*inputs)
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample

if self.upsamplers is not None:
for upsampler in self.upsamplers:
Expand Down Expand Up @@ -1337,7 +1367,7 @@ def __init__(
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)

Expand Down
8 changes: 4 additions & 4 deletions models/unet_2d_blocks_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import flax.linen as nn
import jax.numpy as jnp

from .attention_flax import FlaxSpatialTransformer
from .attention_flax import FlaxTransformer2DModel
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D


Expand Down Expand Up @@ -63,7 +63,7 @@ def setup(self):
)
resnets.append(res_block)

attn_block = FlaxSpatialTransformer(
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
Expand Down Expand Up @@ -196,7 +196,7 @@ def setup(self):
)
resnets.append(res_block)

attn_block = FlaxSpatialTransformer(
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
Expand Down Expand Up @@ -326,7 +326,7 @@ def setup(self):
attentions = []

for _ in range(self.num_layers):
attn_block = FlaxSpatialTransformer(
attn_block = FlaxTransformer2DModel(
in_channels=self.in_channels,
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
Expand Down
Loading

0 comments on commit 7da0950

Please sign in to comment.