diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py new file mode 100644 index 000000000000..3ba95f738ef7 --- /dev/null +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -0,0 +1,508 @@ +""" +This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers. + +It currently only supports porting the VQVAE for the ITHQ dataset. + +VQVAE for the ITHQ dataset: +```sh +# From the root directory in diffusers. + +# Download the checkpoint +$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth + +# Download the config +# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class +# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE` +# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml` +$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml + +# run the convert script +$ python ./scripts/convert_vq_diffusion_to_diffusers.py \ + --checkpoint_path ./ithq_vqvae.pth \ + --original_config_file ./ithq_vqvae.yaml \ + --dump_path \ + --only-vqvae +``` +""" + +import argparse +import torch +import yaml +from yaml.loader import FullLoader +from diffusers import VQModel +from diffusers.pipelines import VQDiffusionPipeline +import sys + +try: + from omegaconf import OmegaConf +except ImportError: + raise ImportError( + "OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install OmegaConf`." + ) + +############### vqvae model ################### + +PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] + +def vqvae_model_from_original_config(original_config): + assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers." + + original_config = original_config.params + + original_encoder_config = original_config.encoder_config.params + original_decoder_config = original_config.decoder_config.params + + in_channels = original_encoder_config.in_channels + out_channels = original_decoder_config.out_ch + + down_block_types = get_down_block_types(original_encoder_config) + up_block_types = get_up_block_types(original_decoder_config) + + assert(original_encoder_config.ch == original_decoder_config.ch) + assert(original_encoder_config.ch_mult == original_decoder_config.ch_mult) + block_out_channels = tuple([original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]) + + assert(original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks) + layers_per_block = original_encoder_config.num_res_blocks + + assert(original_encoder_config.z_channels == original_decoder_config.z_channels) + latent_channels = original_encoder_config.z_channels + + num_vq_embeddings = original_config.n_embed + + # Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion + norm_num_groups = 32 + + e_dim = original_config.embed_dim + conv_attention_block = True + + model = VQModel( + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + latent_channels=latent_channels, + num_vq_embeddings=num_vq_embeddings, + norm_num_groups=norm_num_groups, + e_dim=e_dim, + conv_attention_block=conv_attention_block + ) + + return model + +def get_down_block_types(original_encoder_config): + attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions) + num_resolutions = len(original_encoder_config.ch_mult) + resolution = coerce_resolution(original_encoder_config.resolution) + + curr_res = resolution + down_block_types = [] + + for _ in range(num_resolutions): + if curr_res in attn_resolutions: + down_block_type = "AttnDownEncoderBlock2D" + else: + down_block_type = "DownEncoderBlock2D" + + down_block_types.append(down_block_type) + + curr_res = [r // 2 for r in curr_res] + + return down_block_types + +def get_up_block_types(original_decoder_config): + attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions) + num_resolutions = len(original_decoder_config.ch_mult) + resolution = coerce_resolution(original_decoder_config.resolution) + + curr_res = [r // 2**(num_resolutions-1) for r in resolution] + up_block_types = [] + + for _ in reversed(range(num_resolutions)): + if curr_res in attn_resolutions: + up_block_type = "AttnUpDecoderBlock2D" + else: + up_block_type = "UpDecoderBlock2D" + + up_block_types.append(up_block_type) + + curr_res = [r * 2 for r in curr_res] + + return up_block_types + +def coerce_attn_resolutions(attn_resolutions): + attn_resolutions = OmegaConf.to_object(attn_resolutions) + attn_resolutions_ = [] + for ar in attn_resolutions: + if isinstance(ar, (list, tuple)): + attn_resolutions_.append(list(ar)) + else: + attn_resolutions_.append([ar, ar]) + return attn_resolutions_ + +def coerce_resolution(resolution): + resolution = OmegaConf.to_object(resolution) + if isinstance(resolution, int): + resolution = [resolution, resolution] # H, W + elif isinstance(resolution, (tuple, list)): + resolution = list(resolution) + else: + raise ValueError('Unknown type of resolution:', resolution) + return resolution + +############################################### + +############# vqvae checkpoint ############### + +def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + encoder_to_diffusers_checkpoint(model, checkpoint) + ) + + # quant_conv + + diffusers_checkpoint.update({ + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + }) + + # quantize + diffusers_checkpoint.update({ + "quantize.embedding.weight": checkpoint['quantize.embedding'] + }) + + # post_quant_conv + diffusers_checkpoint.update({ + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + }) + + # decoder + diffusers_checkpoint.update( + decoder_to_diffusers_checkpoint(model, checkpoint) + ) + + return diffusers_checkpoint + +def encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update({ + 'encoder.conv_in.weight': checkpoint['encoder.conv_in.weight'], + 'encoder.conv_in.bias': checkpoint['encoder.conv_in.bias'] + }) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + resnet, + checkpoint, + diffusers_resnet_prefix=diffusers_resnet_prefix, + resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update({ + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"] + }) + + # attentions + + if hasattr(down_block, 'attentions'): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix + ) + ) + + + # mid block + + ## attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = f"encoder.mid_block.attentions.0" + attention_prefix = f"encoder.mid.attn_1" + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix + ) + ) + + ## resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + resnet, + checkpoint, + diffusers_resnet_prefix=diffusers_resnet_prefix, + resnet_prefix=resnet_prefix + ) + ) + + + diffusers_checkpoint.update({ + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + }) + + return diffusers_checkpoint + +def decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update({ + 'decoder.conv_in.weight': checkpoint['decoder.conv_in.weight'], + 'decoder.conv_in.bias': checkpoint['decoder.conv_in.bias'] + }) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + resnet, + checkpoint, + diffusers_resnet_prefix=diffusers_resnet_prefix, + resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update({ + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"] + }) + + # attentions + + if hasattr(up_block, 'attentions'): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix + ) + ) + + # mid block + + ## attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = f"decoder.mid_block.attentions.0" + attention_prefix = f"decoder.mid.attn_1" + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix + ) + ) + + ## resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + resnet, + checkpoint, + diffusers_resnet_prefix=diffusers_resnet_prefix, + resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update({ + # conv_norm_out + "decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"], + "decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"], + + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + }) + + return diffusers_checkpoint + +def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update({ + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + }) + + return rv + +def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # group_norm + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + + # query + f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"], + f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], + + # key + f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"], + f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], + + # value + f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"], + f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], + + # proj_attn + f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"], + f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + +############################################### + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--only-vqvae", + action="store_true", + help="Set flag to only convert the VQVAE model. checkpoint_path and original_config_file must point to checkpoints and configs only for the VQVAE, not the whole diffusion model. Note that the entirety of the diffusers VQDiffusionPipeline will be saved." + ) + + parser.add_argument( + "--checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the checkpoint to convert." + ) + + parser.add_argument( + "--original_config_file", + default=None, + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + + parser.add_argument( + "--dump_path", + default=None, + type=str, + required=True, + help="Path to the output model." + ) + + args = parser.parse_args() + + # The yaml file contains annotations that certain values should + # loaded as tuples. By default, OmegaConf will panic when reading + # these. Instead, we can manually read the yaml with the FullLoader and then + # construct the OmegaConf object. + with open(args.original_config_file) as f: + original_config = yaml.load(f, FullLoader) + + original_config = OmegaConf.create(original_config).model + + checkpoint = torch.load(args.checkpoint_path)['model'] + + if args.only_vqvae: + vqvae_model = vqvae_model_from_original_config(original_config) + vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, checkpoint) + vqvae_model.load_state_dict(vqvae_diffusers_checkpoint) + else: + print("This script currently only supports porting the VQVAE to diffusers. Must pass the --only-vqvae flag.") + sys.exit(1) + + pipe = VQDiffusionPipeline(vqvae=vqvae_model) + + pipe.save_pretrained(args.dump_path) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f963310f12eb..dd86dc65880a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -91,6 +91,64 @@ def forward(self, hidden_states): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states +class ConvAttentionBlock(nn.Module): + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + self.query = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + self.key = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + self.value = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + + self.proj_attn = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + + self.rescale_output_factor = rescale_output_factor + + def forward(self, hidden_states): + residual = hidden_states + + hidden_states = self.group_norm(hidden_states) + + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + batch, channel, height, width = query_proj.shape + + query_states = query_proj.reshape(batch, channel, height * width).permute(0, 2, 1) + key_states = key_proj.reshape(batch, channel, height * width) + value_states = value_proj.reshape(batch, channel, height * width) + + scale = 1 / math.sqrt(self.channels / self.num_heads) + attention_scores = torch.bmm(query_states, key_states) + attention_scores = attention_scores * scale + + attention_probs = torch.softmax(attention_scores.float(), dim=2).type(attention_scores.dtype) + attention_probs = attention_probs.permute(0, 2, 1) + + hidden_states = torch.bmm(value_states, attention_probs) + hidden_states = hidden_states.reshape(batch, channel, height, width) + hidden_states = self.proj_attn(hidden_states) + + # res connect and resale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + + return hidden_states + + + class SpatialTransformer(nn.Module): """ diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index f42389b98562..e8dab4f1acba 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -17,7 +17,7 @@ import torch from torch import nn -from .attention import AttentionBlock, SpatialTransformer +from .attention import AttentionBlock, SpatialTransformer, ConvAttentionBlock from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -34,6 +34,7 @@ def get_down_block( resnet_groups=None, cross_attention_dim=None, downsample_padding=None, + conv_attention_block=False ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -111,6 +112,20 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attn_num_head_channels=attn_num_head_channels, + conv_attention_block=conv_attention_block + ) + raise ValueError(f"{down_block_type} does not exist.") def get_up_block( @@ -126,6 +141,7 @@ def get_up_block( attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, + conv_attention_block=False ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -202,6 +218,18 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attn_num_head_channels=attn_num_head_channels, + conv_attention_block=conv_attention_block + ) raise ValueError(f"{up_block_type} does not exist.") @@ -220,6 +248,7 @@ def __init__( attn_num_head_channels=1, attention_type="default", output_scale_factor=1.0, + conv_attention_block: bool = False, **kwargs, ): super().__init__() @@ -244,9 +273,11 @@ def __init__( ] attentions = [] + attention_block_class = ConvAttentionBlock if conv_attention_block else AttentionBlock + for _ in range(num_layers): attentions.append( - AttentionBlock( + attention_block_class( in_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -730,11 +761,14 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + conv_attention_block: bool=False ): super().__init__() resnets = [] attentions = [] + attention_block_class = ConvAttentionBlock if conv_attention_block else AttentionBlock + for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( @@ -752,7 +786,7 @@ def __init__( ) ) attentions.append( - AttentionBlock( + attention_block_class( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -1299,11 +1333,14 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, + conv_attention_block: bool = False ): super().__init__() resnets = [] attentions = [] + attention_block_class = ConvAttentionBlock if conv_attention_block else AttentionBlock + for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels @@ -1322,7 +1359,7 @@ def __init__( ) ) attentions.append( - AttentionBlock( + attention_block_class( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index fe89b41c074e..35491854774d 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -62,6 +62,7 @@ def __init__( norm_num_groups=32, act_fn="silu", double_z=True, + conv_attention_block: bool = False ): super().__init__() self.layers_per_block = layers_per_block @@ -90,6 +91,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, + conv_attention_block=conv_attention_block ) self.down_blocks.append(down_block) @@ -103,6 +105,7 @@ def __init__( attn_num_head_channels=None, resnet_groups=norm_num_groups, temb_channels=None, + conv_attention_block=conv_attention_block ) # out @@ -141,6 +144,7 @@ def __init__( layers_per_block=2, norm_num_groups=32, act_fn="silu", + conv_attention_block: bool = False ): super().__init__() self.layers_per_block = layers_per_block @@ -160,6 +164,7 @@ def __init__( attn_num_head_channels=None, resnet_groups=norm_num_groups, temb_channels=None, + conv_attention_block=conv_attention_block ) # up @@ -183,6 +188,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, + conv_attention_block=conv_attention_block ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -408,6 +414,8 @@ def __init__( sample_size: int = 32, num_vq_embeddings: int = 256, norm_num_groups: int = 32, + e_dim: Optional[int] = None, + conv_attention_block: bool = False ): super().__init__() @@ -421,13 +429,16 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=False, + conv_attention_block=conv_attention_block, ) - self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + e_dim = e_dim if e_dim is not None else latent_channels + + self.quant_conv = torch.nn.Conv2d(latent_channels, e_dim, 1) self.quantize = VectorQuantizer( - num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False + num_vq_embeddings, e_dim, beta=0.25, remap=None, sane_index_shape=False ) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(e_dim, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( @@ -438,6 +449,7 @@ def __init__( layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, + conv_attention_block=conv_attention_block ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8e3c8592a258..4c6707cace73 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -14,6 +14,7 @@ StableDiffusionInpaintPipeline, StableDiffusionPipeline, ) + from .vq_diffusion import VQDiffusionPipeline if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py new file mode 100644 index 000000000000..edf6f570f5bf --- /dev/null +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -0,0 +1 @@ +from .pipeline_vq_diffusion import VQDiffusionPipeline diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py new file mode 100644 index 000000000000..7ff36d4e2ef4 --- /dev/null +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -0,0 +1,44 @@ +from ...pipeline_utils import DiffusionPipeline +from diffusers import VQModel + +import numpy as np +import torch +import PIL + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + +# This class is a placeholder and does not have the full VQ-diffusion pipeline built out yet +# +# NOTE: In VQ-Diffusion, the VQVAE trained on the ITHQ dataset uses an EMA variant of the vector quantizer +# in diffusers. The EMA variant uses EMA's to update the codebook during training but acts the same as the +# usual vector quantizer during inference. The VQDiffusion pipeline uses the non-ema vector quantizer during +# inference. If diffusers is to support training, the EMA vector quantizer could be implemented. For more +# information on EMA Vector quantizers, see https://arxiv.org/abs/1711.00937. +class VQDiffusionPipeline(DiffusionPipeline): + + vqvae: VQModel + + def __init__(self, vqvae: VQModel): + super().__init__() + self.register_modules(vqvae=vqvae) + + @torch.no_grad() + def encode(self, image): + image = preprocess_image(image) + encoded = self.vqvae.encode(image) + return encoded.latents + + @torch.no_grad() + def decode(self, encoded_image): + image = self.vqvae.decode(encoded_image).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = self.numpy_to_pil(image) + return image