From 13d723556424cd1958b54c69e0647254106f8ad7 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 23 Oct 2022 20:00:05 -0700 Subject: [PATCH 01/21] Changes for VQ-diffusion VQVAE Add specify dimension of embeddings to VQModel: `VQModel` will by default set the dimension of embeddings to the number of latent channels. The VQ-diffusion VQVAE has a smaller embedding dimension, 128, than number of latent channels, 256. Add AttnDownEncoderBlock2D and AttnUpDecoderBlock2D to the up and down unet block helpers. VQ-diffusion's VQVAE uses those two block types. --- src/diffusers/models/unet_2d_blocks.py | 24 ++++++++++++++++++++++++ src/diffusers/models/vae.py | 11 ++++++----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f4081c5c1cac..838085613627 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -109,6 +109,19 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + raise ValueError(f"{down_block_type} does not exist.") def get_up_block( @@ -200,6 +213,17 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + ) raise ValueError(f"{up_block_type} does not exist.") diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 5f5a47dada0f..220e8869f8bc 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -425,6 +425,7 @@ def __init__( sample_size: int = 32, num_vq_embeddings: int = 256, norm_num_groups: int = 32, + e_dim: Optional[int] = None, ): super().__init__() @@ -440,11 +441,11 @@ def __init__( double_z=False, ) - self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - self.quantize = VectorQuantizer( - num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False - ) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + e_dim = e_dim if e_dim is not None else latent_channels + + self.quant_conv = torch.nn.Conv2d(latent_channels, e_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, e_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = torch.nn.Conv2d(e_dim, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( From 41db6aaa34fde627883a6f71553e838396633553 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 23 Oct 2022 20:23:29 -0700 Subject: [PATCH 02/21] Changes for VQ-diffusion transformer Modify attention.py so SpatialTransformer can be used for VQ-diffusion's transformer. SpatialTransformer: - Can now operate over discrete inputs (classes of vector embeddings) as well as continuous. - `in_channels` was made optional in the constructor so two locations where it was passed as a positional arg were moved to kwargs - modified forward pass to take optional timestep embeddings ImagePositionalEmbeddings: - added to provide positional embeddings to discrete inputs for latent pixels BasicTransformerBlock: - norm layers were made configurable so that the VQ-diffusion could use AdaLayerNorm with timestep embeddings - modified forward pass to take optional timestep embeddings CrossAttention: - now may optionally take a bias parameter for its query, key, and value linear layers FeedForward: - Internal layers are now configurable ApproximateGELU: - Activation function in VQ-diffusion's feedforward layer AdaLayerNorm: - Norm layer modified to incorporate timestep embeddings --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/attention.py | 305 +++++++++++++++++++++---- src/diffusers/models/embeddings.py | 65 ++++++ src/diffusers/models/unet_2d_blocks.py | 6 +- 5 files changed, 331 insertions(+), 48 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 49c3e82b8e7b..5a66450bd36c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, SpatialTransformer, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c5d53b2feb4b..d303937e077e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,6 +16,7 @@ if is_torch_available(): + from .attention import SpatialTransformer from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index af441ef86181..eb44c77c7461 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +from typing import List, Optional import torch import torch.nn.functional as F from torch import nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin +from diffusers.models.embeddings import ImagePositionalEmbeddings + class AttentionBlock(nn.Module): """ @@ -104,65 +108,170 @@ def forward(self, hidden_states): return hidden_states -class SpatialTransformer(nn.Module): +class SpatialTransformer(ModelMixin, ConfigMixin): """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. + Transformer block for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. Parameters: - in_channels (:obj:`int`): The number of channels in the input and output. n_heads (:obj:`int`): The number of heads to use for multi-head attention. d_head (:obj:`int`): The number of channels in each head. + in_channels (: + obj:`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. context_dim (:obj:`int`, *optional*): The number of context dimensions to use. + discrete (: + obj:`bool`, *optional*, defaults to False): Set to True if the input is discrete i.e. over classes of + vector embeddings for each pixel. See the beginning of the docstring for a more in-depth description. + height (:obj:`int`, *optional*): Pass if the input is discrete. The height of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + width (:obj:`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_embed (: + obj:`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of + the latent pixels. Includes the class for the masked latent pixel. + ff_layers (:obj:,`List[Literal["Dropout", "Linear", "ApproximateGELU", "GEGLU"]]` *optional*): + The layers to use in the TransformerBlocks' FeedForward block. + norm_layers (:obj: `List[Literal["LayerNorm", "AdaLayerNorm"]]`, *optional*): + The norm layers to use for the TransformerBlocks. + diffusion_steps (:obj: `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than `diffusion_steps`. + attention_bias (: + obj: `bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ + @register_to_config def __init__( self, - in_channels: int, n_heads: int, d_head: int, + in_channels: Optional[int] = None, depth: int = 1, dropout: float = 0.0, num_groups: int = 32, context_dim: Optional[int] = None, + discrete: bool = False, + height: Optional[int] = None, + width: Optional[int] = None, + num_embed: Optional[int] = None, + ff_layers: Optional[List[str]] = None, + norm_layers: Optional[List[str]] = None, + diffusion_steps: Optional[int] = None, + attention_bias: Optional[bool] = None, ): super().__init__() self.n_heads = n_heads self.d_head = d_head - self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.discrete = discrete + + if self.discrete: + assert height is not None, "SpatialTransformer over discrete input must provide height" + assert width is not None, "SpatialTransformer over discrete input must provide width" + assert num_embed is not None, "SpatialTransformer over discrete input must provide num_embed" + + self.height = height + self.width = width + self.num_embed = num_embed + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=self.num_embed, embed_dim=inner_dim, height=self.height, width=self.width + ) + else: + assert in_channels is not None, "SpatialTransformer over continuous input must provide in_channels" + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.transformer_blocks = nn.ModuleList( [ - BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + ff_layers=ff_layers, + diffusion_steps=diffusion_steps, + attention_bias=attention_bias, + norm_layers=norm_layers, + ) for d in range(depth) ] ) - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + if self.discrete: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_embed - 1) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) def _set_attention_slice(self, slice_size): for block in self.transformer_blocks: block._set_attention_slice(slice_size) - def forward(self, hidden_states, context=None): - # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + def forward(self, hidden_states, context=None, timestep=None): + """ + Args: + hidden_states (:obj: When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + context (:obj: `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep (:obj: `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + Returns: + [`torch.FloatTensor` of shape `(batch size, num embed - 1, num latent pixels)`] if discrete or + [`torch.FloatTensor` of shape `(batch size, channel, height, width)`] if continuous : + If discrete, returns probability distributions for the unnoised latent pixels. Note that it does not + output a prediction for the masked class. + """ + if self.discrete: + hidden_states = self.latent_image_embedding(hidden_states) + else: + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) - return hidden_states + residual + hidden_states = block(hidden_states, context=context, timestep=timestep) + + if self.discrete: + logits = self.out(self.norm_out(hidden_states)) + # (batch, self.num_embed - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + return_value = F.log_softmax(logits.double(), dim=1).float() + else: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + return_value = hidden_states + residual + + return return_value class BasicTransformerBlock(nn.Module): @@ -177,6 +286,13 @@ class BasicTransformerBlock(nn.Module): context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + ff_layers (:obj:,`List[Literal["Dropout", "Linear", "ApproximateGELU", "GEGLU"]]` *optional*): + The layers to use in the FeedForward block. + norm_layers (:obj: `List[Literal["LayerNorm", "AdaLayerNorm"]]`, *optional*): + The norm layers. Must be of length 3. Defaults to `["LayerNorm", "LayerNorm", "LayerNorm"]` + diffusion_steps (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `SpatialTransformer`. + attention_bias (:obj: `bool`, *optional*): Configure if the attentions should contain a bias parameter. """ def __init__( @@ -188,28 +304,58 @@ def __init__( context_dim: Optional[int] = None, gated_ff: bool = True, checkpoint: bool = True, + ff_layers: Optional[List[str]] = None, + norm_layers: Optional[List[str]] = None, + diffusion_steps: Optional[int] = None, + attention_bias: Optional[bool] = None, ): super().__init__() self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, bias=attention_bias ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, layers=ff_layers) self.attn2 = CrossAttention( - query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + bias=attention_bias, ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) + + norm_layers = ["LayerNorm", "LayerNorm", "LayerNorm"] if norm_layers is None else norm_layers + + assert len(norm_layers) == 3, "BasicTransformerBlock only supports 3 norm_layers" + + for idx, norm_layer in enumerate(norm_layers): + if norm_layer == "LayerNorm": + norm_layer_ = nn.LayerNorm(dim) + elif norm_layer == "AdaLayerNorm": + assert diffusion_steps is not None, "When using AdaLayerNorm, you must also pass diffusion_steps." + norm_layer_ = AdaLayerNorm(dim, diffusion_steps) + + if idx == 0: + self.norm1 = norm_layer_ + elif idx == 1: + self.norm2 = norm_layer_ + elif idx == 2: + self.norm3 = norm_layer_ + self.checkpoint = checkpoint def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size - def forward(self, hidden_states, context=None): - hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + def forward(self, hidden_states, context=None, timestep=None): + norm1_kwargs = {"timestep": timestep} if self.norm1.__class__ == AdaLayerNorm else {} + norm2_kwargs = {"timestep": timestep} if self.norm2.__class__ == AdaLayerNorm else {} + norm3_kwargs = {"timestep": timestep} if self.norm3.__class__ == AdaLayerNorm else {} + + hidden_states = self.attn1(self.norm1(hidden_states, **norm1_kwargs)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states, **norm2_kwargs), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states, **norm3_kwargs)) + hidden_states + return hidden_states @@ -224,10 +370,18 @@ class CrossAttention(nn.Module): heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (:obj:`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( - self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + self, + query_dim: int, + context_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=None, ): super().__init__() inner_dim = dim_head * heads @@ -240,9 +394,11 @@ def __init__( # You can set slice_size with `set_attention_slice` self._slice_size = None - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + bias = False if bias is None else bias + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(context_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(context_dim, inner_dim, bias=bias) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim)) @@ -352,22 +508,50 @@ class FeedForward(nn.Module): mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + layers (:obj:,`List[Literal["Dropout", "Linear", "ApproximateGELU", "GEGLU"]]` *optional*): + The list of layers to use. Note that the list must contain exactly two dimension changing layers (Linear + and GEGLU) but may contain as many non-dimension changing layers as you want (Dropout and ApproximateGELU). + The first dimension changing layer will project from the input dimension to the hidden dimension. The + second dimension changing layer will project from the hidden dimension to the output dimension. Defaults to + `["GEGLU", "Dropout", "Linear"]`. """ def __init__( - self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + glu: bool = False, + dropout: float = 0.0, + layers: Optional[List[str]] = None, ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim self.net = nn.ModuleList([]) - # project in - self.net.append(GEGLU(dim, inner_dim)) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(nn.Linear(inner_dim, dim_out)) + layers = ["GEGLU", "Dropout", "Linear"] if layers is None else layers + + dim_idx = 0 + dims = [[dim, inner_dim], [inner_dim, dim_out]] + + error_string = "FeedForward must have exactly two dimension changing layers (Linear and GEGLU)." + + for layer in layers: + if layer == "Dropout": + self.net.append(nn.Dropout(dropout)) + elif layer == "Linear": + assert dim_idx < 2, f"Too many dimension changes. {error_string}" + self.net.append(nn.Linear(*dims[dim_idx])) + dim_idx += 1 + elif layer == "ApproximateGELU": + self.net.append(ApproximateGELU()) + elif layer == "GEGLU": + assert dim_idx < 2, f"Too many dimension changes. {error_string}" + self.net.append(GEGLU(*dims[dim_idx])) + dim_idx += 1 + + assert dim_idx == 2, f"Too few dimension changes. {error_string}" def forward(self, hidden_states): for module in self.net: @@ -398,3 +582,36 @@ def gelu(self, gate): def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 35715e17fc47..b09d43fc2edf 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -126,3 +126,68 @@ def forward(self, x): else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 838085613627..234eebfd971b 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -350,9 +350,9 @@ def __init__( for _ in range(num_layers): attentions.append( SpatialTransformer( - in_channels, attn_num_head_channels, in_channels // attn_num_head_channels, + in_channels=in_channels, depth=1, context_dim=cross_attention_dim, num_groups=resnet_groups, @@ -527,9 +527,9 @@ def __init__( ) attentions.append( SpatialTransformer( - out_channels, attn_num_head_channels, out_channels // attn_num_head_channels, + in_channels=out_channels, depth=1, context_dim=cross_attention_dim, num_groups=resnet_groups, @@ -1106,9 +1106,9 @@ def __init__( ) attentions.append( SpatialTransformer( - out_channels, attn_num_head_channels, out_channels // attn_num_head_channels, + in_channels=out_channels, depth=1, context_dim=cross_attention_dim, num_groups=resnet_groups, From c36960d2b75f50f1822e4e1715282f2c70d63cbc Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 23 Oct 2022 20:30:06 -0700 Subject: [PATCH 03/21] Add VQ-diffusion scheduler --- src/diffusers/__init__.py | 1 + src/diffusers/schedulers/__init__.py | 1 + .../schedulers/scheduling_vq_diffusion.py | 484 ++++++++++++++++++ 3 files changed, 486 insertions(+) create mode 100644 src/diffusers/schedulers/scheduling_vq_diffusion.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5a66450bd36c..83b2e852d903 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -48,6 +48,7 @@ PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler, + VQDiffusionScheduler, ) from .training_utils import EMAModel else: diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index c3999d2cac61..809ec3fdd915 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -27,6 +27,7 @@ from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin + from .scheduling_vq_diffusion import VQDiffusionScheduler else: from ..utils.dummy_pt_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py new file mode 100644 index 000000000000..078c2baf6283 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -0,0 +1,484 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class VQDiffusionSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.LongTensor + + +def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor: + """ + Convert batch of vector of class indices into batch of log onehot vectors + + Args: + x (`torch.LongTensor` of shape `(batch size, vector length)`): + Batch of class indices + + num_classes (`int`): + number of classes to be used for the onehot vectors + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes, vector length)`: + Log onehot vectors + """ + x_onehot = F.one_hot(x, num_classes) + x_onehot = x_onehot.permute(0, 2, 1) + log_x = torch.log(x_onehot.float().clamp(min=1e-30)) + return log_x + + +def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor: + """ + Apply gumbel noise to `logits` + """ + uniform = torch.rand(logits.shape, device=logits.device, generator=generator) + gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) + noised = gumbel_noise + logits + return noised + + +def alpha_schedules(num_diffusion_timesteps: int, a_cumulative_start=0.99999, a_cumulative_end=0.000009): + """ + Cumulative and non-cumulative alpha schedules. + + See section 4.1. + """ + att = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (a_cumulative_end - a_cumulative_start) + + a_cumulative_start + ) + att = np.concatenate(([1], att)) + at = att[1:] / att[:-1] + att = np.concatenate((att[1:], [1])) + return at, att + + +def gamma_schedules(num_diffusion_timesteps: int, c_cumulative_start=0.000009, c_cumulative_end=0.99999): + """ + Cumulative and non-cumulative gamma schedules. + + See section 4.1. + """ + ctt = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (c_cumulative_end - c_cumulative_start) + + c_cumulative_start + ) + ctt = np.concatenate(([0], ctt)) + one_minus_ctt = 1 - ctt + one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] + ct = 1 - one_minus_ct + ctt = np.concatenate((ctt[1:], [0])) + return ct, ctt + + +class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image. + + The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous + diffusion timestep. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2111.14822 + + Args: + num_embed (`int`): + The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked + latent pixel. + + num_train_timesteps (`int`): + Number of diffusion steps used to train the model. + + a_cumulative_start (`float`): + The starting cumulative alpha value. + + a_cumulative_end (`float`): + The ending cumulative alpha value. + + c_cumulative_start (`float`): + The starting cumulative gamma value. + + c_cumulative_end (`float`): + The ending cumulative gamma value. + """ + + @register_to_config + def __init__( + self, + num_embed: int, + num_train_timesteps: int = 100, + a_cumulative_start: float = 0.99999, + a_cumulative_end: float = 0.000009, + c_cumulative_start: float = 0.000009, + c_cumulative_end: float = 0.99999, + ): + self.num_embed = num_embed + + # By convention, the index for the mask class is the last class index + self.mask_class = self.num_embed - 1 + + at, att = alpha_schedules( + num_train_timesteps, a_cumulative_start=a_cumulative_start, a_cumulative_end=a_cumulative_end + ) + ct, ctt = gamma_schedules( + num_train_timesteps, c_cumulative_start=c_cumulative_start, c_cumulative_end=c_cumulative_end + ) + + num_non_mask_classes = self.num_embed - 1 + bt = (1 - at - ct) / num_non_mask_classes + btt = (1 - att - ctt) / num_non_mask_classes + + at = torch.tensor(at.astype("float64")) + bt = torch.tensor(bt.astype("float64")) + ct = torch.tensor(ct.astype("float64")) + log_at = torch.log(at) + log_bt = torch.log(bt) + log_ct = torch.log(ct) + + att = torch.tensor(att.astype("float64")) + btt = torch.tensor(btt.astype("float64")) + ctt = torch.tensor(ctt.astype("float64")) + log_cumprod_at = torch.log(att) + log_cumprod_bt = torch.log(btt) + log_cumprod_ct = torch.log(ctt) + + self.log_at = log_at.float() + self.log_bt = log_bt.float() + self.log_ct = log_ct.float() + self.log_cumprod_at = log_cumprod_at.float() + self.log_cumprod_bt = log_cumprod_bt.float() + self.log_cumprod_ct = log_cumprod_ct.float() + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + device (`str` or `torch.device`): + device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on. + """ + self.num_inference_steps = num_inference_steps + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.log_at = self.log_at.to(device) + self.log_bt = self.log_bt.to(device) + self.log_ct = self.log_ct.to(device) + self.log_cumprod_at = self.log_cumprod_at.to(device) + self.log_cumprod_bt = self.log_cumprod_bt.to(device) + self.log_cumprod_ct = self.log_cumprod_ct.to(device) + + def step( + self, + log_p_x_0: torch.FloatTensor, + t: torch.long, + x_t: torch.LongTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[VQDiffusionSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the + docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed. + + Args: + log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + + t (`torch.long`): + The timestep that determines which transition matrices are used. + + x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t` + + generator: (`torch.Generator` or None): + RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from. + + return_dict (`bool`): + option for returning tuple rather than VQDiffusionSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + if t == 0: + log_p_x_t_min_1 = log_p_x_0 + else: + log_p_x_t_min_1 = self.q_posterior(log_p_x_0, x_t, t) + + log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) + + x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) + + if not return_dict: + return (x_t_min_1,) + + return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) + + def q_posterior(self, log_p_x_0, x_t, t): + """ + Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). + + Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only + forward probabilities. + + Equation (11) stated in terms of forward probabilities via Equation (5): + + Where: + - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0) + + p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) + + Args: + log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + + x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t` + + t (torch.Long): + The timestep that determines which transition matrix is used. + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: + The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). + """ + log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) + + log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True + ) + + log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False + ) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q = log_p_x_0 - log_q_x_t_given_x_0 + + # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , + # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n + q = q - q_log_sum_exp + + # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # . . . + # . . . + # . . . + # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # c_cumulative_{t-1} ... c_cumulative_{t-1} + q = self.apply_cumulative_transitions(q, t - 1) + + # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n + # . . . + # . . . + # . . . + # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n + # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 + log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp + + # For each column, there are two possible cases. + # + # Where: + # - sum(p_n(x_0))) is summing over all classes for x_0 + # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) + # - C_j is the class transitioning to + # + # 1. x_t is masked i.e. x_t = c_k + # + # Simplifying the expression, the column vector is: + # . + # . + # . + # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) + # . + # . + # . + # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) + # + # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. + # + # For the other rows, we can state the equation as ... + # + # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] + # + # This verifies the other rows. + # + # 2. x_t is not masked + # + # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: + # . + # . + # . + # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # 0 + # + # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. + return log_p_x_t_min_1 + + def log_Q_t_transitioning_to_known_class( + self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool + ): + """ + Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each + latent pixel in `x_t`. + + See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix + is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs. + + Args: + t (torch.Long): + The timestep that determines which transition matrix is used. + + x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t`. + + log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): + The log one-hot vectors of `x_t` + + cumulative (`bool`): + If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`, + we use the cumulative transition matrix `0`->`t`. + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: + Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability + transition matrix. + + When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be + masked. + + Where: + - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. + - C_0 is a class of a latent pixel embedding + - C_k is the class of the masked latent pixel + + non-cumulative result (omitting logarithms): + ``` + q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) + . . . + . . . + . . . + q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) + ``` + + cumulative result (omitting logarithms): + ``` + q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) + . . . + . . . + . . . + q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) + ``` + """ + if cumulative: + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + else: + a = self.log_at[t] + b = self.log_bt[t] + c = self.log_ct[t] + + if not cumulative: + # The values in the onehot vector can also be used as the logprobs for transitioning + # from masked latent pixels. If we are not calculating the cumulative transitions, + # we need to save these vectors to be re-appended to the final matrix so the values + # aren't overwritten. + # + # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector + # if x_t is not masked + # + # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector + # if x_t is masked + log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) + + # `index_to_log_onehot` will add onehot vectors for masked pixels, + # so the default one hot matrix has one too many rows. See the doc string + # for an explanation of the dimensionality of the returned matrix. + log_onehot_x_t = log_onehot_x_t[:, :-1, :] + + # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. + # + # Don't worry about what values this sets in the columns that mark transitions + # to masked latent pixels. They are overwrote later with the `mask_class_mask`. + # + # Looking at the below logspace formula in non-logspace, each value will evaluate to either + # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column + # or + # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. + # + # See equation 7 for more details. + log_Q_t = (log_onehot_x_t + a).logaddexp(b) + + # The whole column of each masked pixel is `c` + mask_class_mask = x_t == self.mask_class + mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) + log_Q_t[mask_class_mask] = c + + if not cumulative: + log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) + + return log_Q_t + + def apply_cumulative_transitions(self, q, t): + bsz = q.shape[0] + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + + num_latent_pixels = q.shape[2] + c = c.expand(bsz, 1, num_latent_pixels) + + q = (q + a).logaddexp(b) + q = torch.cat((q, c), dim=1) + + return q From 88efc34702c2fba9cf5e1838ff8139582532d85c Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 23 Oct 2022 20:32:55 -0700 Subject: [PATCH 04/21] Add VQ-diffusion pipeline --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 1 + .../pipelines/vq_diffusion/__init__.py | 1 + .../vq_diffusion/pipeline_vq_diffusion.py | 239 ++++++++++++++++++ 4 files changed, 242 insertions(+) create mode 100644 src/diffusers/pipelines/vq_diffusion/__init__.py create mode 100644 src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 83b2e852d903..875ee8e8c613 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -37,6 +37,7 @@ LDMPipeline, PNDMPipeline, ScoreSdeVePipeline, + VQDiffusionPipeline, ) from .schedulers import ( DDIMScheduler, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b3124af39077..d7fa75034f95 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -20,6 +20,7 @@ StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, ) + from .vq_diffusion import VQDiffusionPipeline if is_transformers_available() and is_onnx_available(): from .stable_diffusion import ( diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py new file mode 100644 index 000000000000..edf6f570f5bf --- /dev/null +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -0,0 +1 @@ +from .pipeline_vq_diffusion import VQDiffusionPipeline diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py new file mode 100644 index 000000000000..86930fc95fdc --- /dev/null +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -0,0 +1,239 @@ +from typing import Callable, List, Optional, Tuple, Union + +import torch + +from diffusers import SpatialTransformer, VQModel +from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler +from transformers import CLIPTextModel, CLIPTokenizer + +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VQDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using VQ Diffusion + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vqvae ([`VQModel`]): + Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent + representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. VQ Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + transformer (`SpatialTransformer`): + Conditional transformer to denoise the encoded image latents. + scheduler ([`VQDiffusionScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + vqvae: VQModel + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + transformer: SpatialTransformer + scheduler: VQDiffusionScheduler + + def __init__( + self, + vqvae: VQModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: SpatialTransformer, + scheduler: VQDiffusionScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + num_inference_steps: int = 100, + truncation_rate: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): + Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at + most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above + `truncation_rate` are set to zero. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor` of shape (batch), *optional*): + Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. + Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will + be generated of completely masked latent pixels. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + batch_size = batch_size * num_images_per_prompt + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + # get the initial completely masked latents unless the user supplied it + + latents_shape = (batch_size, self.transformer.num_latent_pixels) + if latents is None: + mask_class = self.transformer.num_embed - 1 + latents = torch.full(latents_shape, mask_class).to(self.device) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if (latents < 0).any() or (latents >= self.transformer.num_embed).any(): + raise ValueError( + "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," + f" {self.transformer.num_embed - 1} (inclusive)." + ) + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + x_t = latents + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # predict the un-noised image + log_p_x_0 = self.transformer(hidden_states=x_t, context=text_embeddings, timestep=t) + + log_p_x_0 = self.truncate(log_p_x_0, truncation_rate) + + # remove `log(0)`'s (`-inf`s) + log_p_x_0 = log_p_x_0.clamp(-70) + + # compute the previous noisy sample x_t -> x_t-1 + x_t = self.scheduler.step(log_p_x_0, t, x_t, generator=generator).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, x_t) + + embedding_channels = self.vqvae.quantize.e_dim + embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) + embeddings = self.vqvae.quantize.get_codebook_entry(x_t, shape=embeddings_shape) + image = self.vqvae.decode(embeddings, force_not_quantize=True).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor: + """ + Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The + lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. + """ + sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) + sorted_p_x_0 = torch.exp(sorted_log_p_x_0) + keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate + + # Ensure that at least the largest probability is not zeroed out + all_true = torch.full_like(keep_mask[:, 0:1, :], True) + keep_mask = torch.cat((all_true, keep_mask), dim=1) + keep_mask = keep_mask[:, :-1, :] + + keep_mask = keep_mask.gather(1, indices.argsort(1)) + + rv = log_p_x_0.clone() + + rv[~keep_mask] = -torch.inf # -inf = log(0) + + return rv From 613a52cab6d92a8d81b49702a1d20c99c5bcbdce Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 23 Oct 2022 20:46:18 -0700 Subject: [PATCH 05/21] Add VQ-diffusion convert script to diffusers --- scripts/convert_vq_diffusion_to_diffusers.py | 884 +++++++++++++++++++ 1 file changed, 884 insertions(+) create mode 100644 scripts/convert_vq_diffusion_to_diffusers.py diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py new file mode 100644 index 000000000000..f0b99443f5a4 --- /dev/null +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -0,0 +1,884 @@ +""" +This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers. + +It currently only supports porting the ITHQ dataset. + +ITHQ dataset: +```sh +# From the root directory of diffusers. + +# Download the VQVAE checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth + +# Download the VQVAE config +# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class +# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE` +# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml` +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml + +# Download the main model checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth + +# Download the main model config +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml + +# run the convert script +$ python ./scripts/convert_vq_diffusion_to_diffusers.py \ + --checkpoint_path ./ithq_learnable.pth \ + --original_config_file ./ithq.yaml \ + --vqvae_checkpoint_path ./ithq_vqvae.pth \ + --vqvae_original_config_file ./ithq_vqvae.yaml \ + --dump_path +``` +""" + +import argparse +import tempfile + +import torch + +import yaml +from accelerate import init_empty_weights, load_checkpoint_and_dispatch +from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.models.attention import SpatialTransformer +from transformers import CLIPTextModel, CLIPTokenizer +from yaml.loader import FullLoader + + +try: + from omegaconf import OmegaConf +except ImportError: + raise ImportError( + "OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install" + " OmegaConf`." + ) + +# vqvae model + +PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] + + +def vqvae_model_from_original_config(original_config): + assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers." + + original_config = original_config.params + + original_encoder_config = original_config.encoder_config.params + original_decoder_config = original_config.decoder_config.params + + in_channels = original_encoder_config.in_channels + out_channels = original_decoder_config.out_ch + + down_block_types = get_down_block_types(original_encoder_config) + up_block_types = get_up_block_types(original_decoder_config) + + assert original_encoder_config.ch == original_decoder_config.ch + assert original_encoder_config.ch_mult == original_decoder_config.ch_mult + block_out_channels = tuple( + [original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult] + ) + + assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks + layers_per_block = original_encoder_config.num_res_blocks + + assert original_encoder_config.z_channels == original_decoder_config.z_channels + latent_channels = original_encoder_config.z_channels + + num_vq_embeddings = original_config.n_embed + + # Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion + norm_num_groups = 32 + + e_dim = original_config.embed_dim + + model = VQModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + latent_channels=latent_channels, + num_vq_embeddings=num_vq_embeddings, + norm_num_groups=norm_num_groups, + e_dim=e_dim, + ) + + return model + + +def get_down_block_types(original_encoder_config): + attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions) + num_resolutions = len(original_encoder_config.ch_mult) + resolution = coerce_resolution(original_encoder_config.resolution) + + curr_res = resolution + down_block_types = [] + + for _ in range(num_resolutions): + if curr_res in attn_resolutions: + down_block_type = "AttnDownEncoderBlock2D" + else: + down_block_type = "DownEncoderBlock2D" + + down_block_types.append(down_block_type) + + curr_res = [r // 2 for r in curr_res] + + return down_block_types + + +def get_up_block_types(original_decoder_config): + attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions) + num_resolutions = len(original_decoder_config.ch_mult) + resolution = coerce_resolution(original_decoder_config.resolution) + + curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution] + up_block_types = [] + + for _ in reversed(range(num_resolutions)): + if curr_res in attn_resolutions: + up_block_type = "AttnUpDecoderBlock2D" + else: + up_block_type = "UpDecoderBlock2D" + + up_block_types.append(up_block_type) + + curr_res = [r * 2 for r in curr_res] + + return up_block_types + + +def coerce_attn_resolutions(attn_resolutions): + attn_resolutions = OmegaConf.to_object(attn_resolutions) + attn_resolutions_ = [] + for ar in attn_resolutions: + if isinstance(ar, (list, tuple)): + attn_resolutions_.append(list(ar)) + else: + attn_resolutions_.append([ar, ar]) + return attn_resolutions_ + + +def coerce_resolution(resolution): + resolution = OmegaConf.to_object(resolution) + if isinstance(resolution, int): + resolution = [resolution, resolution] # H, W + elif isinstance(resolution, (tuple, list)): + resolution = list(resolution) + else: + raise ValueError("Unknown type of resolution:", resolution) + return resolution + + +# done vqvae model + +# vqvae checkpoint + + +def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint)) + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint)) + + return diffusers_checkpoint + + +def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"], + "decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + +def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # group_norm + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + +# done vqvae checkpoint + +# transformer model + +PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"] +PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"] +PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"] + + +def transformer_model_from_original_config( + original_diffusion_config, original_transformer_config, original_content_embedding_config +): + assert ( + original_diffusion_config.target in PORTED_DIFFUSIONS + ), f"{original_diffusion_config.target} has not yet been ported to diffusers." + assert ( + original_transformer_config.target in PORTED_TRANSFORMERS + ), f"{original_transformer_config.target} has not yet been ported to diffusers." + assert ( + original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS + ), f"{original_content_embedding_config.target} has not yet been ported to diffusers." + + original_diffusion_config = original_diffusion_config.params + original_transformer_config = original_transformer_config.params + original_content_embedding_config = original_content_embedding_config.params + + inner_dim = original_transformer_config["n_embd"] + + n_heads = original_transformer_config["n_head"] + + # VQ-Diffusion gives dimension of the multi-headed attention layers as the + # number of attention heads times the sequence length (the dimension) of a + # single head. We want to specify our attention blocks with those values + # specified separately + assert inner_dim % n_heads == 0 + d_head = inner_dim // n_heads + + depth = original_transformer_config["n_layer"] + context_dim = original_transformer_config["condition_dim"] + + num_embed = original_content_embedding_config["num_embed"] + # the number of embeddings in the transformer includes the mask embedding. + # the content embedding (the vqvae) does not include the mask embedding. + num_embed = num_embed + 1 + + height = original_transformer_config["content_spatial_size"][0] + width = original_transformer_config["content_spatial_size"][1] + dropout = original_transformer_config["resid_pdrop"] + diffusion_steps = original_diffusion_config["diffusion_step"] + + model = SpatialTransformer( + n_heads=n_heads, + d_head=d_head, + depth=depth, + context_dim=context_dim, + discrete=True, + num_embed=num_embed, + height=height, + width=width, + dropout=dropout, + diffusion_steps=diffusion_steps, + ff_layers=["Linear", "ApproximateGELU", "Linear", "Dropout"], + norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], + attention_bias=True, + ) + + return model + + +# done transformer model + +# transformer checkpoint + + +def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + transformer_prefix = "transformer.transformer" + + diffusers_latent_image_embedding_prefix = "latent_image_embedding" + latent_image_embedding_prefix = f"{transformer_prefix}.content_emb" + + # DalleMaskImageEmbedding + diffusers_checkpoint.update( + { + f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.emb.weight" + ], + f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.height_emb.weight" + ], + f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.width_emb.weight" + ], + } + ) + + # transformer blocks + for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks): + diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}" + transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}" + + # ada norm block + diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1" + ada_norm_prefix = f"{transformer_block_prefix}.ln1" + + diffusers_checkpoint.update( + transformer_ada_norm_to_diffusers_checkpoint( + checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix + ) + ) + + # attention block + diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1" + attention_prefix = f"{transformer_block_prefix}.attn1" + + diffusers_checkpoint.update( + transformer_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # ada norm block + diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2" + ada_norm_prefix = f"{transformer_block_prefix}.ln1_1" + + diffusers_checkpoint.update( + transformer_ada_norm_to_diffusers_checkpoint( + checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix + ) + ) + + # attention block + diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2" + attention_prefix = f"{transformer_block_prefix}.attn2" + + diffusers_checkpoint.update( + transformer_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # norm block + diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3" + norm_block_prefix = f"{transformer_block_prefix}.ln2" + + diffusers_checkpoint.update( + { + f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"], + f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"], + } + ) + + # feedforward block + diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff" + feedforward_prefix = f"{transformer_block_prefix}.mlp" + + diffusers_checkpoint.update( + transformer_feedforward_to_diffusers_checkpoint( + checkpoint, + diffusers_feedforward_prefix=diffusers_feedforward_prefix, + feedforward_prefix=feedforward_prefix, + ) + ) + + # to logits + + diffusers_norm_out_prefix = "norm_out" + norm_out_prefix = f"{transformer_prefix}.to_logits.0" + + diffusers_checkpoint.update( + { + f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"], + f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"], + } + ) + + diffusers_out_prefix = "out" + out_prefix = f"{transformer_prefix}.to_logits.1" + + diffusers_checkpoint.update( + { + f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"], + f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"], + } + ) + + return diffusers_checkpoint + + +def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix): + return { + f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"], + f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"], + f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"], + } + + +def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # key + f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"], + f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"], + # query + f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"], + f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"], + # value + f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"], + f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"], + # linear out + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"], + } + + +def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix): + return { + f"{diffusers_feedforward_prefix}.net.0.weight": checkpoint[f"{feedforward_prefix}.0.weight"], + f"{diffusers_feedforward_prefix}.net.0.bias": checkpoint[f"{feedforward_prefix}.0.bias"], + f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"], + f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"], + } + + +# done transformer checkpoint + + +def read_config_file(filename): + # The yaml file contains annotations that certain values should + # loaded as tuples. By default, OmegaConf will panic when reading + # these. Instead, we can manually read the yaml with the FullLoader and then + # construct the OmegaConf object. + with open(filename) as f: + original_config = yaml.load(f, FullLoader) + + return OmegaConf.create(original_config) + + +# We take separate arguments for the vqvae because the ITHQ vqvae config file +# is separate from the config file for the rest of the model. +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--vqvae_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the vqvae checkpoint to convert.", + ) + + parser.add_argument( + "--vqvae_original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture for the vqvae.", + ) + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + + parser.add_argument( + "--original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + # See link for how ema weights are always selected + # https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65 + parser.add_argument( + "--no_use_ema", + action="store_true", + required=False, + help=( + "Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set" + " it as the original VQ-Diffusion always uses the ema weights when loading models." + ), + ) + + args = parser.parse_args() + + use_ema = not args.no_use_ema + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + # vqvae_model + + print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}") + + vqvae_original_config = read_config_file(args.vqvae_original_config_file).model + vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"] + + with init_empty_weights(): + vqvae_model = vqvae_model_from_original_config(vqvae_original_config) + + vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint) + + with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file: + torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name) + del vqvae_diffusers_checkpoint + del vqvae_checkpoint + load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto") + + print("done loading vqvae") + + # done vqvae_model + + # transformer_model + + print( + f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:" + f" {use_ema}" + ) + + original_config = read_config_file(args.original_config_file).model + + diffusion_config = original_config.params.diffusion_config + transformer_config = original_config.params.diffusion_config.params.transformer_config + content_embedding_config = original_config.params.diffusion_config.params.content_emb_config + + pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location) + + if use_ema: + if "ema" in pre_checkpoint: + checkpoint = {} + for k, v in pre_checkpoint["model"].items(): + checkpoint[k] = v + + for k, v in pre_checkpoint["ema"].items(): + # The ema weights are only used on the transformer. To mimic their key as if they came + # from the state_dict for the top level model, we prefix with an additional "transformer." + # See the source linked in the args.use_ema config for more information. + checkpoint[f"transformer.{k}"] = v + else: + print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.") + checkpoint = pre_checkpoint["model"] + else: + checkpoint = pre_checkpoint["model"] + + del pre_checkpoint + + with init_empty_weights(): + transformer_model = transformer_model_from_original_config( + diffusion_config, transformer_config, content_embedding_config + ) + + diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint( + transformer_model, checkpoint + ) + + with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: + torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) + del diffusers_transformer_checkpoint + del checkpoint + load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto") + + print("done loading transformer") + + # done transformer_model + + # text encoder + + print("loading CLIP text encoder") + + clip_name = "openai/clip-vit-base-patch32" + + # The original VQ-Diffusion specifies the pad value by the int used in the + # returned tokens. Each model uses `0` as the pad value. The transformers clip api + # specifies the pad value via the token before it has been tokenized. The `!` pad + # token is the same as padding with the `0` pad value. + pad_token = "!" + + tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") + + assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 + + text_encoder_model = CLIPTextModel.from_pretrained( + clip_name, + # `CLIPTextModel` does not support device_map="auto" + # device_map="auto" + ) + + print("done loading CLIP text encoder") + + # done text encoder + + # scheduler + + scheduler_model = VQDiffusionScheduler( + # the scheduler has the same number of embeddings as the transformer + num_embed=transformer_model.num_embed + ) + + # done scheduler + + print(f"saving VQ diffusion model, path: {args.dump_path}") + + pipe = VQDiffusionPipeline( + vqvae=vqvae_model, + transformer=transformer_model, + tokenizer=tokenizer_model, + text_encoder=text_encoder_model, + scheduler=scheduler_model, + ) + pipe.save_pretrained(args.dump_path) + + print("done writing VQ diffusion model") From 0f3f0ed2e352244438be74bb6a85adf360f0f266 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 23 Oct 2022 20:55:17 -0700 Subject: [PATCH 06/21] Add VQ-diffusion dummy objects --- src/diffusers/utils/dummy_pt_objects.py | 45 +++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5dd583279708..9868aaf8e77f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SpatialTransformer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] @@ -242,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class VQDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -377,6 +407,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class VQDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EMAModel(metaclass=DummyObject): _backends = ["torch"] From 71fb617a02211d66a6076a07e022a7f5d4736056 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 1 Nov 2022 11:21:14 -0700 Subject: [PATCH 07/21] Add VQ-diffusion markdown docs --- docs/source/_toctree.yml | 2 ++ docs/source/api/pipelines/overview.mdx | 1 + docs/source/api/pipelines/vq_diffusion.mdx | 35 ++++++++++++++++++++++ docs/source/api/schedulers.mdx | 6 ++++ 4 files changed, 44 insertions(+) create mode 100644 docs/source/api/pipelines/vq_diffusion.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7e46d95a46f1..7ebe68ef9b59 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -96,5 +96,7 @@ title: "Stochastic Karras VE" - local: api/pipelines/dance_diffusion title: "Dance Diffusion" + - local: api/pipelines/vq_diffusion + title: "VQ Diffusion" title: "Pipelines" title: "API" diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index af711a02d9f3..ce4b46eba5a7 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -54,6 +54,7 @@ available a colab notebook to directly try them out. | [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [vq_diffusion](./vq_diffusion) | [**Vector Quantized Diffusion Model for Text-to-Image Synthesis**](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/docs/source/api/pipelines/vq_diffusion.mdx b/docs/source/api/pipelines/vq_diffusion.mdx new file mode 100644 index 000000000000..c2965c47d5fe --- /dev/null +++ b/docs/source/api/pipelines/vq_diffusion.mdx @@ -0,0 +1,35 @@ + + +# VQDiffusion + +## Overview + +[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo + +The abstract of the paper is the following: + +We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality. + +The original codebase can be found [here](https://github.com/microsoft/VQ-Diffusion). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_vq_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py) | *Text-to-Image Generation* | - | + + +## VQDiffusionPipeline +[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline + - __call__ + diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 3f88e563de19..b2b99a20fe3b 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -112,3 +112,9 @@ Score SDE-VP is under construction. [[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler + +#### VQDiffusionScheduler + +Original paper can be found [here](https://arxiv.org/abs/2111.14822) + +[[autodoc]] VQDiffusionScheduler From f848dbb818a5b7aa97a78e11e2988b4310565d1e Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 1 Nov 2022 13:11:49 -0700 Subject: [PATCH 08/21] Add VQ-diffusion tests --- tests/pipelines/vq_diffusion/__init__.py | 0 .../vq_diffusion/test_vq_diffusion.py | 150 +++++++++++++++ tests/test_layers_utils.py | 181 +++++++++++++++++- tests/test_scheduler.py | 138 ++++++++++--- 4 files changed, 442 insertions(+), 27 deletions(-) create mode 100644 tests/pipelines/vq_diffusion/__init__.py create mode 100644 tests/pipelines/vq_diffusion/test_vq_diffusion.py diff --git a/tests/pipelines/vq_diffusion/__init__.py b/tests/pipelines/vq_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py new file mode 100644 index 000000000000..d71bc4ca4649 --- /dev/null +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -0,0 +1,150 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import SpatialTransformer, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.utils import slow +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def num_embed(self): + return 12 + + @property + def diffusion_steps(self): + return 12 + + @property + def dummy_vqvae(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + num_vq_embeddings=self.num_embed, + ) + return model + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_transformer(self): + torch.manual_seed(0) + + height = 12 + width = 12 + + model = SpatialTransformer( + n_heads=1, + d_head=height * width, + context_dim=32, + discrete=True, + num_embed=self.num_embed, + height=height, + width=width, + diffusion_steps=self.diffusion_steps, + ff_layers=["Linear", "ApproximateGELU", "Linear", "Dropout"], + norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], + attention_bias=True, + ) + return model + + def test_vq_diffusion(self): + device = "cpu" + + vqvae = self.dummy_vqvae + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + transformer = self.dummy_transformer + scheduler = VQDiffusionScheduler(self.num_embed) + + pipe = VQDiffusionPipeline( + vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler + ) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + prompt = "teddy bear playing in the pool" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = pipe( + [prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2 + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 24, 24, 3) + + expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class VQDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @unittest.skip("VQ Diffusion model not saved to hub") + def test_vq_diffusion(self): + pass diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index cf531fbf3fd3..ee6a26b0656a 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -18,8 +18,10 @@ import numpy as np import torch +from torch import nn -from diffusers.models.attention import AttentionBlock, SpatialTransformer +import pytest +from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, SpatialTransformer from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D from diffusers.utils import torch_device @@ -323,6 +325,45 @@ def test_spatial_transformer_context_dim(self): ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + def test_spatial_transformer_timestep(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + diffusion_steps = 5 + + sample = torch.randn(1, 64, 64, 64).to(torch_device) + spatial_transformer_block = SpatialTransformer( + in_channels=64, + n_heads=2, + d_head=32, + dropout=0.0, + context_dim=64, + norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], + diffusion_steps=diffusion_steps, + ).to(torch_device) + with torch.no_grad(): + timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device) + timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device) + attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1) + attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2) + + assert attention_scores_1.shape == (1, 64, 64, 64) + assert attention_scores_2.shape == (1, 64, 64, 64) + + output_slice_1 = attention_scores_1[0, -1, -3:, -3:] + output_slice_2 = attention_scores_2[0, -1, -3:, -3:] + + expected_slice_1 = torch.tensor( + [-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device + ) + expected_slice_2 = torch.tensor( + [-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device + ) + + assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3) + assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3) + def test_spatial_transformer_dropout(self): torch.manual_seed(0) if torch.cuda.is_available(): @@ -350,3 +391,141 @@ def test_spatial_transformer_dropout(self): [-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + @unittest.skipIf(torch_device == "mps", "MPS does not support float64") + def test_spatial_transformer_discrete(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_embed = 5 + + sample = torch.randint(0, num_embed, (1, 32)).to(torch_device) + spatial_transformer_block = ( + SpatialTransformer( + n_heads=1, + d_head=32, + discrete=True, + num_embed=num_embed, + height=16, + width=2, + ) + .to(torch_device) + .eval() + ) + + with torch.no_grad(): + attention_scores = spatial_transformer_block(sample) + + assert attention_scores.shape == (1, num_embed - 1, 32) + + output_slice = attention_scores[0, -3:, -3:] + + expected_slice = torch.tensor( + [-1.4105, -1.0337, -1.4915, -1.8912, -1.1228, -1.3155, -1.9766, -1.9487, -1.1841], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_spatial_transformer_default_norm_layers(self): + spatial_transformer_block = SpatialTransformer(n_heads=1, d_head=32, in_channels=32) + + assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm + + def test_spatial_transformer_ada_norm_layers(self): + spatial_transformer_block = SpatialTransformer( + n_heads=1, + d_head=32, + in_channels=32, + norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], + diffusion_steps=5, + ) + + assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm + + def test_spatial_transformer_ada_norm_layers_requires_diffusion_steps(self): + with pytest.raises(Exception) as e_info: + SpatialTransformer( + n_heads=1, + d_head=32, + in_channels=32, + norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], + ) + + assert e_info.value.args[0] == "When using AdaLayerNorm, you must also pass diffusion_steps." + + def test_spatial_transformer_default_ff_layers(self): + spatial_transformer_block = SpatialTransformer( + n_heads=1, + d_head=32, + in_channels=32, + ) + + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU + assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + + dim = 32 + inner_dim = 128 + + # First dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim + # NOTE: inner_dim * 2 because GEGLU + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2 + + # Second dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim + + def test_spatial_transformer_vq_diffusion_ff_layers(self): + spatial_transformer_block = SpatialTransformer( + n_heads=1, + d_head=32, + in_channels=32, + ff_layers=["Linear", "ApproximateGELU", "Linear", "Dropout"], + ) + + dim = 32 + inner_dim = 128 + + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == nn.Linear + assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == ApproximateGELU + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + assert spatial_transformer_block.transformer_blocks[0].ff.net[3].__class__ == nn.Dropout + + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].in_features == dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].out_features == inner_dim + + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim + + def test_spatial_transformer_ff_layers_too_few_dim_changes(self): + with pytest.raises(Exception) as e_info: + SpatialTransformer(n_heads=1, d_head=32, in_channels=32, ff_layers=["Linear"]) + + assert ( + e_info.value.args[0] + == "Too few dimension changes. FeedForward must have exactly two dimension changing layers (Linear and" + " GEGLU)." + ) + + def test_spatial_transformer_ff_layers_too_many_dim_changes(self): + for layer in ["Linear", "GEGLU"]: + with pytest.raises(Exception) as e_info: + SpatialTransformer(n_heads=1, d_head=32, in_channels=32, ff_layers=[layer] * 3) + + assert ( + e_info.value.args[0] + == "Too many dimension changes. FeedForward must have exactly two dimension changing layers (Linear" + " and GEGLU)." + ) + + def test_spatial_transformer_attention_bias(self): + spatial_transformer_block = SpatialTransformer(n_heads=1, d_head=32, in_channels=32, attention_bias=True) + + assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None + assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None + assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 9285eed20ff0..038278ee8748 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torch.nn.functional as F from diffusers import ( DDIMScheduler, @@ -29,6 +30,7 @@ LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler, + VQDiffusionScheduler, ) from diffusers.utils import torch_device @@ -85,12 +87,18 @@ def check_over_configs(self, time_step=0, **config): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): time_step = float(time_step) - sample = self.dummy_sample - residual = 0.1 * sample - scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) + if scheduler_class == VQDiffusionScheduler: + num_embed = scheduler_config["num_embed"] + sample = self.dummy_sample(num_embed) + model = self.dummy_model(num_embed) + residual = model(sample, time_step) + else: + sample = self.dummy_sample + residual = 0.1 * sample + with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) @@ -122,12 +130,18 @@ def check_over_forward(self, time_step=0, **forward_kwargs): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): time_step = float(time_step) - sample = self.dummy_sample - residual = 0.1 * sample - scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + if scheduler_class == VQDiffusionScheduler: + num_embed = scheduler_config["num_embed"] + sample = self.dummy_sample(num_embed) + model = self.dummy_model(num_embed) + residual = model(sample, time_step) + else: + sample = self.dummy_sample + residual = 0.1 * sample + with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) @@ -154,15 +168,21 @@ def test_from_pretrained_save_pretrained(self): num_inference_steps = kwargs.pop("num_inference_steps", None) for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample + timestep = 1 + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + timestep = float(timestep) scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - timestep = 1 - if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): - timestep = float(timestep) + if scheduler_class == VQDiffusionScheduler: + num_embed = scheduler_config["num_embed"] + sample = self.dummy_sample(num_embed) + model = self.dummy_model(num_embed) + residual = model(sample, timestep) + else: + sample = self.dummy_sample + residual = 0.1 * sample with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) @@ -200,8 +220,14 @@ def test_step_shape(self): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - sample = self.dummy_sample - residual = 0.1 * sample + if scheduler_class == VQDiffusionScheduler: + num_embed = scheduler_config["num_embed"] + sample = self.dummy_sample(num_embed) + model = self.dummy_model(num_embed) + residual = model(sample, timestep_0) + else: + sample = self.dummy_sample + residual = 0.1 * sample if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -255,8 +281,14 @@ def recursive_check(tuple_object, dict_object): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - sample = self.dummy_sample - residual = 0.1 * sample + if scheduler_class == VQDiffusionScheduler: + num_embed = scheduler_config["num_embed"] + sample = self.dummy_sample(num_embed) + model = self.dummy_model(num_embed) + residual = model(sample, timestep) + else: + sample = self.dummy_sample + residual = 0.1 * sample if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -284,22 +316,26 @@ def test_scheduler_public_api(self): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - self.assertTrue( - hasattr(scheduler, "init_noise_sigma"), - f"{scheduler_class} does not implement a required attribute `init_noise_sigma`", - ) - self.assertTrue( - hasattr(scheduler, "scale_model_input"), - f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`", - ) + + if scheduler_class != VQDiffusionScheduler: + self.assertTrue( + hasattr(scheduler, "init_noise_sigma"), + f"{scheduler_class} does not implement a required attribute `init_noise_sigma`", + ) + self.assertTrue( + hasattr(scheduler, "scale_model_input"), + f"{scheduler_class} does not implement a required class method `scale_model_input(sample," + " timestep)`", + ) self.assertTrue( hasattr(scheduler, "step"), f"{scheduler_class} does not implement a required class method `step(...)`", ) - sample = self.dummy_sample - scaled_sample = scheduler.scale_model_input(sample, 0.0) - self.assertEqual(sample.shape, scaled_sample.shape) + if scheduler_class != VQDiffusionScheduler: + sample = self.dummy_sample + scaled_sample = scheduler.scale_model_input(sample, 0.0) + self.assertEqual(sample.shape, scaled_sample.shape) def test_add_noise_device(self): for scheduler_class in self.scheduler_classes: @@ -1238,3 +1274,53 @@ def test_full_loop_no_noise(self): result_mean = torch.mean(torch.abs(sample)) assert abs(result_mean.item() - 2540529) < 10 + + +class VQDiffusionSchedulerTest(SchedulerCommonTest): + scheduler_classes = (VQDiffusionScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_embed": 4097, + "num_train_timesteps": 100, + } + + config.update(**kwargs) + return config + + def dummy_sample(self, num_embed): + batch_size = 4 + height = 8 + width = 8 + + sample = torch.randint(0, num_embed, (batch_size, height * width)) + + return sample + + @property + def dummy_sample_deter(self): + assert False + + def dummy_model(self, num_embed): + def model(sample, t, *args): + batch_size, num_latent_pixels = sample.shape + logits = torch.rand((batch_size, num_embed - 1, num_latent_pixels)) + return_value = F.log_softmax(logits.double(), dim=1).float() + return return_value + + return model + + def test_timesteps(self): + for timesteps in [2, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_num_embed(self): + for num_embed in [5, 100, 1000, 4000]: + self.check_over_configs(num_embed=num_embed) + + def test_time_indices(self): + for t in [0, 50, 99]: + self.check_over_forward(time_step=t) + + def test_add_noise_device(self): + pass From 3d7eb277933a2bfdaf8de7d4519314ff16675104 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 2 Nov 2022 19:32:17 +0000 Subject: [PATCH 09/21] some renaming --- scripts/convert_vq_diffusion_to_diffusers.py | 8 +- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/attention.py | 198 +++++++++--------- src/diffusers/models/attention_flax.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 8 +- src/diffusers/models/unet_2d_blocks_flax.py | 8 +- .../vq_diffusion/pipeline_vq_diffusion.py | 8 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- .../vq_diffusion/test_vq_diffusion.py | 32 ++- tests/test_layers_utils.py | 40 ++-- 11 files changed, 164 insertions(+), 146 deletions(-) diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index f0b99443f5a4..0fa5cefe1f8b 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -40,7 +40,7 @@ import yaml from accelerate import init_empty_weights, load_checkpoint_and_dispatch from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.models.attention import SpatialTransformer +from diffusers.models.attention import Transformer2DModel from transformers import CLIPTextModel, CLIPTokenizer from yaml.loader import FullLoader @@ -507,9 +507,9 @@ def transformer_model_from_original_config( height = original_transformer_config["content_spatial_size"][0] width = original_transformer_config["content_spatial_size"][1] dropout = original_transformer_config["resid_pdrop"] - diffusion_steps = original_diffusion_config["diffusion_step"] + num_embeds_ada_norm = original_diffusion_config["diffusion_step"] - model = SpatialTransformer( + model = Transformer2DModel( n_heads=n_heads, d_head=d_head, depth=depth, @@ -519,7 +519,7 @@ def transformer_model_from_original_config( height=height, width=width, dropout=dropout, - diffusion_steps=diffusion_steps, + num_embeds_ada_norm=num_embeds_ada_norm, ff_layers=["Linear", "ApproximateGELU", "Linear", "Dropout"], norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], attention_bias=True, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 875ee8e8c613..5fcb467b6d10 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, SpatialTransformer, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d303937e077e..5b101d169148 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,7 +16,7 @@ if is_torch_available(): - from .attention import SpatialTransformer + from .attention import Transformer2DModel from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index eb44c77c7461..07be4a237096 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -23,92 +23,7 @@ from diffusers.models.embeddings import ImagePositionalEmbeddings -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, 1) - - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - # get scores - scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - - # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - -class SpatialTransformer(ModelMixin, ConfigMixin): +class Transformer2DModel(ModelMixin, ConfigMixin): """ Transformer block for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs. @@ -147,10 +62,10 @@ class SpatialTransformer(ModelMixin, ConfigMixin): The layers to use in the TransformerBlocks' FeedForward block. norm_layers (:obj: `List[Literal["LayerNorm", "AdaLayerNorm"]]`, *optional*): The norm layers to use for the TransformerBlocks. - diffusion_steps (:obj: `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + num_embeds_ada_norm (:obj: `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than `diffusion_steps`. + up to but not more than steps than `num_embeds_ada_norm`. attention_bias (: obj: `bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ @@ -171,7 +86,7 @@ def __init__( num_embed: Optional[int] = None, ff_layers: Optional[List[str]] = None, norm_layers: Optional[List[str]] = None, - diffusion_steps: Optional[int] = None, + num_embeds_ada_norm: Optional[int] = None, attention_bias: Optional[bool] = None, ): super().__init__() @@ -182,9 +97,9 @@ def __init__( self.discrete = discrete if self.discrete: - assert height is not None, "SpatialTransformer over discrete input must provide height" - assert width is not None, "SpatialTransformer over discrete input must provide width" - assert num_embed is not None, "SpatialTransformer over discrete input must provide num_embed" + assert height is not None, "Transformer2DModel over discrete input must provide height" + assert width is not None, "Transformer2DModel over discrete input must provide width" + assert num_embed is not None, "Transformer2DModel over discrete input must provide num_embed" self.height = height self.width = width @@ -195,7 +110,7 @@ def __init__( num_embed=self.num_embed, embed_dim=inner_dim, height=self.height, width=self.width ) else: - assert in_channels is not None, "SpatialTransformer over continuous input must provide in_channels" + assert in_channels is not None, "Transformer2DModel over continuous input must provide in_channels" self.in_channels = in_channels @@ -211,7 +126,7 @@ def __init__( dropout=dropout, context_dim=context_dim, ff_layers=ff_layers, - diffusion_steps=diffusion_steps, + num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, norm_layers=norm_layers, ) @@ -274,6 +189,91 @@ def forward(self, hidden_states, context=None, timestep=None): return return_value +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. @@ -290,8 +290,8 @@ class BasicTransformerBlock(nn.Module): The layers to use in the FeedForward block. norm_layers (:obj: `List[Literal["LayerNorm", "AdaLayerNorm"]]`, *optional*): The norm layers. Must be of length 3. Defaults to `["LayerNorm", "LayerNorm", "LayerNorm"]` - diffusion_steps (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `SpatialTransformer`. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (:obj: `bool`, *optional*): Configure if the attentions should contain a bias parameter. """ @@ -306,7 +306,7 @@ def __init__( checkpoint: bool = True, ff_layers: Optional[List[str]] = None, norm_layers: Optional[List[str]] = None, - diffusion_steps: Optional[int] = None, + num_embeds_ada_norm: Optional[int] = None, attention_bias: Optional[bool] = None, ): super().__init__() @@ -331,8 +331,8 @@ def __init__( if norm_layer == "LayerNorm": norm_layer_ = nn.LayerNorm(dim) elif norm_layer == "AdaLayerNorm": - assert diffusion_steps is not None, "When using AdaLayerNorm, you must also pass diffusion_steps." - norm_layer_ = AdaLayerNorm(dim, diffusion_steps) + assert num_embeds_ada_norm is not None, "When using AdaLayerNorm, you must also pass num_embeds_ada_norm." + norm_layer_ = AdaLayerNorm(dim, num_embeds_ada_norm) if idx == 0: self.norm1 = norm_layer_ diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1745265b91e1..1b8609474750 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -142,7 +142,7 @@ def __call__(self, hidden_states, context, deterministic=True): return hidden_states -class FlaxSpatialTransformer(nn.Module): +class FlaxTransformer2DModel(nn.Module): r""" A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: https://arxiv.org/pdf/1506.02025.pdf diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 234eebfd971b..38589a26181b 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,7 @@ import torch from torch import nn -from .attention import AttentionBlock, SpatialTransformer +from .attention import AttentionBlock, Transformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -349,7 +349,7 @@ def __init__( for _ in range(num_layers): attentions.append( - SpatialTransformer( + Transformer2DModel( attn_num_head_channels, in_channels // attn_num_head_channels, in_channels=in_channels, @@ -526,7 +526,7 @@ def __init__( ) ) attentions.append( - SpatialTransformer( + Transformer2DModel( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, @@ -1105,7 +1105,7 @@ def __init__( ) ) attentions.append( - SpatialTransformer( + Transformer2DModel( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index baa71beabe35..5798385b9d28 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -15,7 +15,7 @@ import flax.linen as nn import jax.numpy as jnp -from .attention_flax import FlaxSpatialTransformer +from .attention_flax import FlaxTransformer2DModel from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D @@ -63,7 +63,7 @@ def setup(self): ) resnets.append(res_block) - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, @@ -196,7 +196,7 @@ def setup(self): ) resnets.append(res_block) - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, @@ -326,7 +326,7 @@ def setup(self): attentions = [] for _ in range(self.num_layers): - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.in_channels, n_heads=self.attn_num_head_channels, d_head=self.in_channels // self.attn_num_head_channels, diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 86930fc95fdc..9e24ed6bbcbe 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -2,7 +2,7 @@ import torch -from diffusers import SpatialTransformer, VQModel +from diffusers import Transformer2DModel, VQModel from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler from transformers import CLIPTextModel, CLIPTokenizer @@ -31,7 +31,7 @@ class VQDiffusionPipeline(DiffusionPipeline): tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - transformer (`SpatialTransformer`): + transformer (`Transformer2DModel`): Conditional transformer to denoise the encoded image latents. scheduler ([`VQDiffusionScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. @@ -40,7 +40,7 @@ class VQDiffusionPipeline(DiffusionPipeline): vqvae: VQModel text_encoder: CLIPTextModel tokenizer: CLIPTokenizer - transformer: SpatialTransformer + transformer: Transformer2DModel scheduler: VQDiffusionScheduler def __init__( @@ -48,7 +48,7 @@ def __init__( vqvae: VQModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - transformer: SpatialTransformer, + transformer: Transformer2DModel, scheduler: VQDiffusionScheduler, ): super().__init__() diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9868aaf8e77f..da37c833f9e2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,7 +34,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class SpatialTransformer(metaclass=DummyObject): +class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index d71bc4ca4649..fd775d38b63a 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -19,14 +19,17 @@ import numpy as np import torch -from diffusers import SpatialTransformer, VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.utils import slow +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.utils import slow, torch_device, load_image from diffusers.utils.testing_utils import require_torch_gpu from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from ...test_pipelines_common import PipelineTesterMixin +torch.backends.cuda.matmul.allow_tf32 = False + + class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def tearDown(self): # clean up the VRAM after each test @@ -39,7 +42,7 @@ def num_embed(self): return 12 @property - def diffusion_steps(self): + def num_embeds_ada_norm(self): return 12 @property @@ -84,7 +87,7 @@ def dummy_transformer(self): height = 12 width = 12 - model = SpatialTransformer( + model = Transformer2DModel( n_heads=1, d_head=height * width, context_dim=32, @@ -92,7 +95,7 @@ def dummy_transformer(self): num_embed=self.num_embed, height=height, width=width, - diffusion_steps=self.diffusion_steps, + num_embeds_ada_norm=self.num_embeds_ada_norm, ff_layers=["Linear", "ApproximateGELU", "Linear", "Dropout"], norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], attention_bias=True, @@ -145,6 +148,21 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - @unittest.skip("VQ Diffusion model not saved to hub") def test_vq_diffusion(self): - pass + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/vq_diffusion/teddy_bear_pool.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipeline("teddy bear playing in the pool", truncation_rate=0.86, num_images_per_prompt=1, generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).max() < 1e-2 diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index ee6a26b0656a..e7a8fbd55aea 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -21,7 +21,7 @@ from torch import nn import pytest -from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, SpatialTransformer +from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D from diffusers.utils import torch_device @@ -275,14 +275,14 @@ def test_attention_block_sd(self): assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) -class SpatialTransformerTests(unittest.TestCase): +class Transformer2DModelTests(unittest.TestCase): def test_spatial_transformer_default(self): torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) sample = torch.randn(1, 32, 64, 64).to(torch_device) - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( in_channels=32, n_heads=1, d_head=32, @@ -306,7 +306,7 @@ def test_spatial_transformer_context_dim(self): torch.cuda.manual_seed_all(0) sample = torch.randn(1, 64, 64, 64).to(torch_device) - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( in_channels=64, n_heads=2, d_head=32, @@ -330,17 +330,17 @@ def test_spatial_transformer_timestep(self): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - diffusion_steps = 5 + num_embeds_ada_norm = 5 sample = torch.randn(1, 64, 64, 64).to(torch_device) - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( in_channels=64, n_heads=2, d_head=32, dropout=0.0, context_dim=64, norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], - diffusion_steps=diffusion_steps, + num_embeds_ada_norm=num_embeds_ada_norm, ).to(torch_device) with torch.no_grad(): timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device) @@ -371,7 +371,7 @@ def test_spatial_transformer_dropout(self): sample = torch.randn(1, 32, 64, 64).to(torch_device) spatial_transformer_block = ( - SpatialTransformer( + Transformer2DModel( in_channels=32, n_heads=2, d_head=16, @@ -402,7 +402,7 @@ def test_spatial_transformer_discrete(self): sample = torch.randint(0, num_embed, (1, 32)).to(torch_device) spatial_transformer_block = ( - SpatialTransformer( + Transformer2DModel( n_heads=1, d_head=32, discrete=True, @@ -427,38 +427,38 @@ def test_spatial_transformer_discrete(self): assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) def test_spatial_transformer_default_norm_layers(self): - spatial_transformer_block = SpatialTransformer(n_heads=1, d_head=32, in_channels=32) + spatial_transformer_block = Transformer2DModel(n_heads=1, d_head=32, in_channels=32) assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm def test_spatial_transformer_ada_norm_layers(self): - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( n_heads=1, d_head=32, in_channels=32, norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], - diffusion_steps=5, + num_embeds_ada_norm=5, ) assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm - def test_spatial_transformer_ada_norm_layers_requires_diffusion_steps(self): + def test_spatial_transformer_ada_norm_layers_requires_num_embeds_ada_norm(self): with pytest.raises(Exception) as e_info: - SpatialTransformer( + Transformer2DModel( n_heads=1, d_head=32, in_channels=32, norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], ) - assert e_info.value.args[0] == "When using AdaLayerNorm, you must also pass diffusion_steps." + assert e_info.value.args[0] == "When using AdaLayerNorm, you must also pass num_embeds_ada_norm." def test_spatial_transformer_default_ff_layers(self): - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( n_heads=1, d_head=32, in_channels=32, @@ -481,7 +481,7 @@ def test_spatial_transformer_default_ff_layers(self): assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim def test_spatial_transformer_vq_diffusion_ff_layers(self): - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( n_heads=1, d_head=32, in_channels=32, @@ -504,7 +504,7 @@ def test_spatial_transformer_vq_diffusion_ff_layers(self): def test_spatial_transformer_ff_layers_too_few_dim_changes(self): with pytest.raises(Exception) as e_info: - SpatialTransformer(n_heads=1, d_head=32, in_channels=32, ff_layers=["Linear"]) + Transformer2DModel(n_heads=1, d_head=32, in_channels=32, ff_layers=["Linear"]) assert ( e_info.value.args[0] @@ -515,7 +515,7 @@ def test_spatial_transformer_ff_layers_too_few_dim_changes(self): def test_spatial_transformer_ff_layers_too_many_dim_changes(self): for layer in ["Linear", "GEGLU"]: with pytest.raises(Exception) as e_info: - SpatialTransformer(n_heads=1, d_head=32, in_channels=32, ff_layers=[layer] * 3) + Transformer2DModel(n_heads=1, d_head=32, in_channels=32, ff_layers=[layer] * 3) assert ( e_info.value.args[0] @@ -524,7 +524,7 @@ def test_spatial_transformer_ff_layers_too_many_dim_changes(self): ) def test_spatial_transformer_attention_bias(self): - spatial_transformer_block = SpatialTransformer(n_heads=1, d_head=32, in_channels=32, attention_bias=True) + spatial_transformer_block = Transformer2DModel(n_heads=1, d_head=32, in_channels=32, attention_bias=True) assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None From c981f090cccf87509dcc0f0e3c6d0b775840da7b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 2 Nov 2022 19:37:40 +0000 Subject: [PATCH 10/21] some fixes --- src/diffusers/models/attention.py | 51 +++++++++---------- .../vq_diffusion/test_vq_diffusion.py | 10 +++- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3d2652650441..205fd7107cd3 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -171,18 +171,10 @@ def forward(self, hidden_states, context=None, timestep=None): """ if self.discrete: hidden_states = self.latent_image_embedding(hidden_states) - else: - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context, timestep=timestep) + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=context, timestep=timestep) - if self.discrete: logits = self.out(self.norm_out(hidden_states)) # (batch, self.num_embed - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) @@ -190,6 +182,16 @@ def forward(self, hidden_states, context=None, timestep=None): # log(p(x_0)) return_value = F.log_softmax(logits.double(), dim=1).float() else: + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=context, timestep=timestep) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) return_value = hidden_states + residual @@ -343,7 +345,9 @@ def __init__( if norm_layer == "LayerNorm": norm_layer_ = nn.LayerNorm(dim) elif norm_layer == "AdaLayerNorm": - assert num_embeds_ada_norm is not None, "When using AdaLayerNorm, you must also pass num_embeds_ada_norm." + assert ( + num_embeds_ada_norm is not None + ), "When using AdaLayerNorm, you must also pass num_embeds_ada_norm." norm_layer_ = AdaLayerNorm(dim, num_embeds_ada_norm) if idx == 0: @@ -359,16 +363,6 @@ def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size - def forward(self, hidden_states, context=None, timestep=None): - norm1_kwargs = {"timestep": timestep} if self.norm1.__class__ == AdaLayerNorm else {} - norm2_kwargs = {"timestep": timestep} if self.norm2.__class__ == AdaLayerNorm else {} - norm3_kwargs = {"timestep": timestep} if self.norm3.__class__ == AdaLayerNorm else {} - - hidden_states = self.attn1(self.norm1(hidden_states, **norm1_kwargs)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states, **norm2_kwargs), context=context) + hidden_states - hidden_states = self.ff(self.norm3(hidden_states, **norm3_kwargs)) + hidden_states - return hidden_states - def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if not is_xformers_available(): print("Here is how to install it") @@ -395,11 +389,16 @@ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atte self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - def forward(self, hidden_states, context=None): - hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - + def forward(self, hidden_states, context=None, timestep=None): + norm1_kwargs = {"timestep": timestep} if self.norm1.__class__ == AdaLayerNorm else {} + norm2_kwargs = {"timestep": timestep} if self.norm2.__class__ == AdaLayerNorm else {} + norm3_kwargs = {"timestep": timestep} if self.norm3.__class__ == AdaLayerNorm else {} + + hidden_states = self.attn1(self.norm1(hidden_states, **norm1_kwargs)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states, **norm2_kwargs), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states, **norm3_kwargs)) + hidden_states + return hidden_states + class CrossAttention(nn.Module): r""" diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index fd775d38b63a..ed762dd35acb 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -20,7 +20,7 @@ import torch from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.utils import slow, torch_device, load_image +from diffusers.utils import load_image, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -160,7 +160,13 @@ def test_vq_diffusion(self): pipeline.set_progress_bar_config(disable=None) generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipeline("teddy bear playing in the pool", truncation_rate=0.86, num_images_per_prompt=1, generator=generator, output_type="np") + output = pipeline( + "teddy bear playing in the pool", + truncation_rate=0.86, + num_images_per_prompt=1, + generator=generator, + output_type="np", + ) image = output.images[0] From f88e6a24baa95ef0e3d8747c3cfbd5143fa2fd0d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 10:05:15 +0000 Subject: [PATCH 11/21] more renaming --- ...3c81220be0a04e4543e6fd9b0f290547749cc06cfb | 324 ++++++++++++++++++ ...20be0a04e4543e6fd9b0f290547749cc06cfb.json | 1 + .../bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb | 102 ++++++ .../refs/main | 1 + .../pipeline.py | 1 + src/diffusers/models/attention.py | 247 ++++++------- src/diffusers/models/unet_2d_blocks.py | 28 +- .../vq_diffusion/pipeline_vq_diffusion.py | 6 +- 8 files changed, 573 insertions(+), 137 deletions(-) create mode 100644 clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb create mode 100644 clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb.json create mode 100644 hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb create mode 100644 hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main create mode 120000 hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py diff --git a/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb b/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb new file mode 100644 index 000000000000..2c86e9130fdc --- /dev/null +++ b/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb @@ -0,0 +1,324 @@ +import inspect +from typing import List, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from torchvision import transforms +from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer + + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cut_power=1.0): + super().__init__() + + self.cut_size = cut_size + self.cut_power = cut_power + + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) + + +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + +def set_requires_grad(model, value): + for param in model.parameters(): + param.requires_grad = value + + +class CLIPGuidedStableDiffusion(DiffusionPipeline): + """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 + - https://github.com/Jack000/glid-3-xl + - https://github.dev/crowsonkb/k-diffusion + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + clip_model: CLIPModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler], + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + clip_model=clip_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + + self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) + self.make_cutouts = MakeCutouts(feature_extractor.size) + + set_requires_grad(self.text_encoder, False) + set_requires_grad(self.clip_model, False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + self.enable_attention_slicing(None) + + def freeze_vae(self): + set_requires_grad(self.vae, False) + + def unfreeze_vae(self): + set_requires_grad(self.vae, True) + + def freeze_unet(self): + set_requires_grad(self.unet, False) + + def unfreeze_unet(self): + set_requires_grad(self.unet, True) + + @torch.enable_grad() + def cond_fn( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + text_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + latents = latents.detach().requires_grad_() + + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + if isinstance(self.scheduler, PNDMScheduler): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + + if use_cutouts: + image = self.make_cutouts(image, num_cutouts) + else: + image = transforms.Resize(self.feature_extractor.size)(image) + image = self.normalize(image).to(latents.dtype) + + image_embeddings_clip = self.clip_model.get_image_features(image) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + + if use_cutouts: + dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) + dists = dists.view([num_cutouts, sample.shape[0], -1]) + loss = dists.sum(2).mean(0).sum() * clip_guidance_scale + else: + loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale + + grads = -torch.autograd.grad(loss, latents)[0] + + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + clip_guidance_scale: Optional[float] = 100, + clip_prompt: Optional[Union[str, List[str]]] = None, + num_cutouts: Optional[int] = 4, + use_cutouts: Optional[bool] = True, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + if clip_guidance_scale > 0: + if clip_prompt is not None: + clip_text_input = self.tokenizer( + clip_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.to(self.device) + else: + clip_text_input = text_input.input_ids.to(self.device) + text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) + text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + # duplicate text embeddings clip for each generation per prompt + text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # duplicate unconditional embeddings for each generation per prompt + uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + if self.device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform classifier free guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # perform clip guidance + if clip_guidance_scale > 0: + text_embeddings_for_guidance = ( + text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings + ) + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + text_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb.json b/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb.json new file mode 100644 index 000000000000..ebadb9070e07 --- /dev/null +++ b/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb.json @@ -0,0 +1 @@ +{"url": "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/clip_guided_stable_diffusion.py", "etag": "W/\"3e4886ba6cb31f36f75ec5127cd691e562bb04d1f0ff257edbe1c182fd6a210a\""} \ No newline at end of file diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb new file mode 100644 index 000000000000..bbbcb9f65616 --- /dev/null +++ b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb @@ -0,0 +1,102 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +from typing import Optional, Tuple, Union + +import torch + +from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class CustomPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + eta (`float`, *optional*, defaults to 0.0): + The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # eta corresponds to η in paper and should be between [0, 1] + # do x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, eta).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image), "This is a test" \ No newline at end of file diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main new file mode 100644 index 000000000000..152c8af6817e --- /dev/null +++ b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main @@ -0,0 +1 @@ +b8fa12635e53eebebc22f95ee863e7af4fc2fb07 \ No newline at end of file diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py new file mode 120000 index 000000000000..47bb96808073 --- /dev/null +++ b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py @@ -0,0 +1 @@ +../../blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb \ No newline at end of file diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 205fd7107cd3..73a88835a165 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -33,7 +33,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): """ - Transformer block for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs. When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard @@ -47,29 +47,27 @@ class Transformer2DModel(ModelMixin, ConfigMixin): image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. Parameters: - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. + num_attentinon_heads (:obj:`int`): The number of heads to use for multi-head attention. + attention_head_dim (:obj:`int`): The number of channels in each head. in_channels (: obj:`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. - depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + num_layers (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The number of context dimensions to use. + cross_attention_dim (:obj:`int`, *optional*): The number of context dimensions to use. discrete (: obj:`bool`, *optional*, defaults to False): Set to True if the input is discrete i.e. over classes of - vector embeddings for each pixel. See the beginning of the docstring for a more in-depth description. + vector embeddings for each pixel. See the beginning of the docstring for a more in-num_layers description. height (:obj:`int`, *optional*): Pass if the input is discrete. The height of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. width (:obj:`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. - num_embed (: + num_vector_embeds (: obj:`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. ff_layers (:obj:,`List[Literal["Dropout", "Linear", "ApproximateGELU", "GEGLU"]]` *optional*): The layers to use in the TransformerBlocks' FeedForward block. - norm_layers (:obj: `List[Literal["LayerNorm", "AdaLayerNorm"]]`, *optional*): - The norm layers to use for the TransformerBlocks. num_embeds_ada_norm (:obj: `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for @@ -81,72 +79,82 @@ class Transformer2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - n_heads: int, - d_head: int, + num_attentinon_heads: int = 16, + attention_head_dim: int = 88, in_channels: Optional[int] = None, - depth: int = 1, + num_layers: int = 1, dropout: float = 0.0, - num_groups: int = 32, - context_dim: Optional[int] = None, - discrete: bool = False, - height: Optional[int] = None, - width: Optional[int] = None, - num_embed: Optional[int] = None, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, ff_layers: Optional[List[str]] = None, - norm_layers: Optional[List[str]] = None, num_embeds_ada_norm: Optional[int] = None, - attention_bias: Optional[bool] = None, ): super().__init__() - self.n_heads = n_heads - self.d_head = d_head - inner_dim = n_heads * d_head + self.num_attentinon_heads = num_attentinon_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attentinon_heads * attention_head_dim - self.discrete = discrete + # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = in_channels is not None + self.is_input_vectorized = num_vector_embeds is not None - if self.discrete: - assert height is not None, "Transformer2DModel over discrete input must provide height" - assert width is not None, "Transformer2DModel over discrete input must provide width" - assert num_embed is not None, "Transformer2DModel over discrete input must provide num_embed" - - self.height = height - self.width = width - self.num_embed = num_embed - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=self.num_embed, embed_dim=inner_dim, height=self.height, width=self.width + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized: + raise ValueError( + f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is not None." ) - else: - assert in_channels is not None, "Transformer2DModel over continuous input must provide in_channels" + # 2. Define input layers + if self.is_input_continuous: self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, - n_heads, - d_head, + num_attentinon_heads, + attention_head_dim, dropout=dropout, - context_dim=context_dim, + cross_attention_dim=cross_attention_dim, ff_layers=ff_layers, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, - norm_layers=norm_layers, ) - for d in range(depth) + for d in range(num_layers) ] ) - if self.discrete: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_embed - 1) - else: + # 4. Define output layers + if self.is_input_continuous: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) def _set_attention_slice(self, slice_size): for block in self.transformer_blocks: @@ -169,34 +177,36 @@ def forward(self, hidden_states, context=None, timestep=None): If discrete, returns probability distributions for the unnoised latent pixels. Note that it does not output a prediction for the masked class. """ - if self.discrete: - hidden_states = self.latent_image_embedding(hidden_states) - - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context, timestep=timestep) - - logits = self.out(self.norm_out(hidden_states)) - # (batch, self.num_embed - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - return_value = F.log_softmax(logits.double(), dim=1).float() - else: + # 1. Input + if self.is_input_continuous: batch, channel, height, weight = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context, timestep=timestep) + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=context, timestep=timestep) + # 3. Output + if self.is_input_continuous: hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) - return_value = hidden_states + residual + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() - return return_value + return output def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): for block in self.transformer_blocks: @@ -214,7 +224,7 @@ class AttentionBlock(nn.Module): channels (:obj:`int`): The number of channels in the input and output. num_head_channels (:obj:`int`, *optional*): The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + norm_num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ @@ -223,7 +233,7 @@ def __init__( self, channels: int, num_head_channels: Optional[int] = None, - num_groups: int = 32, + norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, ): @@ -232,7 +242,7 @@ def __init__( self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) @@ -294,70 +304,59 @@ class BasicTransformerBlock(nn.Module): Parameters: dim (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. + num_attentinon_heads (:obj:`int`): The number of heads to use for multi-head attention. + attention_head_dim (:obj:`int`): The number of channels in each head. dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + cross_attention_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. - checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. ff_layers (:obj:,`List[Literal["Dropout", "Linear", "ApproximateGELU", "GEGLU"]]` *optional*): The layers to use in the FeedForward block. - norm_layers (:obj: `List[Literal["LayerNorm", "AdaLayerNorm"]]`, *optional*): - The norm layers. Must be of length 3. Defaults to `["LayerNorm", "LayerNorm", "LayerNorm"]` num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (:obj: `bool`, *optional*): Configure if the attentions should contain a bias parameter. + attention_bias (: + obj: `bool`, *optional*, defaults to :obj:`False`): Configure if the attentions should contain a bias + parameter. """ def __init__( self, dim: int, - n_heads: int, - d_head: int, + num_attentinon_heads: int, + attention_head_dim: int, dropout=0.0, - context_dim: Optional[int] = None, + cross_attention_dim: Optional[int] = None, gated_ff: bool = True, - checkpoint: bool = True, ff_layers: Optional[List[str]] = None, - norm_layers: Optional[List[str]] = None, num_embeds_ada_norm: Optional[int] = None, - attention_bias: Optional[bool] = None, + attention_bias: bool = False, ): super().__init__() self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, bias=attention_bias + query_dim=dim, + heads=num_attentinon_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, layers=ff_layers) self.attn2 = CrossAttention( query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, + cross_attention_dim=cross_attention_dim, + heads=num_attentinon_heads, + dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, ) # is self-attn if context is none - norm_layers = ["LayerNorm", "LayerNorm", "LayerNorm"] if norm_layers is None else norm_layers - - assert len(norm_layers) == 3, "BasicTransformerBlock only supports 3 norm_layers" - - for idx, norm_layer in enumerate(norm_layers): - if norm_layer == "LayerNorm": - norm_layer_ = nn.LayerNorm(dim) - elif norm_layer == "AdaLayerNorm": - assert ( - num_embeds_ada_norm is not None - ), "When using AdaLayerNorm, you must also pass num_embeds_ada_norm." - norm_layer_ = AdaLayerNorm(dim, num_embeds_ada_norm) - - if idx == 0: - self.norm1 = norm_layer_ - elif idx == 1: - self.norm2 = norm_layer_ - elif idx == 2: - self.norm3 = norm_layer_ - - self.checkpoint = checkpoint + # layer norms + self.use_ada_layer_norm = num_embeds_ada_norm is not None + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size @@ -390,13 +389,21 @@ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atte self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def forward(self, hidden_states, context=None, timestep=None): - norm1_kwargs = {"timestep": timestep} if self.norm1.__class__ == AdaLayerNorm else {} - norm2_kwargs = {"timestep": timestep} if self.norm2.__class__ == AdaLayerNorm else {} - norm3_kwargs = {"timestep": timestep} if self.norm3.__class__ == AdaLayerNorm else {} + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - hidden_states = self.attn1(self.norm1(hidden_states, **norm1_kwargs)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states, **norm2_kwargs), context=context) + hidden_states - hidden_states = self.ff(self.norm3(hidden_states, **norm3_kwargs)) + hidden_states return hidden_states @@ -406,7 +413,7 @@ class CrossAttention(nn.Module): Parameters: query_dim (:obj:`int`): The number of channels in the query. - context_dim (:obj:`int`, *optional*): + cross_attention_dim (:obj:`int`, *optional*): The number of channels in the context. If not given, defaults to `query_dim`. heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. @@ -418,15 +425,15 @@ class CrossAttention(nn.Module): def __init__( self, query_dim: int, - context_dim: Optional[int] = None, + cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, - bias=None, + bias=False, ): super().__init__() inner_dim = dim_head * heads - context_dim = context_dim if context_dim is not None else query_dim + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.scale = dim_head**-0.5 self.heads = heads @@ -436,11 +443,9 @@ def __init__( self._slice_size = None self._use_memory_efficient_attention_xformers = False - bias = False if bias is None else bias - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(context_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(context_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim)) @@ -643,6 +648,8 @@ class ApproximateGELU(nn.Module): def __init__(self): super().__init__() + # self.linear = nn.Linear(jk + def forward(self, x): return x * torch.sigmoid(1.702 * x) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index e445a8132993..495baa6980ff 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -273,7 +273,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) resnets.append( @@ -353,9 +353,9 @@ def __init__( attn_num_head_channels, in_channels // attn_num_head_channels, in_channels=in_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, ) ) resnets.append( @@ -451,7 +451,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -534,9 +534,9 @@ def __init__( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) @@ -787,7 +787,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -1038,7 +1038,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -1117,9 +1117,9 @@ def __init__( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) @@ -1361,7 +1361,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 9e24ed6bbcbe..2c11f3de3698 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -165,15 +165,15 @@ def __call__( latents_shape = (batch_size, self.transformer.num_latent_pixels) if latents is None: - mask_class = self.transformer.num_embed - 1 + mask_class = self.transformer.num_vector_embeds - 1 latents = torch.full(latents_shape, mask_class).to(self.device) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - if (latents < 0).any() or (latents >= self.transformer.num_embed).any(): + if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): raise ValueError( "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," - f" {self.transformer.num_embed - 1} (inclusive)." + f" {self.transformer.num_vector_embeds - 1} (inclusive)." ) latents = latents.to(self.device) From 1ed3752b91c5464a1afa282102b3d2da37d71c90 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 10:05:26 +0000 Subject: [PATCH 12/21] correct --- ...3c81220be0a04e4543e6fd9b0f290547749cc06cfb | 324 ------------------ ...20be0a04e4543e6fd9b0f290547749cc06cfb.json | 1 - .../bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb | 102 ------ .../refs/main | 1 - .../pipeline.py | 1 - 5 files changed, 429 deletions(-) delete mode 100644 clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb delete mode 100644 clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb.json delete mode 100644 hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb delete mode 100644 hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main delete mode 120000 hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py diff --git a/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb b/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb deleted file mode 100644 index 2c86e9130fdc..000000000000 --- a/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb +++ /dev/null @@ -1,324 +0,0 @@ -import inspect -from typing import List, Optional, Union - -import torch -from torch import nn -from torch.nn import functional as F - -from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput -from torchvision import transforms -from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer - - -class MakeCutouts(nn.Module): - def __init__(self, cut_size, cut_power=1.0): - super().__init__() - - self.cut_size = cut_size - self.cut_power = cut_power - - def forward(self, pixel_values, num_cutouts): - sideY, sideX = pixel_values.shape[2:4] - max_size = min(sideX, sideY) - min_size = min(sideX, sideY, self.cut_size) - cutouts = [] - for _ in range(num_cutouts): - size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) - offsetx = torch.randint(0, sideX - size + 1, ()) - offsety = torch.randint(0, sideY - size + 1, ()) - cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] - cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) - return torch.cat(cutouts) - - -def spherical_dist_loss(x, y): - x = F.normalize(x, dim=-1) - y = F.normalize(y, dim=-1) - return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) - - -def set_requires_grad(model, value): - for param in model.parameters(): - param.requires_grad = value - - -class CLIPGuidedStableDiffusion(DiffusionPipeline): - """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 - - https://github.com/Jack000/glid-3-xl - - https://github.dev/crowsonkb/k-diffusion - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - clip_model: CLIPModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[PNDMScheduler, LMSDiscreteScheduler], - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - self.register_modules( - vae=vae, - text_encoder=text_encoder, - clip_model=clip_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - feature_extractor=feature_extractor, - ) - - self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) - self.make_cutouts = MakeCutouts(feature_extractor.size) - - set_requires_grad(self.text_encoder, False) - set_requires_grad(self.clip_model, False) - - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - self.enable_attention_slicing(None) - - def freeze_vae(self): - set_requires_grad(self.vae, False) - - def unfreeze_vae(self): - set_requires_grad(self.vae, True) - - def freeze_unet(self): - set_requires_grad(self.unet, False) - - def unfreeze_unet(self): - set_requires_grad(self.unet, True) - - @torch.enable_grad() - def cond_fn( - self, - latents, - timestep, - index, - text_embeddings, - noise_pred_original, - text_embeddings_clip, - clip_guidance_scale, - num_cutouts, - use_cutouts=True, - ): - latents = latents.detach().requires_grad_() - - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents - - # predict the noise residual - noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - - if isinstance(self.scheduler, PNDMScheduler): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - beta_prod_t = 1 - alpha_prod_t - # compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) - - fac = torch.sqrt(beta_prod_t) - sample = pred_original_sample * (fac) + latents * (1 - fac) - elif isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - sample = latents - sigma * noise_pred - else: - raise ValueError(f"scheduler type {type(self.scheduler)} not supported") - - sample = 1 / 0.18215 * sample - image = self.vae.decode(sample).sample - image = (image / 2 + 0.5).clamp(0, 1) - - if use_cutouts: - image = self.make_cutouts(image, num_cutouts) - else: - image = transforms.Resize(self.feature_extractor.size)(image) - image = self.normalize(image).to(latents.dtype) - - image_embeddings_clip = self.clip_model.get_image_features(image) - image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) - - if use_cutouts: - dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) - dists = dists.view([num_cutouts, sample.shape[0], -1]) - loss = dists.sum(2).mean(0).sum() * clip_guidance_scale - else: - loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale - - grads = -torch.autograd.grad(loss, latents)[0] - - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents.detach() + grads * (sigma**2) - noise_pred = noise_pred_original - else: - noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads - return noise_pred, latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - clip_guidance_scale: Optional[float] = 100, - clip_prompt: Optional[Union[str, List[str]]] = None, - num_cutouts: Optional[int] = 4, - use_cutouts: Optional[bool] = True, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] - # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - - if clip_guidance_scale > 0: - if clip_prompt is not None: - clip_text_input = self.tokenizer( - clip_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ).input_ids.to(self.device) - else: - clip_text_input = text_input.input_ids.to(self.device) - text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) - text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) - # duplicate text embeddings clip for each generation per prompt - text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - max_length = text_input.input_ids.shape[-1] - uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - # duplicate unconditional embeddings for each generation per prompt - uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) - latents_dtype = text_embeddings.dtype - if latents is None: - if self.device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( - self.device - ) - else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - # Some schedulers like PNDM have timesteps as arrays - # It's more optimized to move all timesteps to correct device beforehand - timesteps_tensor = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - - for i, t in enumerate(self.progress_bar(timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform classifier free guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # perform clip guidance - if clip_guidance_scale > 0: - text_embeddings_for_guidance = ( - text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings - ) - noise_pred, latents = self.cond_fn( - latents, - t, - i, - text_embeddings_for_guidance, - noise_pred, - text_embeddings_clip, - clip_guidance_scale, - num_cutouts, - use_cutouts, - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents).prev_sample - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, None) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb.json b/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb.json deleted file mode 100644 index ebadb9070e07..000000000000 --- a/clip_guided_stable_diffusion/72392adcdf265e793b0dc13d166393a9d1367724bb03f6faca8cfb1c91c30827.8d4a13da440f0a37b6d42d3c81220be0a04e4543e6fd9b0f290547749cc06cfb.json +++ /dev/null @@ -1 +0,0 @@ -{"url": "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/clip_guided_stable_diffusion.py", "etag": "W/\"3e4886ba6cb31f36f75ec5127cd691e562bb04d1f0ff257edbe1c182fd6a210a\""} \ No newline at end of file diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb deleted file mode 100644 index bbbcb9f65616..000000000000 --- a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - - -from typing import Optional, Tuple, Union - -import torch - -from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class CustomPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of - [`DDPMScheduler`], or [`DDIMScheduler`]. - """ - - def __init__(self, unet, scheduler): - super().__init__() - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - generator: Optional[torch.Generator] = None, - eta: float = 0.0, - num_inference_steps: int = 50, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - eta (`float`, *optional*, defaults to 0.0): - The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. - """ - - # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - ) - image = image.to(self.device) - - # set step values - self.scheduler.set_timesteps(num_inference_steps) - - for t in self.progress_bar(self.scheduler.timesteps): - # 1. predict noise model_output - model_output = self.unet(image, t).sample - - # 2. predict previous mean of image x_t-1 and add variance depending on eta - # eta corresponds to η in paper and should be between [0, 1] - # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta).prev_sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image), "This is a test" \ No newline at end of file diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main deleted file mode 100644 index 152c8af6817e..000000000000 --- a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/refs/main +++ /dev/null @@ -1 +0,0 @@ -b8fa12635e53eebebc22f95ee863e7af4fc2fb07 \ No newline at end of file diff --git a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py b/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py deleted file mode 120000 index 47bb96808073..000000000000 --- a/hf-internal-testing/diffusers-dummy-pipeline/models--hf-internal-testing--diffusers-dummy-pipeline/snapshots/b8fa12635e53eebebc22f95ee863e7af4fc2fb07/pipeline.py +++ /dev/null @@ -1 +0,0 @@ -../../blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb \ No newline at end of file From 5fe7cfadd7ffc1501c5b4e92b4e04b20ef81ecb2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 10:10:25 +0000 Subject: [PATCH 13/21] fix typo --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 73a88835a165..bbc5c923ddc4 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -397,7 +397,7 @@ def forward(self, hidden_states, context=None, timestep=None): # 2. Cross-Attention norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states From 8350615803a0af39d522c95e79fab3d487066143 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 10:55:44 +0000 Subject: [PATCH 14/21] correct weights --- scripts/convert_vq_diffusion_to_diffusers.py | 39 +++++---- src/diffusers/models/attention.py | 86 +++++++------------ src/diffusers/models/vae.py | 21 +++-- .../vq_diffusion/pipeline_vq_diffusion.py | 2 +- .../schedulers/scheduling_vq_diffusion.py | 42 ++++----- .../vq_diffusion/test_vq_diffusion.py | 3 +- 6 files changed, 86 insertions(+), 107 deletions(-) diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index 0fa5cefe1f8b..ae105e30362e 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -101,7 +101,7 @@ def vqvae_model_from_original_config(original_config): latent_channels=latent_channels, num_vq_embeddings=num_vq_embeddings, norm_num_groups=norm_num_groups, - e_dim=e_dim, + vq_embed_dim=e_dim, ) return model @@ -506,25 +506,26 @@ def transformer_model_from_original_config( height = original_transformer_config["content_spatial_size"][0] width = original_transformer_config["content_spatial_size"][1] + + assert width == height, "width has to be equal to height" dropout = original_transformer_config["resid_pdrop"] num_embeds_ada_norm = original_diffusion_config["diffusion_step"] - model = Transformer2DModel( - n_heads=n_heads, - d_head=d_head, - depth=depth, - context_dim=context_dim, - discrete=True, - num_embed=num_embed, - height=height, - width=width, - dropout=dropout, - num_embeds_ada_norm=num_embeds_ada_norm, - ff_layers=["Linear", "ApproximateGELU", "Linear", "Dropout"], - norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], - attention_bias=True, - ) + model_kwargs = { + "attention_bias": True, + "cross_attention_dim": context_dim, + "attention_head_dim": d_head, + "num_layers": depth, + "dropout": dropout, + "num_attention_heads": n_heads, + "num_vector_embeds": num_embed, + "num_embeds_ada_norm": num_embeds_ada_norm, + "norm_num_groups": 32, + "sample_size": width, + "activation_fn": "geglu-approximate", + } + model = Transformer2DModel(**model_kwargs) return model @@ -676,8 +677,8 @@ def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_atten def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix): return { - f"{diffusers_feedforward_prefix}.net.0.weight": checkpoint[f"{feedforward_prefix}.0.weight"], - f"{diffusers_feedforward_prefix}.net.0.bias": checkpoint[f"{feedforward_prefix}.0.bias"], + f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"], + f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"], f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"], f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"], } @@ -865,7 +866,7 @@ def read_config_file(filename): scheduler_model = VQDiffusionScheduler( # the scheduler has the same number of embeddings as the transformer - num_embed=transformer_model.num_embed + num_vec_classes=transformer_model.num_vector_embeds ) # done scheduler diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bbc5c923ddc4..6dbea4c538dd 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import List, Optional +from typing import Optional import torch import torch.nn.functional as F @@ -47,7 +47,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. Parameters: - num_attentinon_heads (:obj:`int`): The number of heads to use for multi-head attention. + num_attention_heads (:obj:`int`): The number of heads to use for multi-head attention. attention_head_dim (:obj:`int`): The number of channels in each head. in_channels (: obj:`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. @@ -66,8 +66,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): num_vector_embeds (: obj:`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. - ff_layers (:obj:,`List[Literal["Dropout", "Linear", "ApproximateGELU", "GEGLU"]]` *optional*): - The layers to use in the TransformerBlocks' FeedForward block. + activation_fn (:obj:`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (:obj: `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for @@ -79,7 +78,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - num_attentinon_heads: int = 16, + num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, @@ -89,13 +88,13 @@ def __init__( attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, - ff_layers: Optional[List[str]] = None, + activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, ): super().__init__() - self.num_attentinon_heads = num_attentinon_heads + self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim - inner_dim = num_attentinon_heads * attention_head_dim + inner_dim = num_attention_heads * attention_head_dim # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration @@ -137,11 +136,11 @@ def __init__( [ BasicTransformerBlock( inner_dim, - num_attentinon_heads, + num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, - ff_layers=ff_layers, + activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, ) @@ -304,13 +303,11 @@ class BasicTransformerBlock(nn.Module): Parameters: dim (:obj:`int`): The number of channels in the input and output. - num_attentinon_heads (:obj:`int`): The number of heads to use for multi-head attention. + num_attention_heads (:obj:`int`): The number of heads to use for multi-head attention. attention_head_dim (:obj:`int`): The number of channels in each head. dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. - gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. - ff_layers (:obj:,`List[Literal["Dropout", "Linear", "ApproximateGELU", "GEGLU"]]` *optional*): - The layers to use in the FeedForward block. + activation_fn (:obj:`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: @@ -321,28 +318,27 @@ class BasicTransformerBlock(nn.Module): def __init__( self, dim: int, - num_attentinon_heads: int, + num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, - gated_ff: bool = True, - ff_layers: Optional[List[str]] = None, + activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, ): super().__init__() self.attn1 = CrossAttention( query_dim=dim, - heads=num_attentinon_heads, + heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, layers=ff_layers) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.attn2 = CrossAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, - heads=num_attentinon_heads, + heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, @@ -560,14 +556,8 @@ class FeedForward(nn.Module): dim (:obj:`int`): The number of channels in the input. dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - layers (:obj:,`List[Literal["Dropout", "Linear", "ApproximateGELU", "GEGLU"]]` *optional*): - The list of layers to use. Note that the list must contain exactly two dimension changing layers (Linear - and GEGLU) but may contain as many non-dimension changing layers as you want (Dropout and ApproximateGELU). - The first dimension changing layer will project from the input dimension to the hidden dimension. The - second dimension changing layer will project from the hidden dimension to the output dimension. Defaults to - `["GEGLU", "Dropout", "Linear"]`. + activation_fn (:obj:`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. """ def __init__( @@ -575,37 +565,25 @@ def __init__( dim: int, dim_out: Optional[int] = None, mult: int = 4, - glu: bool = False, dropout: float = 0.0, - layers: Optional[List[str]] = None, + activation_fn: str = "geglu", ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - self.net = nn.ModuleList([]) - - layers = ["GEGLU", "Dropout", "Linear"] if layers is None else layers - - dim_idx = 0 - dims = [[dim, inner_dim], [inner_dim, dim_out]] - - error_string = "FeedForward must have exactly two dimension changing layers (Linear and GEGLU)." - for layer in layers: - if layer == "Dropout": - self.net.append(nn.Dropout(dropout)) - elif layer == "Linear": - assert dim_idx < 2, f"Too many dimension changes. {error_string}" - self.net.append(nn.Linear(*dims[dim_idx])) - dim_idx += 1 - elif layer == "ApproximateGELU": - self.net.append(ApproximateGELU()) - elif layer == "GEGLU": - assert dim_idx < 2, f"Too many dimension changes. {error_string}" - self.net.append(GEGLU(*dims[dim_idx])) - dim_idx += 1 + if activation_fn == "geglu": + geglu = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + geglu = ApproximateGELU(dim, inner_dim) - assert dim_idx == 2, f"Too few dimension changes. {error_string}" + self.net = nn.ModuleList([]) + # project in + self.net.append(geglu) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) def forward(self, hidden_states): for module in self.net: @@ -645,12 +623,12 @@ class ApproximateGELU(nn.Module): For more details, see section 2: https://arxiv.org/abs/1606.08415 """ - def __init__(self): + def __init__(self, dim_in: int, dim_out: int): super().__init__() - - # self.linear = nn.Linear(jk + self.proj = nn.Linear(dim_in, dim_out) def forward(self, x): + x = self.proj(x) return x * torch.sigmoid(1.702 * x) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 220e8869f8bc..30de343d08ee 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -233,14 +233,16 @@ class VectorQuantizer(nn.Module): # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + def __init__( + self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True + ): super().__init__() self.n_e = n_e - self.e_dim = e_dim + self.vq_embed_dim = vq_embed_dim self.beta = beta self.legacy = legacy - self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) self.remap = remap @@ -287,7 +289,7 @@ def unmap_to_all(self, inds): def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.e_dim) + z_flattened = z.view(-1, self.vq_embed_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = ( @@ -409,6 +411,7 @@ class VQModel(ModelMixin, ConfigMixin): latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): TODO num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. """ @register_to_config @@ -425,7 +428,7 @@ def __init__( sample_size: int = 32, num_vq_embeddings: int = 256, norm_num_groups: int = 32, - e_dim: Optional[int] = None, + vq_embed_dim: Optional[int] = None, ): super().__init__() @@ -441,11 +444,11 @@ def __init__( double_z=False, ) - e_dim = e_dim if e_dim is not None else latent_channels + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels - self.quant_conv = torch.nn.Conv2d(latent_channels, e_dim, 1) - self.quantize = VectorQuantizer(num_vq_embeddings, e_dim, beta=0.25, remap=None, sane_index_shape=False) - self.post_quant_conv = torch.nn.Conv2d(e_dim, latent_channels, 1) + self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 2c11f3de3698..a1399ec885e1 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -200,7 +200,7 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, x_t) - embedding_channels = self.vqvae.quantize.e_dim + embedding_channels = self.vqvae.config.vq_embed_dim embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) embeddings = self.vqvae.quantize.get_codebook_entry(x_t, shape=embeddings_shape) image = self.vqvae.decode(embeddings, force_not_quantize=True).sample diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index 078c2baf6283..132f34488abd 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -55,15 +55,15 @@ def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator return noised -def alpha_schedules(num_diffusion_timesteps: int, a_cumulative_start=0.99999, a_cumulative_end=0.000009): +def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): """ Cumulative and non-cumulative alpha schedules. See section 4.1. """ att = ( - np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (a_cumulative_end - a_cumulative_start) - + a_cumulative_start + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) + + alpha_cum_start ) att = np.concatenate(([1], att)) at = att[1:] / att[:-1] @@ -71,15 +71,15 @@ def alpha_schedules(num_diffusion_timesteps: int, a_cumulative_start=0.99999, a_ return at, att -def gamma_schedules(num_diffusion_timesteps: int, c_cumulative_start=0.000009, c_cumulative_end=0.99999): +def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): """ Cumulative and non-cumulative gamma schedules. See section 4.1. """ ctt = ( - np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (c_cumulative_end - c_cumulative_start) - + c_cumulative_start + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) + + gamma_cum_start ) ctt = np.concatenate(([0], ctt)) one_minus_ctt = 1 - ctt @@ -104,47 +104,43 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): For more details, see the original paper: https://arxiv.org/abs/2111.14822 Args: - num_embed (`int`): + num_vec_classes (`int`): The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. num_train_timesteps (`int`): Number of diffusion steps used to train the model. - a_cumulative_start (`float`): + alpha_cum_start (`float`): The starting cumulative alpha value. - a_cumulative_end (`float`): + alpha_cum_end (`float`): The ending cumulative alpha value. - c_cumulative_start (`float`): + gamma_cum_start (`float`): The starting cumulative gamma value. - c_cumulative_end (`float`): + gamma_cum_end (`float`): The ending cumulative gamma value. """ @register_to_config def __init__( self, - num_embed: int, + num_vec_classes: int, num_train_timesteps: int = 100, - a_cumulative_start: float = 0.99999, - a_cumulative_end: float = 0.000009, - c_cumulative_start: float = 0.000009, - c_cumulative_end: float = 0.99999, + alpha_cum_start: float = 0.99999, + alpha_cum_end: float = 0.000009, + gamma_cum_start: float = 0.000009, + gamma_cum_end: float = 0.99999, ): - self.num_embed = num_embed + self.num_embed = num_vec_classes # By convention, the index for the mask class is the last class index self.mask_class = self.num_embed - 1 - at, att = alpha_schedules( - num_train_timesteps, a_cumulative_start=a_cumulative_start, a_cumulative_end=a_cumulative_end - ) - ct, ctt = gamma_schedules( - num_train_timesteps, c_cumulative_start=c_cumulative_start, c_cumulative_end=c_cumulative_end - ) + at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) + ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) num_non_mask_classes = self.num_embed - 1 bt = (1 - at - ct) / num_non_mask_classes diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index ed762dd35acb..79d2b21aabad 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -155,7 +155,8 @@ def test_vq_diffusion(self): ) expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") + # pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") + pipeline = VQDiffusionPipeline.from_pretrained("/home/patrick_huggingface_co/vq-diffusion-ithq") pipeline = pipeline.to(torch_device) pipeline.set_progress_bar_config(disable=None) From 558195d5bce010275a697d4e3f047a35592991a9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 11:20:18 +0000 Subject: [PATCH 15/21] finalize --- src/diffusers/models/attention.py | 43 +++++++++++++------ src/diffusers/models/unet_2d_blocks.py | 28 +++++++----- .../vq_diffusion/pipeline_vq_diffusion.py | 15 ++++--- .../schedulers/scheduling_vq_diffusion.py | 12 +++--- .../vq_diffusion/test_vq_diffusion.py | 3 +- 5 files changed, 63 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6dbea4c538dd..104e113e246f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,16 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F from torch import nn -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin -from diffusers.models.embeddings import ImagePositionalEmbeddings -from diffusers.utils.import_utils import is_xformers_available +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..models.embeddings import ImagePositionalEmbeddings +from ..utils import BaseOutput +from ..utils.import_utils import is_xformers_available + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if Transformer2DModel is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor if is_xformers_available(): @@ -159,22 +173,24 @@ def _set_attention_slice(self, slice_size): for block in self.transformer_blocks: block._set_attention_slice(slice_size) - def forward(self, hidden_states, context=None, timestep=None): + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): """ Args: hidden_states (:obj: When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states - context (:obj: `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + encoder_hidden_states (:obj: `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep (:obj: `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Returns: - [`torch.FloatTensor` of shape `(batch size, num embed - 1, num latent pixels)`] if discrete or - [`torch.FloatTensor` of shape `(batch size, channel, height, width)`] if continuous : - If discrete, returns probability distributions for the unnoised latent pixels. Note that it does not - output a prediction for the masked class. + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. """ # 1. Input if self.is_input_continuous: @@ -189,7 +205,7 @@ def forward(self, hidden_states, context=None, timestep=None): # 2. Blocks for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context, timestep=timestep) + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) # 3. Output if self.is_input_continuous: @@ -205,7 +221,10 @@ def forward(self, hidden_states, context=None, timestep=None): # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() - return output + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): for block in self.transformer_blocks: diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 495baa6980ff..4132ccbd0ce0 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -398,7 +398,7 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states).sample hidden_states = resnet(hidden_states, temb) return hidden_states @@ -580,19 +580,22 @@ def forward(self, hidden_states, temb=None, encoder_hidden_states=None): for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): - return module(*inputs) + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn), hidden_states, encoder_hidden_states - ) + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample output_states += (hidden_states,) @@ -1169,19 +1172,22 @@ def forward( if self.training and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): - return module(*inputs) + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn), hidden_states, encoder_hidden_states - ) + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index a1399ec885e1..104adb62a9eb 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -182,27 +182,28 @@ def __call__( timesteps_tensor = self.scheduler.timesteps.to(self.device) - x_t = latents + sample = latents for i, t in enumerate(self.progress_bar(timesteps_tensor)): # predict the un-noised image - log_p_x_0 = self.transformer(hidden_states=x_t, context=text_embeddings, timestep=t) + # model_output == `log_p_x_0` + model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample - log_p_x_0 = self.truncate(log_p_x_0, truncation_rate) + model_output = self.truncate(model_output, truncation_rate) # remove `log(0)`'s (`-inf`s) - log_p_x_0 = log_p_x_0.clamp(-70) + model_output = model_output.clamp(-70) # compute the previous noisy sample x_t -> x_t-1 - x_t = self.scheduler.step(log_p_x_0, t, x_t, generator=generator).prev_sample + sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: - callback(i, t, x_t) + callback(i, t, sample) embedding_channels = self.vqvae.config.vq_embed_dim embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) - embeddings = self.vqvae.quantize.get_codebook_entry(x_t, shape=embeddings_shape) + embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) image = self.vqvae.decode(embeddings, force_not_quantize=True).sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index 132f34488abd..c4f039b7bde1 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -195,9 +195,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic def step( self, - log_p_x_0: torch.FloatTensor, - t: torch.long, - x_t: torch.LongTensor, + model_output: torch.FloatTensor, + timestep: torch.long, + sample: torch.LongTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[VQDiffusionSchedulerOutput, Tuple]: @@ -227,10 +227,10 @@ def step( [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - if t == 0: - log_p_x_t_min_1 = log_p_x_0 + if timestep == 0: + log_p_x_t_min_1 = model_output else: - log_p_x_t_min_1 = self.q_posterior(log_p_x_0, x_t, t) + log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index 79d2b21aabad..ed762dd35acb 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -155,8 +155,7 @@ def test_vq_diffusion(self): ) expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - # pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") - pipeline = VQDiffusionPipeline.from_pretrained("/home/patrick_huggingface_co/vq-diffusion-ithq") + pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") pipeline = pipeline.to(torch_device) pipeline.set_progress_bar_config(disable=None) From c72b8e9c9129155108bb29de9292d1e6ccc5490f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 12:53:06 +0000 Subject: [PATCH 16/21] fix tests --- tests/test_layers_utils.py | 130 ++++++++++++++----------------------- tests/test_scheduler.py | 46 ++++++------- 2 files changed, 70 insertions(+), 106 deletions(-) diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index e7a8fbd55aea..911ec548b3bd 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -20,7 +20,6 @@ import torch from torch import nn -import pytest from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D @@ -237,7 +236,7 @@ def test_attention_block_default(self): num_head_channels=1, rescale_output_factor=1.0, eps=1e-6, - num_groups=32, + norm_num_groups=32, ).to(torch_device) with torch.no_grad(): attention_scores = attentionBlock(sample) @@ -261,7 +260,7 @@ def test_attention_block_sd(self): channels=512, rescale_output_factor=1.0, eps=1e-6, - num_groups=32, + norm_num_groups=32, ).to(torch_device) with torch.no_grad(): attention_scores = attentionBlock(sample) @@ -284,13 +283,13 @@ def test_spatial_transformer_default(self): sample = torch.randn(1, 32, 64, 64).to(torch_device) spatial_transformer_block = Transformer2DModel( in_channels=32, - n_heads=1, - d_head=32, + num_attention_heads=1, + attention_head_dim=32, dropout=0.0, - context_dim=None, + cross_attention_dim=None, ).to(torch_device) with torch.no_grad(): - attention_scores = spatial_transformer_block(sample) + attention_scores = spatial_transformer_block(sample).sample assert attention_scores.shape == (1, 32, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -300,7 +299,7 @@ def test_spatial_transformer_default(self): ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - def test_spatial_transformer_context_dim(self): + def test_spatial_transformer_cross_attention_dim(self): torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) @@ -308,14 +307,14 @@ def test_spatial_transformer_context_dim(self): sample = torch.randn(1, 64, 64, 64).to(torch_device) spatial_transformer_block = Transformer2DModel( in_channels=64, - n_heads=2, - d_head=32, + num_attention_heads=2, + attention_head_dim=32, dropout=0.0, - context_dim=64, + cross_attention_dim=64, ).to(torch_device) with torch.no_grad(): context = torch.randn(1, 4, 64).to(torch_device) - attention_scores = spatial_transformer_block(sample, context) + attention_scores = spatial_transformer_block(sample, context).sample assert attention_scores.shape == (1, 64, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -335,18 +334,17 @@ def test_spatial_transformer_timestep(self): sample = torch.randn(1, 64, 64, 64).to(torch_device) spatial_transformer_block = Transformer2DModel( in_channels=64, - n_heads=2, - d_head=32, + num_attention_heads=2, + attention_head_dim=32, dropout=0.0, - context_dim=64, - norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], + cross_attention_dim=64, num_embeds_ada_norm=num_embeds_ada_norm, ).to(torch_device) with torch.no_grad(): timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device) timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device) - attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1) - attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2) + attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample + attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample assert attention_scores_1.shape == (1, 64, 64, 64) assert attention_scores_2.shape == (1, 64, 64, 64) @@ -373,16 +371,16 @@ def test_spatial_transformer_dropout(self): spatial_transformer_block = ( Transformer2DModel( in_channels=32, - n_heads=2, - d_head=16, + num_attention_heads=2, + attention_head_dim=16, dropout=0.3, - context_dim=None, + cross_attention_dim=None, ) .to(torch_device) .eval() ) with torch.no_grad(): - attention_scores = spatial_transformer_block(sample) + attention_scores = spatial_transformer_block(sample).sample assert attention_scores.shape == (1, 32, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -403,31 +401,27 @@ def test_spatial_transformer_discrete(self): sample = torch.randint(0, num_embed, (1, 32)).to(torch_device) spatial_transformer_block = ( Transformer2DModel( - n_heads=1, - d_head=32, - discrete=True, - num_embed=num_embed, - height=16, - width=2, + num_attention_heads=1, + attention_head_dim=32, + num_vector_embeds=num_embed, + sample_size=16, ) .to(torch_device) .eval() ) with torch.no_grad(): - attention_scores = spatial_transformer_block(sample) + attention_scores = spatial_transformer_block(sample).sample assert attention_scores.shape == (1, num_embed - 1, 32) - output_slice = attention_scores[0, -3:, -3:] + output_slice = attention_scores[0, -2:, -3:] - expected_slice = torch.tensor( - [-1.4105, -1.0337, -1.4915, -1.8912, -1.1228, -1.3155, -1.9766, -1.9487, -1.1841], device=torch_device - ) + expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) def test_spatial_transformer_default_norm_layers(self): - spatial_transformer_block = Transformer2DModel(n_heads=1, d_head=32, in_channels=32) + spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32) assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm @@ -435,10 +429,9 @@ def test_spatial_transformer_default_norm_layers(self): def test_spatial_transformer_ada_norm_layers(self): spatial_transformer_block = Transformer2DModel( - n_heads=1, - d_head=32, + num_attention_heads=1, + attention_head_dim=32, in_channels=32, - norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], num_embeds_ada_norm=5, ) @@ -446,21 +439,10 @@ def test_spatial_transformer_ada_norm_layers(self): assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm - def test_spatial_transformer_ada_norm_layers_requires_num_embeds_ada_norm(self): - with pytest.raises(Exception) as e_info: - Transformer2DModel( - n_heads=1, - d_head=32, - in_channels=32, - norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], - ) - - assert e_info.value.args[0] == "When using AdaLayerNorm, you must also pass num_embeds_ada_norm." - def test_spatial_transformer_default_ff_layers(self): spatial_transformer_block = Transformer2DModel( - n_heads=1, - d_head=32, + num_attention_heads=1, + attention_head_dim=32, in_channels=32, ) @@ -480,51 +462,33 @@ def test_spatial_transformer_default_ff_layers(self): assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim - def test_spatial_transformer_vq_diffusion_ff_layers(self): + def test_spatial_transformer_geglu_approx_ff_layers(self): spatial_transformer_block = Transformer2DModel( - n_heads=1, - d_head=32, + num_attention_heads=1, + attention_head_dim=32, in_channels=32, - ff_layers=["Linear", "ApproximateGELU", "Linear", "Dropout"], + activation_fn="geglu-approximate", ) + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU + assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + dim = 32 inner_dim = 128 - assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == nn.Linear - assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == ApproximateGELU - assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear - assert spatial_transformer_block.transformer_blocks[0].ff.net[3].__class__ == nn.Dropout - - assert spatial_transformer_block.transformer_blocks[0].ff.net[0].in_features == dim - assert spatial_transformer_block.transformer_blocks[0].ff.net[0].out_features == inner_dim + # First dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim + # Second dimension change assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim - def test_spatial_transformer_ff_layers_too_few_dim_changes(self): - with pytest.raises(Exception) as e_info: - Transformer2DModel(n_heads=1, d_head=32, in_channels=32, ff_layers=["Linear"]) - - assert ( - e_info.value.args[0] - == "Too few dimension changes. FeedForward must have exactly two dimension changing layers (Linear and" - " GEGLU)." - ) - - def test_spatial_transformer_ff_layers_too_many_dim_changes(self): - for layer in ["Linear", "GEGLU"]: - with pytest.raises(Exception) as e_info: - Transformer2DModel(n_heads=1, d_head=32, in_channels=32, ff_layers=[layer] * 3) - - assert ( - e_info.value.args[0] - == "Too many dimension changes. FeedForward must have exactly two dimension changing layers (Linear" - " and GEGLU)." - ) - def test_spatial_transformer_attention_bias(self): - spatial_transformer_block = Transformer2DModel(n_heads=1, d_head=32, in_channels=32, attention_bias=True) + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True + ) assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 038278ee8748..29186aaac99b 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -91,9 +91,9 @@ def check_over_configs(self, time_step=0, **config): scheduler = scheduler_class(**scheduler_config) if scheduler_class == VQDiffusionScheduler: - num_embed = scheduler_config["num_embed"] - sample = self.dummy_sample(num_embed) - model = self.dummy_model(num_embed) + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) residual = model(sample, time_step) else: sample = self.dummy_sample @@ -134,9 +134,9 @@ def check_over_forward(self, time_step=0, **forward_kwargs): scheduler = scheduler_class(**scheduler_config) if scheduler_class == VQDiffusionScheduler: - num_embed = scheduler_config["num_embed"] - sample = self.dummy_sample(num_embed) - model = self.dummy_model(num_embed) + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) residual = model(sample, time_step) else: sample = self.dummy_sample @@ -176,9 +176,9 @@ def test_from_pretrained_save_pretrained(self): scheduler = scheduler_class(**scheduler_config) if scheduler_class == VQDiffusionScheduler: - num_embed = scheduler_config["num_embed"] - sample = self.dummy_sample(num_embed) - model = self.dummy_model(num_embed) + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) residual = model(sample, timestep) else: sample = self.dummy_sample @@ -221,9 +221,9 @@ def test_step_shape(self): scheduler = scheduler_class(**scheduler_config) if scheduler_class == VQDiffusionScheduler: - num_embed = scheduler_config["num_embed"] - sample = self.dummy_sample(num_embed) - model = self.dummy_model(num_embed) + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) residual = model(sample, timestep_0) else: sample = self.dummy_sample @@ -282,9 +282,9 @@ def recursive_check(tuple_object, dict_object): scheduler = scheduler_class(**scheduler_config) if scheduler_class == VQDiffusionScheduler: - num_embed = scheduler_config["num_embed"] - sample = self.dummy_sample(num_embed) - model = self.dummy_model(num_embed) + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) residual = model(sample, timestep) else: sample = self.dummy_sample @@ -1281,19 +1281,19 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest): def get_scheduler_config(self, **kwargs): config = { - "num_embed": 4097, + "num_vec_classes": 4097, "num_train_timesteps": 100, } config.update(**kwargs) return config - def dummy_sample(self, num_embed): + def dummy_sample(self, num_vec_classes): batch_size = 4 height = 8 width = 8 - sample = torch.randint(0, num_embed, (batch_size, height * width)) + sample = torch.randint(0, num_vec_classes, (batch_size, height * width)) return sample @@ -1301,10 +1301,10 @@ def dummy_sample(self, num_embed): def dummy_sample_deter(self): assert False - def dummy_model(self, num_embed): + def dummy_model(self, num_vec_classes): def model(sample, t, *args): batch_size, num_latent_pixels = sample.shape - logits = torch.rand((batch_size, num_embed - 1, num_latent_pixels)) + logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels)) return_value = F.log_softmax(logits.double(), dim=1).float() return return_value @@ -1314,9 +1314,9 @@ def test_timesteps(self): for timesteps in [2, 5, 100, 1000]: self.check_over_configs(num_train_timesteps=timesteps) - def test_num_embed(self): - for num_embed in [5, 100, 1000, 4000]: - self.check_over_configs(num_embed=num_embed) + def test_num_vec_classes(self): + for num_vec_classes in [5, 100, 1000, 4000]: + self.check_over_configs(num_vec_classes=num_vec_classes) def test_time_indices(self): for t in [0, 50, 99]: From cc303e48db0506a0f110c7151f45e58fa49502da Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 13:54:36 +0100 Subject: [PATCH 17/21] Apply suggestions from code review Co-authored-by: Anton Lozhkov --- .../vq_diffusion/pipeline_vq_diffusion.py | 14 ++++++++++++++ .../schedulers/scheduling_vq_diffusion.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 104adb62a9eb..9cf8ae455db8 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -1,3 +1,17 @@ +# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Callable, List, Optional, Tuple, Union import torch diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index c4f039b7bde1..dbe320d998a3 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -1,3 +1,17 @@ +# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Optional, Tuple, Union From 303a365b8fb7c736eb9a7eb48709974cec452d7a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 15:26:07 +0100 Subject: [PATCH 18/21] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 9cf8ae455db8..7655f5926f38 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -115,7 +115,7 @@ def __call__( The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be @@ -125,7 +125,7 @@ def __call__( called at every step. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ From 963c71afcb382d28d306ec5b6d8a9450f2dc8b2b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 15:46:09 +0100 Subject: [PATCH 19/21] finish --- docs/source/api/models.mdx | 6 ++ docs/source/api/pipelines/vq_diffusion.mdx | 1 - docs/source/index.mdx | 2 + src/diffusers/models/attention.py | 93 +++++++++---------- .../vq_diffusion/pipeline_vq_diffusion.py | 5 +- 5 files changed, 53 insertions(+), 54 deletions(-) diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index c3f5e65edfbd..2e1e8798a7ef 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -49,6 +49,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## AutoencoderKL [[autodoc]] AutoencoderKL +## Transformer2DModel +[[autodoc]] Transformer2DModel + +## Transformer2DModelOutput +[[autodoc]] models.attention.Transformer2DModelOutput + ## FlaxModelMixin [[autodoc]] FlaxModelMixin diff --git a/docs/source/api/pipelines/vq_diffusion.mdx b/docs/source/api/pipelines/vq_diffusion.mdx index c2965c47d5fe..92cc903eee79 100644 --- a/docs/source/api/pipelines/vq_diffusion.mdx +++ b/docs/source/api/pipelines/vq_diffusion.mdx @@ -32,4 +32,3 @@ The original codebase can be found [here](https://github.com/microsoft/VQ-Diffus ## VQDiffusionPipeline [[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline - __call__ - diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 392b22399908..62a3e88f173d 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -34,6 +34,7 @@ available a colab notebook to directly try them out. | Pipeline | Paper | Tasks | Colab |---|---|:---:|:---:| +| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | @@ -45,5 +46,6 @@ available a colab notebook to directly try them out. | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 104e113e246f..b8f569e46742 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -61,32 +61,26 @@ class Transformer2DModel(ModelMixin, ConfigMixin): image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. Parameters: - num_attention_heads (:obj:`int`): The number of heads to use for multi-head attention. - attention_head_dim (:obj:`int`): The number of channels in each head. - in_channels (: - obj:`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. - num_layers (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. - cross_attention_dim (:obj:`int`, *optional*): The number of context dimensions to use. - discrete (: - obj:`bool`, *optional*, defaults to False): Set to True if the input is discrete i.e. over classes of - vector embeddings for each pixel. See the beginning of the docstring for a more in-num_layers description. - height (:obj:`int`, *optional*): Pass if the input is discrete. The height of the latent images. + num_attention_heads (`int`, *optional*): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. - width (:obj:`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (: - obj:`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of - the latent pixels. Includes the class for the masked latent pixel. - activation_fn (:obj:`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (:obj: `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (: - obj: `bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. """ @register_to_config @@ -176,13 +170,13 @@ def _set_attention_slice(self, slice_size): def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): """ Args: - hidden_states (:obj: When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states - encoder_hidden_states (:obj: `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. - timestep (:obj: `torch.long`, *optional*): + timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -239,12 +233,12 @@ class AttentionBlock(nn.Module): Uses three q, k, v linear layers to compute attention. Parameters: - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): + channels (`int`): The number of channels in the input and output. + num_head_channels (`int`, *optional*): The number of channels in each head. If None, then `num_heads` = 1. - norm_num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ def __init__( @@ -321,17 +315,16 @@ class BasicTransformerBlock(nn.Module): A basic Transformer block. Parameters: - dim (:obj:`int`): The number of channels in the input and output. - num_attention_heads (:obj:`int`): The number of heads to use for multi-head attention. - attention_head_dim (:obj:`int`): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. - activation_fn (:obj:`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: - obj: `bool`, *optional*, defaults to :obj:`False`): Configure if the attentions should contain a bias - parameter. + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( @@ -427,13 +420,13 @@ class CrossAttention(nn.Module): A cross attention layer. Parameters: - query_dim (:obj:`int`): The number of channels in the query. - cross_attention_dim (:obj:`int`, *optional*): + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): The number of channels in the context. If not given, defaults to `query_dim`. - heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - bias (:obj:`bool`, *optional*, defaults to False): + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ @@ -572,11 +565,11 @@ class FeedForward(nn.Module): A feed-forward layer. Parameters: - dim (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (:obj:`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. """ def __init__( @@ -616,8 +609,8 @@ class GEGLU(nn.Module): A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: - dim_in (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`): The number of channels in the output. + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 7655f5926f38..6e5325ba7ef5 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -45,7 +45,7 @@ class VQDiffusionPipeline(DiffusionPipeline): tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - transformer (`Transformer2DModel`): + transformer ([`Transformer2DModel`]): Conditional transformer to denoise the encoded image latents. scheduler ([`VQDiffusionScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. @@ -115,8 +115,7 @@ def __call__( The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a - plain tuple. + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. From 213af1f8dbf819c45160f2c8d4db289d9cdd1762 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 15:50:04 +0100 Subject: [PATCH 20/21] finish --- src/diffusers/models/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b8f569e46742..bac85e2f39cf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -30,7 +30,7 @@ class Transformer2DModelOutput(BaseOutput): """ Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if Transformer2DModel is discrete): + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """ @@ -61,8 +61,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin): image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. Parameters: - num_attention_heads (`int`, *optional*): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. From 8d2d88f5efa5381a871dbae42ebba9b0632cd9cd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 3 Nov 2022 16:05:57 +0100 Subject: [PATCH 21/21] up --- .../vq_diffusion/test_vq_diffusion.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index ed762dd35acb..5eb32d40d4f3 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -56,6 +56,7 @@ def dummy_vqvae(self): up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], latent_channels=3, num_vq_embeddings=self.num_embed, + vq_embed_dim=3, ) return model @@ -87,19 +88,19 @@ def dummy_transformer(self): height = 12 width = 12 - model = Transformer2DModel( - n_heads=1, - d_head=height * width, - context_dim=32, - discrete=True, - num_embed=self.num_embed, - height=height, - width=width, - num_embeds_ada_norm=self.num_embeds_ada_norm, - ff_layers=["Linear", "ApproximateGELU", "Linear", "Dropout"], - norm_layers=["AdaLayerNorm", "AdaLayerNorm", "LayerNorm"], - attention_bias=True, - ) + model_kwargs = { + "attention_bias": True, + "cross_attention_dim": 32, + "attention_head_dim": height * width, + "num_attention_heads": 1, + "num_vector_embeds": self.num_embed, + "num_embeds_ada_norm": self.num_embeds_ada_norm, + "norm_num_groups": 32, + "sample_size": width, + "activation_fn": "geglu-approximate", + } + + model = Transformer2DModel(**model_kwargs) return model def test_vq_diffusion(self):