diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 331d4fff7870..70d64b80de7a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -96,6 +96,8 @@ title: "Stochastic Karras VE" - local: api/pipelines/dance_diffusion title: "Dance Diffusion" + - local: api/pipelines/vq_diffusion + title: "VQ Diffusion" - local: api/pipelines/repaint title: "RePaint" title: "Pipelines" diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index c3f5e65edfbd..2e1e8798a7ef 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -49,6 +49,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## AutoencoderKL [[autodoc]] AutoencoderKL +## Transformer2DModel +[[autodoc]] Transformer2DModel + +## Transformer2DModelOutput +[[autodoc]] models.attention.Transformer2DModelOutput + ## FlaxModelMixin [[autodoc]] FlaxModelMixin diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index a53a2f8b4137..5a15473cf19c 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -41,22 +41,22 @@ If you are looking for *official* training examples, please have a look at [exam The following table summarizes all officially supported pipelines, their corresponding paper, and if available a colab notebook to directly try them out. -| Pipeline | Paper | Tasks | Colab -|------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------:|:---:| -| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | -| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Text-to-Image Generation | -| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | -| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | -| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) -| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| Pipeline | Paper | Tasks | Colab +|---|---|:---:|:---:| +| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | +| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | +| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | +| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | +| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [vq_diffusion](./vq_diffusion) | [**Vector Quantized Diffusion Model for Text-to-Image Synthesis**](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | | [repaint](./repaint) | [**RePaint: Inpainting using Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2201.09865) | Image Inpainting | - **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below. diff --git a/docs/source/api/pipelines/vq_diffusion.mdx b/docs/source/api/pipelines/vq_diffusion.mdx new file mode 100644 index 000000000000..92cc903eee79 --- /dev/null +++ b/docs/source/api/pipelines/vq_diffusion.mdx @@ -0,0 +1,34 @@ + + +# VQDiffusion + +## Overview + +[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo + +The abstract of the paper is the following: + +We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality. + +The original codebase can be found [here](https://github.com/microsoft/VQ-Diffusion). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_vq_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py) | *Text-to-Image Generation* | - | + + +## VQDiffusionPipeline +[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline + - __call__ diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 6e7da10e33e5..f073f6b37912 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -113,7 +113,6 @@ Score SDE-VP is under construction. [[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler - #### Euler scheduler Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson. @@ -130,6 +129,12 @@ Fast scheduler which often times generates good outputs with 20-30 steps. [[autodoc]] EulerAncestralDiscreteScheduler +#### VQDiffusionScheduler + +Original paper can be found [here](https://arxiv.org/abs/2111.14822) + +[[autodoc]] VQDiffusionScheduler + #### RePaint scheduler DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks. @@ -137,4 +142,4 @@ Intended for use with [`RePaintPipeline`]. Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint -[[autodoc]] RePaintScheduler \ No newline at end of file +[[autodoc]] RePaintScheduler diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 392b22399908..62a3e88f173d 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -34,6 +34,7 @@ available a colab notebook to directly try them out. | Pipeline | Paper | Tasks | Colab |---|---|:---:|:---:| +| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | @@ -45,5 +46,6 @@ available a colab notebook to directly try them out. | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py new file mode 100644 index 000000000000..ae105e30362e --- /dev/null +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -0,0 +1,885 @@ +""" +This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers. + +It currently only supports porting the ITHQ dataset. + +ITHQ dataset: +```sh +# From the root directory of diffusers. + +# Download the VQVAE checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth + +# Download the VQVAE config +# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class +# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE` +# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml` +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml + +# Download the main model checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth + +# Download the main model config +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml + +# run the convert script +$ python ./scripts/convert_vq_diffusion_to_diffusers.py \ + --checkpoint_path ./ithq_learnable.pth \ + --original_config_file ./ithq.yaml \ + --vqvae_checkpoint_path ./ithq_vqvae.pth \ + --vqvae_original_config_file ./ithq_vqvae.yaml \ + --dump_path +``` +""" + +import argparse +import tempfile + +import torch + +import yaml +from accelerate import init_empty_weights, load_checkpoint_and_dispatch +from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.models.attention import Transformer2DModel +from transformers import CLIPTextModel, CLIPTokenizer +from yaml.loader import FullLoader + + +try: + from omegaconf import OmegaConf +except ImportError: + raise ImportError( + "OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install" + " OmegaConf`." + ) + +# vqvae model + +PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] + + +def vqvae_model_from_original_config(original_config): + assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers." + + original_config = original_config.params + + original_encoder_config = original_config.encoder_config.params + original_decoder_config = original_config.decoder_config.params + + in_channels = original_encoder_config.in_channels + out_channels = original_decoder_config.out_ch + + down_block_types = get_down_block_types(original_encoder_config) + up_block_types = get_up_block_types(original_decoder_config) + + assert original_encoder_config.ch == original_decoder_config.ch + assert original_encoder_config.ch_mult == original_decoder_config.ch_mult + block_out_channels = tuple( + [original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult] + ) + + assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks + layers_per_block = original_encoder_config.num_res_blocks + + assert original_encoder_config.z_channels == original_decoder_config.z_channels + latent_channels = original_encoder_config.z_channels + + num_vq_embeddings = original_config.n_embed + + # Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion + norm_num_groups = 32 + + e_dim = original_config.embed_dim + + model = VQModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + latent_channels=latent_channels, + num_vq_embeddings=num_vq_embeddings, + norm_num_groups=norm_num_groups, + vq_embed_dim=e_dim, + ) + + return model + + +def get_down_block_types(original_encoder_config): + attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions) + num_resolutions = len(original_encoder_config.ch_mult) + resolution = coerce_resolution(original_encoder_config.resolution) + + curr_res = resolution + down_block_types = [] + + for _ in range(num_resolutions): + if curr_res in attn_resolutions: + down_block_type = "AttnDownEncoderBlock2D" + else: + down_block_type = "DownEncoderBlock2D" + + down_block_types.append(down_block_type) + + curr_res = [r // 2 for r in curr_res] + + return down_block_types + + +def get_up_block_types(original_decoder_config): + attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions) + num_resolutions = len(original_decoder_config.ch_mult) + resolution = coerce_resolution(original_decoder_config.resolution) + + curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution] + up_block_types = [] + + for _ in reversed(range(num_resolutions)): + if curr_res in attn_resolutions: + up_block_type = "AttnUpDecoderBlock2D" + else: + up_block_type = "UpDecoderBlock2D" + + up_block_types.append(up_block_type) + + curr_res = [r * 2 for r in curr_res] + + return up_block_types + + +def coerce_attn_resolutions(attn_resolutions): + attn_resolutions = OmegaConf.to_object(attn_resolutions) + attn_resolutions_ = [] + for ar in attn_resolutions: + if isinstance(ar, (list, tuple)): + attn_resolutions_.append(list(ar)) + else: + attn_resolutions_.append([ar, ar]) + return attn_resolutions_ + + +def coerce_resolution(resolution): + resolution = OmegaConf.to_object(resolution) + if isinstance(resolution, int): + resolution = [resolution, resolution] # H, W + elif isinstance(resolution, (tuple, list)): + resolution = list(resolution) + else: + raise ValueError("Unknown type of resolution:", resolution) + return resolution + + +# done vqvae model + +# vqvae checkpoint + + +def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint)) + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint)) + + return diffusers_checkpoint + + +def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + vqvae_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + vqvae_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"], + "decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + +def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # group_norm + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0, 0 + ], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + +# done vqvae checkpoint + +# transformer model + +PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"] +PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"] +PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"] + + +def transformer_model_from_original_config( + original_diffusion_config, original_transformer_config, original_content_embedding_config +): + assert ( + original_diffusion_config.target in PORTED_DIFFUSIONS + ), f"{original_diffusion_config.target} has not yet been ported to diffusers." + assert ( + original_transformer_config.target in PORTED_TRANSFORMERS + ), f"{original_transformer_config.target} has not yet been ported to diffusers." + assert ( + original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS + ), f"{original_content_embedding_config.target} has not yet been ported to diffusers." + + original_diffusion_config = original_diffusion_config.params + original_transformer_config = original_transformer_config.params + original_content_embedding_config = original_content_embedding_config.params + + inner_dim = original_transformer_config["n_embd"] + + n_heads = original_transformer_config["n_head"] + + # VQ-Diffusion gives dimension of the multi-headed attention layers as the + # number of attention heads times the sequence length (the dimension) of a + # single head. We want to specify our attention blocks with those values + # specified separately + assert inner_dim % n_heads == 0 + d_head = inner_dim // n_heads + + depth = original_transformer_config["n_layer"] + context_dim = original_transformer_config["condition_dim"] + + num_embed = original_content_embedding_config["num_embed"] + # the number of embeddings in the transformer includes the mask embedding. + # the content embedding (the vqvae) does not include the mask embedding. + num_embed = num_embed + 1 + + height = original_transformer_config["content_spatial_size"][0] + width = original_transformer_config["content_spatial_size"][1] + + assert width == height, "width has to be equal to height" + dropout = original_transformer_config["resid_pdrop"] + num_embeds_ada_norm = original_diffusion_config["diffusion_step"] + + model_kwargs = { + "attention_bias": True, + "cross_attention_dim": context_dim, + "attention_head_dim": d_head, + "num_layers": depth, + "dropout": dropout, + "num_attention_heads": n_heads, + "num_vector_embeds": num_embed, + "num_embeds_ada_norm": num_embeds_ada_norm, + "norm_num_groups": 32, + "sample_size": width, + "activation_fn": "geglu-approximate", + } + + model = Transformer2DModel(**model_kwargs) + return model + + +# done transformer model + +# transformer checkpoint + + +def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + transformer_prefix = "transformer.transformer" + + diffusers_latent_image_embedding_prefix = "latent_image_embedding" + latent_image_embedding_prefix = f"{transformer_prefix}.content_emb" + + # DalleMaskImageEmbedding + diffusers_checkpoint.update( + { + f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.emb.weight" + ], + f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.height_emb.weight" + ], + f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[ + f"{latent_image_embedding_prefix}.width_emb.weight" + ], + } + ) + + # transformer blocks + for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks): + diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}" + transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}" + + # ada norm block + diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1" + ada_norm_prefix = f"{transformer_block_prefix}.ln1" + + diffusers_checkpoint.update( + transformer_ada_norm_to_diffusers_checkpoint( + checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix + ) + ) + + # attention block + diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1" + attention_prefix = f"{transformer_block_prefix}.attn1" + + diffusers_checkpoint.update( + transformer_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # ada norm block + diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2" + ada_norm_prefix = f"{transformer_block_prefix}.ln1_1" + + diffusers_checkpoint.update( + transformer_ada_norm_to_diffusers_checkpoint( + checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix + ) + ) + + # attention block + diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2" + attention_prefix = f"{transformer_block_prefix}.attn2" + + diffusers_checkpoint.update( + transformer_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # norm block + diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3" + norm_block_prefix = f"{transformer_block_prefix}.ln2" + + diffusers_checkpoint.update( + { + f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"], + f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"], + } + ) + + # feedforward block + diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff" + feedforward_prefix = f"{transformer_block_prefix}.mlp" + + diffusers_checkpoint.update( + transformer_feedforward_to_diffusers_checkpoint( + checkpoint, + diffusers_feedforward_prefix=diffusers_feedforward_prefix, + feedforward_prefix=feedforward_prefix, + ) + ) + + # to logits + + diffusers_norm_out_prefix = "norm_out" + norm_out_prefix = f"{transformer_prefix}.to_logits.0" + + diffusers_checkpoint.update( + { + f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"], + f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"], + } + ) + + diffusers_out_prefix = "out" + out_prefix = f"{transformer_prefix}.to_logits.1" + + diffusers_checkpoint.update( + { + f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"], + f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"], + } + ) + + return diffusers_checkpoint + + +def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix): + return { + f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"], + f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"], + f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"], + } + + +def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # key + f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"], + f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"], + # query + f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"], + f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"], + # value + f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"], + f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"], + # linear out + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"], + } + + +def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix): + return { + f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"], + f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"], + f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"], + f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"], + } + + +# done transformer checkpoint + + +def read_config_file(filename): + # The yaml file contains annotations that certain values should + # loaded as tuples. By default, OmegaConf will panic when reading + # these. Instead, we can manually read the yaml with the FullLoader and then + # construct the OmegaConf object. + with open(filename) as f: + original_config = yaml.load(f, FullLoader) + + return OmegaConf.create(original_config) + + +# We take separate arguments for the vqvae because the ITHQ vqvae config file +# is separate from the config file for the rest of the model. +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--vqvae_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the vqvae checkpoint to convert.", + ) + + parser.add_argument( + "--vqvae_original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture for the vqvae.", + ) + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + + parser.add_argument( + "--original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + # See link for how ema weights are always selected + # https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65 + parser.add_argument( + "--no_use_ema", + action="store_true", + required=False, + help=( + "Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set" + " it as the original VQ-Diffusion always uses the ema weights when loading models." + ), + ) + + args = parser.parse_args() + + use_ema = not args.no_use_ema + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + # vqvae_model + + print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}") + + vqvae_original_config = read_config_file(args.vqvae_original_config_file).model + vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"] + + with init_empty_weights(): + vqvae_model = vqvae_model_from_original_config(vqvae_original_config) + + vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint) + + with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file: + torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name) + del vqvae_diffusers_checkpoint + del vqvae_checkpoint + load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto") + + print("done loading vqvae") + + # done vqvae_model + + # transformer_model + + print( + f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:" + f" {use_ema}" + ) + + original_config = read_config_file(args.original_config_file).model + + diffusion_config = original_config.params.diffusion_config + transformer_config = original_config.params.diffusion_config.params.transformer_config + content_embedding_config = original_config.params.diffusion_config.params.content_emb_config + + pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location) + + if use_ema: + if "ema" in pre_checkpoint: + checkpoint = {} + for k, v in pre_checkpoint["model"].items(): + checkpoint[k] = v + + for k, v in pre_checkpoint["ema"].items(): + # The ema weights are only used on the transformer. To mimic their key as if they came + # from the state_dict for the top level model, we prefix with an additional "transformer." + # See the source linked in the args.use_ema config for more information. + checkpoint[f"transformer.{k}"] = v + else: + print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.") + checkpoint = pre_checkpoint["model"] + else: + checkpoint = pre_checkpoint["model"] + + del pre_checkpoint + + with init_empty_weights(): + transformer_model = transformer_model_from_original_config( + diffusion_config, transformer_config, content_embedding_config + ) + + diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint( + transformer_model, checkpoint + ) + + with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: + torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) + del diffusers_transformer_checkpoint + del checkpoint + load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto") + + print("done loading transformer") + + # done transformer_model + + # text encoder + + print("loading CLIP text encoder") + + clip_name = "openai/clip-vit-base-patch32" + + # The original VQ-Diffusion specifies the pad value by the int used in the + # returned tokens. Each model uses `0` as the pad value. The transformers clip api + # specifies the pad value via the token before it has been tokenized. The `!` pad + # token is the same as padding with the `0` pad value. + pad_token = "!" + + tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") + + assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 + + text_encoder_model = CLIPTextModel.from_pretrained( + clip_name, + # `CLIPTextModel` does not support device_map="auto" + # device_map="auto" + ) + + print("done loading CLIP text encoder") + + # done text encoder + + # scheduler + + scheduler_model = VQDiffusionScheduler( + # the scheduler has the same number of embeddings as the transformer + num_vec_classes=transformer_model.num_vector_embeds + ) + + # done scheduler + + print(f"saving VQ diffusion model, path: {args.dump_path}") + + pipe = VQDiffusionPipeline( + vqvae=vqvae_model, + transformer=transformer_model, + tokenizer=tokenizer_model, + text_encoder=text_encoder_model, + scheduler=scheduler_model, + ) + pipe.save_pretrained(args.dump_path) + + print("done writing VQ diffusion model") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1a9a7d74ebee..00052109e32f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -38,6 +38,7 @@ PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, + VQDiffusionPipeline, ) from .schedulers import ( DDIMScheduler, @@ -50,6 +51,7 @@ RePaintScheduler, SchedulerMixin, ScoreSdeVeScheduler, + VQDiffusionScheduler, ) from .training_utils import EMAModel else: diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c5d53b2feb4b..5b101d169148 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,6 +16,7 @@ if is_torch_available(): + from .attention import Transformer2DModel from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 372c8492b485..bac85e2f39cf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,13 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F from torch import nn -from diffusers.utils.import_utils import is_xformers_available +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..models.embeddings import ImagePositionalEmbeddings +from ..utils import BaseOutput +from ..utils.import_utils import is_xformers_available + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor if is_xformers_available(): @@ -28,6 +45,186 @@ xformers = None +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = in_channels is not None + self.is_input_vectorized = num_vector_embeds is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized: + raise ValueError( + f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if self.is_input_continuous: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + + def _set_attention_slice(self, slice_size): + for block in self.transformer_blocks: + block._set_attention_slice(slice_size) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if self.is_input_continuous: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.transformer_blocks: + block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted @@ -36,19 +233,19 @@ class AttentionBlock(nn.Module): Uses three q, k, v linear layers to compute attention. Parameters: - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): + channels (`int`): The number of channels in the input and output. + num_head_channels (`int`, *optional*): The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ def __init__( self, channels: int, num_head_channels: Optional[int] = None, - num_groups: int = 32, + norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, ): @@ -57,7 +254,7 @@ def __init__( self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) @@ -113,107 +310,61 @@ def forward(self, hidden_states): return hidden_states -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Parameters: - in_channels (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The number of context dimensions to use. - """ - - def __init__( - self, - in_channels: int, - n_heads: int, - d_head: int, - depth: int = 1, - dropout: float = 0.0, - num_groups: int = 32, - context_dim: Optional[int] = None, - ): - super().__init__() - self.n_heads = n_heads - self.d_head = d_head - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth) - ] - ) - - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def _set_attention_slice(self, slice_size): - for block in self.transformer_blocks: - block._set_attention_slice(slice_size) - - def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for block in self.transformer_blocks: - block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - - def forward(self, hidden_states, context=None): - # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, width = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) - return hidden_states + residual - - class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: - dim (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. - gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. - checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, - n_heads: int, - d_head: int, + num_attention_heads: int, + attention_head_dim: int, dropout=0.0, - context_dim: Optional[int] = None, - gated_ff: bool = True, - checkpoint: bool = True, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, ): super().__init__() self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.attn2 = CrossAttention( - query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) + + # layer norms + self.use_ada_layer_norm = num_embeds_ada_norm is not None + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size @@ -245,10 +396,22 @@ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atte self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - def forward(self, hidden_states, context=None): - hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + def forward(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states @@ -257,20 +420,28 @@ class CrossAttention(nn.Module): A cross attention layer. Parameters: - query_dim (:obj:`int`): The number of channels in the query. - context_dim (:obj:`int`, *optional*): + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): The number of channels in the context. If not given, defaults to `query_dim`. - heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( - self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, ): super().__init__() inner_dim = dim_head * heads - context_dim = context_dim if context_dim is not None else query_dim + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.scale = dim_head**-0.5 self.heads = heads @@ -280,9 +451,9 @@ def __init__( self._slice_size = None self._use_memory_efficient_attention_xformers = False - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim)) @@ -394,23 +565,33 @@ class FeedForward(nn.Module): A feed-forward layer. Parameters: - dim (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. """ def __init__( - self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - self.net = nn.ModuleList([]) + if activation_fn == "geglu": + geglu = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + geglu = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) # project in - self.net.append(GEGLU(dim, inner_dim)) + self.net.append(geglu) # project dropout self.net.append(nn.Dropout(dropout)) # project out @@ -428,8 +609,8 @@ class GEGLU(nn.Module): A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: - dim_in (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`): The number of channels in the output. + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): @@ -445,3 +626,38 @@ def gelu(self, gate): def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1745265b91e1..1b8609474750 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -142,7 +142,7 @@ def __call__(self, hidden_states, context, deterministic=True): return hidden_states -class FlaxSpatialTransformer(nn.Module): +class FlaxTransformer2DModel(nn.Module): r""" A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: https://arxiv.org/pdf/1506.02025.pdf diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 35715e17fc47..b09d43fc2edf 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -126,3 +126,68 @@ def forward(self, x): else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index ae4fe2d8bba7..4132ccbd0ce0 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,7 @@ import torch from torch import nn -from .attention import AttentionBlock, SpatialTransformer +from .attention import AttentionBlock, Transformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -109,6 +109,19 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + ) + raise ValueError(f"{down_block_type} does not exist.") def get_up_block( @@ -200,6 +213,17 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + ) raise ValueError(f"{up_block_type} does not exist.") @@ -249,7 +273,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) resnets.append( @@ -325,13 +349,13 @@ def __init__( for _ in range(num_layers): attentions.append( - SpatialTransformer( - in_channels, + Transformer2DModel( attn_num_head_channels, in_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, ) ) resnets.append( @@ -374,7 +398,7 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states).sample hidden_states = resnet(hidden_states, temb) return hidden_states @@ -427,7 +451,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -506,13 +530,13 @@ def __init__( ) ) attentions.append( - SpatialTransformer( - out_channels, + Transformer2DModel( attn_num_head_channels, out_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) @@ -556,19 +580,22 @@ def forward(self, hidden_states, temb=None, encoder_hidden_states=None): for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): - return module(*inputs) + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn), hidden_states, encoder_hidden_states - ) + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample output_states += (hidden_states,) @@ -763,7 +790,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -1014,7 +1041,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) @@ -1089,13 +1116,13 @@ def __init__( ) ) attentions.append( - SpatialTransformer( - out_channels, + Transformer2DModel( attn_num_head_channels, out_channels // attn_num_head_channels, - depth=1, - context_dim=cross_attention_dim, - num_groups=resnet_groups, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) @@ -1145,19 +1172,22 @@ def forward( if self.training and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): - return module(*inputs) + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn), hidden_states, encoder_hidden_states - ) + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1337,7 +1367,7 @@ def __init__( num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - num_groups=resnet_groups, + norm_num_groups=resnet_groups, ) ) diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index baa71beabe35..5798385b9d28 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -15,7 +15,7 @@ import flax.linen as nn import jax.numpy as jnp -from .attention_flax import FlaxSpatialTransformer +from .attention_flax import FlaxTransformer2DModel from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D @@ -63,7 +63,7 @@ def setup(self): ) resnets.append(res_block) - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, @@ -196,7 +196,7 @@ def setup(self): ) resnets.append(res_block) - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, @@ -326,7 +326,7 @@ def setup(self): attentions = [] for _ in range(self.num_layers): - attn_block = FlaxSpatialTransformer( + attn_block = FlaxTransformer2DModel( in_channels=self.in_channels, n_heads=self.attn_num_head_channels, d_head=self.in_channels // self.attn_num_head_channels, diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 5f5a47dada0f..30de343d08ee 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -233,14 +233,16 @@ class VectorQuantizer(nn.Module): # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + def __init__( + self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True + ): super().__init__() self.n_e = n_e - self.e_dim = e_dim + self.vq_embed_dim = vq_embed_dim self.beta = beta self.legacy = legacy - self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) self.remap = remap @@ -287,7 +289,7 @@ def unmap_to_all(self, inds): def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.e_dim) + z_flattened = z.view(-1, self.vq_embed_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = ( @@ -409,6 +411,7 @@ class VQModel(ModelMixin, ConfigMixin): latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): TODO num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. """ @register_to_config @@ -425,6 +428,7 @@ def __init__( sample_size: int = 32, num_vq_embeddings: int = 256, norm_num_groups: int = 32, + vq_embed_dim: Optional[int] = None, ): super().__init__() @@ -440,11 +444,11 @@ def __init__( double_z=False, ) - self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - self.quantize = VectorQuantizer( - num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False - ) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8015d4e11466..bb3440b2bfbc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -21,6 +21,7 @@ StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, ) + from .vq_diffusion import VQDiffusionPipeline if is_transformers_available() and is_onnx_available(): from .stable_diffusion import ( diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py new file mode 100644 index 000000000000..edf6f570f5bf --- /dev/null +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -0,0 +1 @@ +from .pipeline_vq_diffusion import VQDiffusionPipeline diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py new file mode 100644 index 000000000000..6e5325ba7ef5 --- /dev/null +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -0,0 +1,253 @@ +# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple, Union + +import torch + +from diffusers import Transformer2DModel, VQModel +from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler +from transformers import CLIPTextModel, CLIPTokenizer + +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VQDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using VQ Diffusion + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vqvae ([`VQModel`]): + Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent + representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. VQ Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + transformer ([`Transformer2DModel`]): + Conditional transformer to denoise the encoded image latents. + scheduler ([`VQDiffusionScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + vqvae: VQModel + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + transformer: Transformer2DModel + scheduler: VQDiffusionScheduler + + def __init__( + self, + vqvae: VQModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: Transformer2DModel, + scheduler: VQDiffusionScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + num_inference_steps: int = 100, + truncation_rate: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): + Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at + most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above + `truncation_rate` are set to zero. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor` of shape (batch), *optional*): + Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. + Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will + be generated of completely masked latent pixels. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + batch_size = batch_size * num_images_per_prompt + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + # get the initial completely masked latents unless the user supplied it + + latents_shape = (batch_size, self.transformer.num_latent_pixels) + if latents is None: + mask_class = self.transformer.num_vector_embeds - 1 + latents = torch.full(latents_shape, mask_class).to(self.device) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): + raise ValueError( + "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," + f" {self.transformer.num_vector_embeds - 1} (inclusive)." + ) + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + sample = latents + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # predict the un-noised image + # model_output == `log_p_x_0` + model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample + + model_output = self.truncate(model_output, truncation_rate) + + # remove `log(0)`'s (`-inf`s) + model_output = model_output.clamp(-70) + + # compute the previous noisy sample x_t -> x_t-1 + sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + embedding_channels = self.vqvae.config.vq_embed_dim + embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) + embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) + image = self.vqvae.decode(embeddings, force_not_quantize=True).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor: + """ + Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The + lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. + """ + sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) + sorted_p_x_0 = torch.exp(sorted_log_p_x_0) + keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate + + # Ensure that at least the largest probability is not zeroed out + all_true = torch.full_like(keep_mask[:, 0:1, :], True) + keep_mask = torch.cat((all_true, keep_mask), dim=1) + keep_mask = keep_mask[:, :-1, :] + + keep_mask = keep_mask.gather(1, indices.argsort(1)) + + rv = log_p_x_0.clone() + + rv[~keep_mask] = -torch.inf # -inf = log(0) + + return rv diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a1915ed8d2b1..1be541ba8b66 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -28,6 +28,7 @@ from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin + from .scheduling_vq_diffusion import VQDiffusionScheduler else: from ..utils.dummy_pt_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py new file mode 100644 index 000000000000..dbe320d998a3 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -0,0 +1,494 @@ +# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class VQDiffusionSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.LongTensor + + +def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor: + """ + Convert batch of vector of class indices into batch of log onehot vectors + + Args: + x (`torch.LongTensor` of shape `(batch size, vector length)`): + Batch of class indices + + num_classes (`int`): + number of classes to be used for the onehot vectors + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes, vector length)`: + Log onehot vectors + """ + x_onehot = F.one_hot(x, num_classes) + x_onehot = x_onehot.permute(0, 2, 1) + log_x = torch.log(x_onehot.float().clamp(min=1e-30)) + return log_x + + +def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor: + """ + Apply gumbel noise to `logits` + """ + uniform = torch.rand(logits.shape, device=logits.device, generator=generator) + gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) + noised = gumbel_noise + logits + return noised + + +def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): + """ + Cumulative and non-cumulative alpha schedules. + + See section 4.1. + """ + att = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) + + alpha_cum_start + ) + att = np.concatenate(([1], att)) + at = att[1:] / att[:-1] + att = np.concatenate((att[1:], [1])) + return at, att + + +def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): + """ + Cumulative and non-cumulative gamma schedules. + + See section 4.1. + """ + ctt = ( + np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) + + gamma_cum_start + ) + ctt = np.concatenate(([0], ctt)) + one_minus_ctt = 1 - ctt + one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] + ct = 1 - one_minus_ct + ctt = np.concatenate((ctt[1:], [0])) + return ct, ctt + + +class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image. + + The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous + diffusion timestep. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2111.14822 + + Args: + num_vec_classes (`int`): + The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked + latent pixel. + + num_train_timesteps (`int`): + Number of diffusion steps used to train the model. + + alpha_cum_start (`float`): + The starting cumulative alpha value. + + alpha_cum_end (`float`): + The ending cumulative alpha value. + + gamma_cum_start (`float`): + The starting cumulative gamma value. + + gamma_cum_end (`float`): + The ending cumulative gamma value. + """ + + @register_to_config + def __init__( + self, + num_vec_classes: int, + num_train_timesteps: int = 100, + alpha_cum_start: float = 0.99999, + alpha_cum_end: float = 0.000009, + gamma_cum_start: float = 0.000009, + gamma_cum_end: float = 0.99999, + ): + self.num_embed = num_vec_classes + + # By convention, the index for the mask class is the last class index + self.mask_class = self.num_embed - 1 + + at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) + ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) + + num_non_mask_classes = self.num_embed - 1 + bt = (1 - at - ct) / num_non_mask_classes + btt = (1 - att - ctt) / num_non_mask_classes + + at = torch.tensor(at.astype("float64")) + bt = torch.tensor(bt.astype("float64")) + ct = torch.tensor(ct.astype("float64")) + log_at = torch.log(at) + log_bt = torch.log(bt) + log_ct = torch.log(ct) + + att = torch.tensor(att.astype("float64")) + btt = torch.tensor(btt.astype("float64")) + ctt = torch.tensor(ctt.astype("float64")) + log_cumprod_at = torch.log(att) + log_cumprod_bt = torch.log(btt) + log_cumprod_ct = torch.log(ctt) + + self.log_at = log_at.float() + self.log_bt = log_bt.float() + self.log_ct = log_ct.float() + self.log_cumprod_at = log_cumprod_at.float() + self.log_cumprod_bt = log_cumprod_bt.float() + self.log_cumprod_ct = log_cumprod_ct.float() + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + device (`str` or `torch.device`): + device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on. + """ + self.num_inference_steps = num_inference_steps + timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.log_at = self.log_at.to(device) + self.log_bt = self.log_bt.to(device) + self.log_ct = self.log_ct.to(device) + self.log_cumprod_at = self.log_cumprod_at.to(device) + self.log_cumprod_bt = self.log_cumprod_bt.to(device) + self.log_cumprod_ct = self.log_cumprod_ct.to(device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.long, + sample: torch.LongTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[VQDiffusionSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the + docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed. + + Args: + log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + + t (`torch.long`): + The timestep that determines which transition matrices are used. + + x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t` + + generator: (`torch.Generator` or None): + RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from. + + return_dict (`bool`): + option for returning tuple rather than VQDiffusionSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + if timestep == 0: + log_p_x_t_min_1 = model_output + else: + log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) + + log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) + + x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) + + if not return_dict: + return (x_t_min_1,) + + return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) + + def q_posterior(self, log_p_x_0, x_t, t): + """ + Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). + + Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only + forward probabilities. + + Equation (11) stated in terms of forward probabilities via Equation (5): + + Where: + - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0) + + p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) + + Args: + log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): + The log probabilities for the predicted classes of the initial latent pixels. Does not include a + prediction for the masked class as the initial unnoised image cannot be masked. + + x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t` + + t (torch.Long): + The timestep that determines which transition matrix is used. + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: + The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). + """ + log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) + + log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True + ) + + log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( + t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False + ) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q = log_p_x_0 - log_q_x_t_given_x_0 + + # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , + # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) + q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) + + # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n + # . . . + # . . . + # . . . + # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n + q = q - q_log_sum_exp + + # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # . . . + # . . . + # . . . + # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} + # c_cumulative_{t-1} ... c_cumulative_{t-1} + q = self.apply_cumulative_transitions(q, t - 1) + + # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n + # . . . + # . . . + # . . . + # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n + # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 + log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp + + # For each column, there are two possible cases. + # + # Where: + # - sum(p_n(x_0))) is summing over all classes for x_0 + # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) + # - C_j is the class transitioning to + # + # 1. x_t is masked i.e. x_t = c_k + # + # Simplifying the expression, the column vector is: + # . + # . + # . + # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) + # . + # . + # . + # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) + # + # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. + # + # For the other rows, we can state the equation as ... + # + # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] + # + # This verifies the other rows. + # + # 2. x_t is not masked + # + # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: + # . + # . + # . + # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) + # . + # . + # . + # 0 + # + # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. + return log_p_x_t_min_1 + + def log_Q_t_transitioning_to_known_class( + self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool + ): + """ + Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each + latent pixel in `x_t`. + + See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix + is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs. + + Args: + t (torch.Long): + The timestep that determines which transition matrix is used. + + x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): + The classes of each latent pixel at time `t`. + + log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): + The log one-hot vectors of `x_t` + + cumulative (`bool`): + If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`, + we use the cumulative transition matrix `0`->`t`. + + Returns: + `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: + Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability + transition matrix. + + When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be + masked. + + Where: + - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. + - C_0 is a class of a latent pixel embedding + - C_k is the class of the masked latent pixel + + non-cumulative result (omitting logarithms): + ``` + q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) + . . . + . . . + . . . + q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) + ``` + + cumulative result (omitting logarithms): + ``` + q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) + . . . + . . . + . . . + q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) + ``` + """ + if cumulative: + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + else: + a = self.log_at[t] + b = self.log_bt[t] + c = self.log_ct[t] + + if not cumulative: + # The values in the onehot vector can also be used as the logprobs for transitioning + # from masked latent pixels. If we are not calculating the cumulative transitions, + # we need to save these vectors to be re-appended to the final matrix so the values + # aren't overwritten. + # + # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector + # if x_t is not masked + # + # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector + # if x_t is masked + log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) + + # `index_to_log_onehot` will add onehot vectors for masked pixels, + # so the default one hot matrix has one too many rows. See the doc string + # for an explanation of the dimensionality of the returned matrix. + log_onehot_x_t = log_onehot_x_t[:, :-1, :] + + # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. + # + # Don't worry about what values this sets in the columns that mark transitions + # to masked latent pixels. They are overwrote later with the `mask_class_mask`. + # + # Looking at the below logspace formula in non-logspace, each value will evaluate to either + # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column + # or + # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. + # + # See equation 7 for more details. + log_Q_t = (log_onehot_x_t + a).logaddexp(b) + + # The whole column of each masked pixel is `c` + mask_class_mask = x_t == self.mask_class + mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) + log_Q_t[mask_class_mask] = c + + if not cumulative: + log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) + + return log_Q_t + + def apply_cumulative_transitions(self, q, t): + bsz = q.shape[0] + a = self.log_cumprod_at[t] + b = self.log_cumprod_bt[t] + c = self.log_cumprod_ct[t] + + num_latent_pixels = q.shape[2] + c = c.expand(bsz, 1, num_latent_pixels) + + q = (q + a).logaddexp(b) + q = torch.cat((q, c), dim=1) + + return q diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 63aa20962f05..833f2b6c5074 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Transformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] @@ -257,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class VQDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -407,6 +437,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class VQDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EMAModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/vq_diffusion/__init__.py b/tests/pipelines/vq_diffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py new file mode 100644 index 000000000000..5eb32d40d4f3 --- /dev/null +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -0,0 +1,175 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.utils import load_image, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def num_embed(self): + return 12 + + @property + def num_embeds_ada_norm(self): + return 12 + + @property + def dummy_vqvae(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + num_vq_embeddings=self.num_embed, + vq_embed_dim=3, + ) + return model + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_transformer(self): + torch.manual_seed(0) + + height = 12 + width = 12 + + model_kwargs = { + "attention_bias": True, + "cross_attention_dim": 32, + "attention_head_dim": height * width, + "num_attention_heads": 1, + "num_vector_embeds": self.num_embed, + "num_embeds_ada_norm": self.num_embeds_ada_norm, + "norm_num_groups": 32, + "sample_size": width, + "activation_fn": "geglu-approximate", + } + + model = Transformer2DModel(**model_kwargs) + return model + + def test_vq_diffusion(self): + device = "cpu" + + vqvae = self.dummy_vqvae + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + transformer = self.dummy_transformer + scheduler = VQDiffusionScheduler(self.num_embed) + + pipe = VQDiffusionPipeline( + vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler + ) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + prompt = "teddy bear playing in the pool" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = pipe( + [prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2 + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 24, 24, 3) + + expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class VQDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_vq_diffusion(self): + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/vq_diffusion/teddy_bear_pool.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipeline( + "teddy bear playing in the pool", + truncation_rate=0.86, + num_images_per_prompt=1, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).max() < 1e-2 diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index cf531fbf3fd3..911ec548b3bd 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -18,8 +18,9 @@ import numpy as np import torch +from torch import nn -from diffusers.models.attention import AttentionBlock, SpatialTransformer +from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D from diffusers.utils import torch_device @@ -235,7 +236,7 @@ def test_attention_block_default(self): num_head_channels=1, rescale_output_factor=1.0, eps=1e-6, - num_groups=32, + norm_num_groups=32, ).to(torch_device) with torch.no_grad(): attention_scores = attentionBlock(sample) @@ -259,7 +260,7 @@ def test_attention_block_sd(self): channels=512, rescale_output_factor=1.0, eps=1e-6, - num_groups=32, + norm_num_groups=32, ).to(torch_device) with torch.no_grad(): attention_scores = attentionBlock(sample) @@ -273,22 +274,22 @@ def test_attention_block_sd(self): assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) -class SpatialTransformerTests(unittest.TestCase): +class Transformer2DModelTests(unittest.TestCase): def test_spatial_transformer_default(self): torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) sample = torch.randn(1, 32, 64, 64).to(torch_device) - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( in_channels=32, - n_heads=1, - d_head=32, + num_attention_heads=1, + attention_head_dim=32, dropout=0.0, - context_dim=None, + cross_attention_dim=None, ).to(torch_device) with torch.no_grad(): - attention_scores = spatial_transformer_block(sample) + attention_scores = spatial_transformer_block(sample).sample assert attention_scores.shape == (1, 32, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -298,22 +299,22 @@ def test_spatial_transformer_default(self): ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - def test_spatial_transformer_context_dim(self): + def test_spatial_transformer_cross_attention_dim(self): torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) sample = torch.randn(1, 64, 64, 64).to(torch_device) - spatial_transformer_block = SpatialTransformer( + spatial_transformer_block = Transformer2DModel( in_channels=64, - n_heads=2, - d_head=32, + num_attention_heads=2, + attention_head_dim=32, dropout=0.0, - context_dim=64, + cross_attention_dim=64, ).to(torch_device) with torch.no_grad(): context = torch.randn(1, 4, 64).to(torch_device) - attention_scores = spatial_transformer_block(sample, context) + attention_scores = spatial_transformer_block(sample, context).sample assert attention_scores.shape == (1, 64, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -323,6 +324,44 @@ def test_spatial_transformer_context_dim(self): ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + def test_spatial_transformer_timestep(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_embeds_ada_norm = 5 + + sample = torch.randn(1, 64, 64, 64).to(torch_device) + spatial_transformer_block = Transformer2DModel( + in_channels=64, + num_attention_heads=2, + attention_head_dim=32, + dropout=0.0, + cross_attention_dim=64, + num_embeds_ada_norm=num_embeds_ada_norm, + ).to(torch_device) + with torch.no_grad(): + timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device) + timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device) + attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample + attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample + + assert attention_scores_1.shape == (1, 64, 64, 64) + assert attention_scores_2.shape == (1, 64, 64, 64) + + output_slice_1 = attention_scores_1[0, -1, -3:, -3:] + output_slice_2 = attention_scores_2[0, -1, -3:, -3:] + + expected_slice_1 = torch.tensor( + [-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device + ) + expected_slice_2 = torch.tensor( + [-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device + ) + + assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3) + assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3) + def test_spatial_transformer_dropout(self): torch.manual_seed(0) if torch.cuda.is_available(): @@ -330,18 +369,18 @@ def test_spatial_transformer_dropout(self): sample = torch.randn(1, 32, 64, 64).to(torch_device) spatial_transformer_block = ( - SpatialTransformer( + Transformer2DModel( in_channels=32, - n_heads=2, - d_head=16, + num_attention_heads=2, + attention_head_dim=16, dropout=0.3, - context_dim=None, + cross_attention_dim=None, ) .to(torch_device) .eval() ) with torch.no_grad(): - attention_scores = spatial_transformer_block(sample) + attention_scores = spatial_transformer_block(sample).sample assert attention_scores.shape == (1, 32, 64, 64) output_slice = attention_scores[0, -1, -3:, -3:] @@ -350,3 +389,107 @@ def test_spatial_transformer_dropout(self): [-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + @unittest.skipIf(torch_device == "mps", "MPS does not support float64") + def test_spatial_transformer_discrete(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_embed = 5 + + sample = torch.randint(0, num_embed, (1, 32)).to(torch_device) + spatial_transformer_block = ( + Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + num_vector_embeds=num_embed, + sample_size=16, + ) + .to(torch_device) + .eval() + ) + + with torch.no_grad(): + attention_scores = spatial_transformer_block(sample).sample + + assert attention_scores.shape == (1, num_embed - 1, 32) + + output_slice = attention_scores[0, -2:, -3:] + + expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_spatial_transformer_default_norm_layers(self): + spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32) + + assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm + + def test_spatial_transformer_ada_norm_layers(self): + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + in_channels=32, + num_embeds_ada_norm=5, + ) + + assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm + assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm + + def test_spatial_transformer_default_ff_layers(self): + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + in_channels=32, + ) + + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU + assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + + dim = 32 + inner_dim = 128 + + # First dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim + # NOTE: inner_dim * 2 because GEGLU + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2 + + # Second dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim + + def test_spatial_transformer_geglu_approx_ff_layers(self): + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, + attention_head_dim=32, + in_channels=32, + activation_fn="geglu-approximate", + ) + + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU + assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + + dim = 32 + inner_dim = 128 + + # First dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim + + # Second dimension change + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim + + def test_spatial_transformer_attention_bias(self): + spatial_transformer_block = Transformer2DModel( + num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True + ) + + assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None + assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None + assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 9285eed20ff0..29186aaac99b 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torch.nn.functional as F from diffusers import ( DDIMScheduler, @@ -29,6 +30,7 @@ LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler, + VQDiffusionScheduler, ) from diffusers.utils import torch_device @@ -85,12 +87,18 @@ def check_over_configs(self, time_step=0, **config): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): time_step = float(time_step) - sample = self.dummy_sample - residual = 0.1 * sample - scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, time_step) + else: + sample = self.dummy_sample + residual = 0.1 * sample + with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) @@ -122,12 +130,18 @@ def check_over_forward(self, time_step=0, **forward_kwargs): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): time_step = float(time_step) - sample = self.dummy_sample - residual = 0.1 * sample - scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, time_step) + else: + sample = self.dummy_sample + residual = 0.1 * sample + with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) @@ -154,15 +168,21 @@ def test_from_pretrained_save_pretrained(self): num_inference_steps = kwargs.pop("num_inference_steps", None) for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample + timestep = 1 + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + timestep = float(timestep) scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - timestep = 1 - if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): - timestep = float(timestep) + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, timestep) + else: + sample = self.dummy_sample + residual = 0.1 * sample with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) @@ -200,8 +220,14 @@ def test_step_shape(self): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - sample = self.dummy_sample - residual = 0.1 * sample + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, timestep_0) + else: + sample = self.dummy_sample + residual = 0.1 * sample if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -255,8 +281,14 @@ def recursive_check(tuple_object, dict_object): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - sample = self.dummy_sample - residual = 0.1 * sample + if scheduler_class == VQDiffusionScheduler: + num_vec_classes = scheduler_config["num_vec_classes"] + sample = self.dummy_sample(num_vec_classes) + model = self.dummy_model(num_vec_classes) + residual = model(sample, timestep) + else: + sample = self.dummy_sample + residual = 0.1 * sample if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -284,22 +316,26 @@ def test_scheduler_public_api(self): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - self.assertTrue( - hasattr(scheduler, "init_noise_sigma"), - f"{scheduler_class} does not implement a required attribute `init_noise_sigma`", - ) - self.assertTrue( - hasattr(scheduler, "scale_model_input"), - f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`", - ) + + if scheduler_class != VQDiffusionScheduler: + self.assertTrue( + hasattr(scheduler, "init_noise_sigma"), + f"{scheduler_class} does not implement a required attribute `init_noise_sigma`", + ) + self.assertTrue( + hasattr(scheduler, "scale_model_input"), + f"{scheduler_class} does not implement a required class method `scale_model_input(sample," + " timestep)`", + ) self.assertTrue( hasattr(scheduler, "step"), f"{scheduler_class} does not implement a required class method `step(...)`", ) - sample = self.dummy_sample - scaled_sample = scheduler.scale_model_input(sample, 0.0) - self.assertEqual(sample.shape, scaled_sample.shape) + if scheduler_class != VQDiffusionScheduler: + sample = self.dummy_sample + scaled_sample = scheduler.scale_model_input(sample, 0.0) + self.assertEqual(sample.shape, scaled_sample.shape) def test_add_noise_device(self): for scheduler_class in self.scheduler_classes: @@ -1238,3 +1274,53 @@ def test_full_loop_no_noise(self): result_mean = torch.mean(torch.abs(sample)) assert abs(result_mean.item() - 2540529) < 10 + + +class VQDiffusionSchedulerTest(SchedulerCommonTest): + scheduler_classes = (VQDiffusionScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_vec_classes": 4097, + "num_train_timesteps": 100, + } + + config.update(**kwargs) + return config + + def dummy_sample(self, num_vec_classes): + batch_size = 4 + height = 8 + width = 8 + + sample = torch.randint(0, num_vec_classes, (batch_size, height * width)) + + return sample + + @property + def dummy_sample_deter(self): + assert False + + def dummy_model(self, num_vec_classes): + def model(sample, t, *args): + batch_size, num_latent_pixels = sample.shape + logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels)) + return_value = F.log_softmax(logits.double(), dim=1).float() + return return_value + + return model + + def test_timesteps(self): + for timesteps in [2, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_num_vec_classes(self): + for num_vec_classes in [5, 100, 1000, 4000]: + self.check_over_configs(num_vec_classes=num_vec_classes) + + def test_time_indices(self): + for t in [0, 50, 99]: + self.check_over_forward(time_step=t) + + def test_add_noise_device(self): + pass