From 2dcf64b72a44dc748610c3eb51534b325df3bca9 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Sun, 18 Dec 2022 15:15:30 -0800 Subject: [PATCH] kakaobrain unCLIP (#1428) * [wip] attention block updates * [wip] unCLIP unet decoder and super res * [wip] unCLIP prior transformer * [wip] scheduler changes * [wip] text proj utility class * [wip] UnCLIPPipeline * [wip] kakaobrain unCLIP convert script * [unCLIP pipeline] fixes re: @patrickvonplaten remove callbacks move denoising loops into call function * UNCLIPScheduler re: @patrickvonplaten Revert changes to DDPMScheduler. Make UNCLIPScheduler, a modified DDPM scheduler with changes to support karlo * mask -> attention_mask re: @patrickvonplaten * [DDPMScheduler] remove leftover change * [docs] PriorTransformer * [docs] UNet2DConditionModel and UNet2DModel * [nit] UNCLIPScheduler -> UnCLIPScheduler matches existing unclip naming better * [docs] SchedulingUnCLIP * [docs] UnCLIPTextProjModel * refactor * finish licenses * rename all to attention_mask and prep in models * more renaming * don't expose unused configs * final renaming fixes * remove x attn mask when not necessary * configure kakao script to use new class embedding config * fix copies * [tests] UnCLIPScheduler * finish x attn * finish * remove more * rename condition blocks * clean more * Apply suggestions from code review * up * fix * [tests] UnCLIPPipelineFastTests * remove unused imports * [tests] UnCLIPPipelineIntegrationTests * correct * make style Co-authored-by: Patrick von Platen --- ...convert_kakao_brain_unclip_to_diffusers.py | 1159 +++++++++++++++++ src/diffusers/__init__.py | 12 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/attention.py | 124 +- src/diffusers/models/prior_transformer.py | 194 +++ src/diffusers/models/resnet.py | 16 +- src/diffusers/models/unet_2d.py | 11 +- src/diffusers/models/unet_2d_blocks.py | 662 +++++++++- src/diffusers/models/unet_2d_condition.py | 81 +- src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/unclip/__init__.py | 6 + .../pipelines/unclip/pipeline_unclip.py | 370 ++++++ src/diffusers/pipelines/unclip/text_proj.py | 87 ++ .../versatile_diffusion/modeling_text_unet.py | 218 +++- src/diffusers/schedulers/__init__.py | 1 + src/diffusers/schedulers/scheduling_unclip.py | 309 +++++ src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/unclip/__init__.py | 0 tests/pipelines/unclip/test_unclip.py | 282 ++++ tests/test_scheduler.py | 133 ++ 21 files changed, 3594 insertions(+), 118 deletions(-) create mode 100644 scripts/convert_kakao_brain_unclip_to_diffusers.py create mode 100644 src/diffusers/models/prior_transformer.py create mode 100644 src/diffusers/pipelines/unclip/__init__.py create mode 100644 src/diffusers/pipelines/unclip/pipeline_unclip.py create mode 100644 src/diffusers/pipelines/unclip/text_proj.py create mode 100644 src/diffusers/schedulers/scheduling_unclip.py create mode 100644 tests/pipelines/unclip/__init__.py create mode 100644 tests/pipelines/unclip/test_unclip.py diff --git a/scripts/convert_kakao_brain_unclip_to_diffusers.py b/scripts/convert_kakao_brain_unclip_to_diffusers.py new file mode 100644 index 0000000000..ebb98eb66a --- /dev/null +++ b/scripts/convert_kakao_brain_unclip_to_diffusers.py @@ -0,0 +1,1159 @@ +import argparse +import tempfile + +import torch + +from accelerate import load_checkpoint_and_dispatch +from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel +from diffusers.models.prior_transformer import PriorTransformer +from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel +from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + + +""" +Example - From the diffusers root directory: + +Download weights: +```sh +$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt +$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt +$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt +$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th +``` + +Convert the model: +```sh +$ python scripts/convert_kakao_brain_unclip_to_diffusers.py \ + --decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \ + --super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \ + --prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \ + --clip_stat_path ./ViT-L-14_stats.th \ + --dump_path +``` +""" + + +# prior + +PRIOR_ORIGINAL_PREFIX = "model" + +# Uses default arguments +PRIOR_CONFIG = {} + + +def prior_model_from_original_config(): + model = PriorTransformer(**PRIOR_CONFIG) + + return model + + +def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint): + diffusers_checkpoint = {} + + # .time_embed.0 -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"], + } + ) + + # .clip_img_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"], + } + ) + + # .text_emb_proj -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"], + } + ) + + # .text_enc_proj -> .encoder_hidden_states_proj + diffusers_checkpoint.update( + { + "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"], + "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"], + } + ) + + # .positional_embedding -> .positional_embedding + diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]}) + + # .prd_emb -> .prd_embedding + diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]}) + + # .time_embed.2 -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"], + } + ) + + # .resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .final_ln -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"], + } + ) + + # .out_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"], + } + ) + + # clip stats + clip_mean, clip_std = clip_stats_checkpoint + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std}) + + return diffusers_checkpoint + + +def prior_attention_to_diffusers( + checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim +): + diffusers_checkpoint = {} + + # .c_qkv -> .{to_q, to_k, to_v} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"], + bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"], + split=3, + chunk_size=attention_head_dim, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .c_proj -> .to_out.0 + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix): + diffusers_checkpoint = { + # .c_fc -> .net.0.proj + f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"], + f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"], + # .c_proj -> .net.2 + f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"], + f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"], + } + + return diffusers_checkpoint + + +# done prior + + +# decoder + +DECODER_ORIGINAL_PREFIX = "model" + +# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can +# update then. +DECODER_CONFIG = { + "sample_size": 64, + "layers_per_block": 3, + "down_block_types": ( + "ResnetDownsampleBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + ), + "up_block_types": ( + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "ResnetUpsampleBlock2D", + ), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (320, 640, 960, 1280), + "in_channels": 3, + "out_channels": 6, + "cross_attention_dim": 1536, + "class_embed_type": "identity", + "attention_head_dim": 64, + "resnet_time_scale_shift": "scale_shift", + "class_embed_type": "identity", +} + + +def decoder_model_from_original_config(): + model = UNet2DConditionModel(**DECODER_CONFIG) + + return model + + +def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + original_unet_prefix = DECODER_ORIGINAL_PREFIX + num_head_channels = DECODER_CONFIG["attention_head_dim"] + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=num_head_channels, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .input_blocks -> .down_blocks + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + original_unet_prefix=original_unet_prefix, + num_head_channels=num_head_channels, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=num_head_channels, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix)) + + return diffusers_checkpoint + + +# done decoder + +# text proj + + +def text_proj_from_original_config(): + # From the conditional unet constructor where the dimension of the projected time embeddings is + # constructed + time_embed_dim = DECODER_CONFIG["block_out_channels"][0] * 4 + + cross_attention_dim = DECODER_CONFIG["cross_attention_dim"] + + model = UnCLIPTextProjModel(time_embed_dim=time_embed_dim, cross_attention_dim=cross_attention_dim) + + return model + + +# Note that the input checkpoint is the original decoder checkpoint +def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint): + diffusers_checkpoint = { + # .text_seq_proj.0 -> .encoder_hidden_states_proj + "encoder_hidden_states_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"], + "encoder_hidden_states_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"], + # .text_seq_proj.1 -> .text_encoder_hidden_states_norm + "text_encoder_hidden_states_norm.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"], + "text_encoder_hidden_states_norm.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"], + # .clip_tok_proj -> .clip_extra_context_tokens_proj + "clip_extra_context_tokens_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"], + "clip_extra_context_tokens_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"], + # .text_feat_proj -> .embedding_proj + "embedding_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"], + "embedding_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"], + # .cf_param -> .learned_classifier_free_guidance_embeddings + "learned_classifier_free_guidance_embeddings": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.cf_param"], + # .clip_emb -> .clip_image_embeddings_project_to_time_embeddings + "clip_image_embeddings_project_to_time_embeddings.weight": checkpoint[ + f"{DECODER_ORIGINAL_PREFIX}.clip_emb.weight" + ], + "clip_image_embeddings_project_to_time_embeddings.bias": checkpoint[ + f"{DECODER_ORIGINAL_PREFIX}.clip_emb.bias" + ], + } + + return diffusers_checkpoint + + +# done text proj + +# super res unet first steps + +SUPER_RES_UNET_FIRST_STEPS_PREFIX = "model_first_steps" + +SUPER_RES_UNET_FIRST_STEPS_CONFIG = { + "sample_size": 256, + "layers_per_block": 3, + "down_block_types": ( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + "up_block_types": ( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + "block_out_channels": (320, 640, 960, 1280), + "in_channels": 6, + "out_channels": 3, + "add_attention": False, +} + + +def super_res_unet_first_steps_model_from_original_config(): + model = UNet2DModel(**SUPER_RES_UNET_FIRST_STEPS_CONFIG) + + return model + + +def super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + original_unet_prefix = SUPER_RES_UNET_FIRST_STEPS_PREFIX + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix)) + + return diffusers_checkpoint + + +# done super res unet first steps + +# super res unet last step + +SUPER_RES_UNET_LAST_STEP_PREFIX = "model_last_step" + +SUPER_RES_UNET_LAST_STEP_CONFIG = { + "sample_size": 256, + "layers_per_block": 3, + "down_block_types": ( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + "up_block_types": ( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + "block_out_channels": (320, 640, 960, 1280), + "in_channels": 6, + "out_channels": 3, + "add_attention": False, +} + + +def super_res_unet_last_step_model_from_original_config(): + model = UNet2DModel(**SUPER_RES_UNET_LAST_STEP_CONFIG) + + return model + + +def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + original_unet_prefix = SUPER_RES_UNET_LAST_STEP_PREFIX + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + original_unet_prefix=original_unet_prefix, + num_head_channels=None, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix)) + diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix)) + + return diffusers_checkpoint + + +# done super res unet last step + + +# unet utils + +# .time_embed -> .time_embedding +def unet_time_embeddings(checkpoint, original_unet_prefix): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{original_unet_prefix}.time_embed.0.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{original_unet_prefix}.time_embed.0.bias"], + "time_embedding.linear_2.weight": checkpoint[f"{original_unet_prefix}.time_embed.2.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{original_unet_prefix}.time_embed.2.bias"], + } + ) + + return diffusers_checkpoint + + +# .input_blocks.0 -> .conv_in +def unet_conv_in(checkpoint, original_unet_prefix): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_in.weight": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.weight"], + "conv_in.bias": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.bias"], + } + ) + + return diffusers_checkpoint + + +# .out.0 -> .conv_norm_out +def unet_conv_norm_out(checkpoint, original_unet_prefix): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_norm_out.weight": checkpoint[f"{original_unet_prefix}.out.0.weight"], + "conv_norm_out.bias": checkpoint[f"{original_unet_prefix}.out.0.bias"], + } + ) + + return diffusers_checkpoint + + +# .out.2 -> .conv_out +def unet_conv_out(checkpoint, original_unet_prefix): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_out.weight": checkpoint[f"{original_unet_prefix}.out.2.weight"], + "conv_out.bias": checkpoint[f"{original_unet_prefix}.out.2.bias"], + } + ) + + return diffusers_checkpoint + + +# .input_blocks -> .down_blocks +def unet_downblock_to_diffusers_checkpoint( + model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, original_unet_prefix, num_head_channels +): + diffusers_checkpoint = {} + + diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets" + original_down_block_prefix = f"{original_unet_prefix}.input_blocks" + + down_block = model.down_blocks[diffusers_down_block_idx] + + num_resnets = len(down_block.resnets) + + if down_block.downsamplers is None: + downsampler = False + else: + assert len(down_block.downsamplers) == 1 + downsampler = True + # The downsample block is also a resnet + num_resnets += 1 + + for resnet_idx_inc in range(num_resnets): + full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0" + + if downsampler and resnet_idx_inc == num_resnets - 1: + # this is a downsample block + full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0" + else: + # this is a regular resnet block + full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix + ) + ) + + if hasattr(down_block, "attentions"): + num_attentions = len(down_block.attentions) + diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions" + + for attention_idx_inc in range(num_attentions): + full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1" + full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" + + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + attention_prefix=full_attention_prefix, + diffusers_attention_prefix=full_diffusers_attention_prefix, + num_head_channels=num_head_channels, + ) + ) + + num_original_down_blocks = num_resnets + + return diffusers_checkpoint, num_original_down_blocks + + +# .middle_block -> .mid_block +def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, original_unet_prefix, num_head_channels): + diffusers_checkpoint = {} + + # block 0 + + original_block_idx = 0 + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, + diffusers_resnet_prefix="mid_block.resnets.0", + resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}", + ) + ) + + original_block_idx += 1 + + # optional block 1 + + if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None: + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix="mid_block.attentions.0", + attention_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}", + num_head_channels=num_head_channels, + ) + ) + original_block_idx += 1 + + # block 1 or block 2 + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, + diffusers_resnet_prefix="mid_block.resnets.1", + resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}", + ) + ) + + return diffusers_checkpoint + + +# .output_blocks -> .up_blocks +def unet_upblock_to_diffusers_checkpoint( + model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, original_unet_prefix, num_head_channels +): + diffusers_checkpoint = {} + + diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets" + original_up_block_prefix = f"{original_unet_prefix}.output_blocks" + + up_block = model.up_blocks[diffusers_up_block_idx] + + num_resnets = len(up_block.resnets) + + if up_block.upsamplers is None: + upsampler = False + else: + assert len(up_block.upsamplers) == 1 + upsampler = True + # The upsample block is also a resnet + num_resnets += 1 + + has_attentions = hasattr(up_block, "attentions") + + for resnet_idx_inc in range(num_resnets): + if upsampler and resnet_idx_inc == num_resnets - 1: + # this is an upsample block + if has_attentions: + # There is a middle attention block that we skip + original_resnet_block_idx = 2 + else: + original_resnet_block_idx = 1 + + # we add the `minus 1` because the last two resnets are stuck together in the same output block + full_resnet_prefix = ( + f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}" + ) + + full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0" + else: + # this is a regular resnet block + full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0" + full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix + ) + ) + + if has_attentions: + num_attentions = len(up_block.attentions) + diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions" + + for attention_idx_inc in range(num_attentions): + full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1" + full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" + + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + attention_prefix=full_attention_prefix, + diffusers_attention_prefix=full_diffusers_attention_prefix, + num_head_channels=num_head_channels, + ) + ) + + num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets + + return diffusers_checkpoint, num_original_down_blocks + + +def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + diffusers_checkpoint = { + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"], + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"], + f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"], + f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"], + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"], + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"], + } + + skip_connection_prefix = f"{resnet_prefix}.skip_connection" + + if f"{skip_connection_prefix}.weight" in checkpoint: + diffusers_checkpoint.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"], + } + ) + + return diffusers_checkpoint + + +def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels): + diffusers_checkpoint = {} + + # .norm -> .group_norm + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + } + ) + + # .qkv -> .{query, key, value} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0], + bias=checkpoint[f"{attention_prefix}.qkv.bias"], + split=3, + chunk_size=num_head_channels, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .encoder_kv -> .{context_key, context_value} + [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( + weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0], + bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"], + split=2, + chunk_size=num_head_channels, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight, + f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias, + f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight, + f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias, + } + ) + + # .proj_out (1d conv) -> .proj_attn (linear) + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0 + ], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + ) + + return diffusers_checkpoint + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + assert weights[weights_biases_idx] is None + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +# done unet utils + + +# Driver functions + + +def text_encoder(): + print("loading CLIP text encoder") + + clip_name = "openai/clip-vit-large-patch14" + + # sets pad_value to 0 + pad_token = "!" + + tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") + + assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 + + text_encoder_model = CLIPTextModelWithProjection.from_pretrained( + clip_name, + # `CLIPTextModel` does not support device_map="auto" + # device_map="auto" + ) + + print("done loading CLIP text encoder") + + return text_encoder_model, tokenizer_model + + +def prior(*, args, checkpoint_map_location): + print("loading prior") + + prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location) + prior_checkpoint = prior_checkpoint["state_dict"] + + clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location) + + prior_model = prior_model_from_original_config() + + prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint( + prior_model, prior_checkpoint, clip_stats_checkpoint + ) + + del prior_checkpoint + del clip_stats_checkpoint + + load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True) + + print("done loading prior") + + return prior_model + + +def decoder(*, args, checkpoint_map_location): + print("loading decoder") + + decoder_checkpoint = torch.load(args.decoder_checkpoint_path, map_location=checkpoint_map_location) + decoder_checkpoint = decoder_checkpoint["state_dict"] + + decoder_model = decoder_model_from_original_config() + + decoder_diffusers_checkpoint = decoder_original_checkpoint_to_diffusers_checkpoint( + decoder_model, decoder_checkpoint + ) + + # text proj interlude + + # The original decoder implementation includes a set of parameters that are used + # for creating the `encoder_hidden_states` which are what the U-net is conditioned + # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull + # the parameters into the UnCLIPTextProjModel class + text_proj_model = text_proj_from_original_config() + + text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(decoder_checkpoint) + + load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True) + + # done text proj interlude + + del decoder_checkpoint + + load_checkpoint_to_model(decoder_diffusers_checkpoint, decoder_model, strict=True) + + print("done loading decoder") + + return decoder_model, text_proj_model + + +def super_res_unet(*, args, checkpoint_map_location): + print("loading super resolution unet") + + super_res_checkpoint = torch.load(args.super_res_unet_checkpoint_path, map_location=checkpoint_map_location) + super_res_checkpoint = super_res_checkpoint["state_dict"] + + # model_first_steps + + super_res_first_model = super_res_unet_first_steps_model_from_original_config() + + super_res_first_steps_checkpoint = super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint( + super_res_first_model, super_res_checkpoint + ) + + # model_last_step + super_res_last_model = super_res_unet_last_step_model_from_original_config() + + super_res_last_step_checkpoint = super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint( + super_res_last_model, super_res_checkpoint + ) + + del super_res_checkpoint + + load_checkpoint_to_model(super_res_first_steps_checkpoint, super_res_first_model, strict=True) + + load_checkpoint_to_model(super_res_last_step_checkpoint, super_res_last_model, strict=True) + + print("done loading super resolution unet") + + return super_res_first_model, super_res_last_model + + +def load_checkpoint_to_model(checkpoint, model, strict=False): + with tempfile.NamedTemporaryFile() as file: + torch.save(checkpoint, file.name) + del checkpoint + if strict: + model.load_state_dict(torch.load(file.name), strict=True) + else: + load_checkpoint_and_dispatch(model, file.name, device_map="auto") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--prior_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the prior checkpoint to convert.", + ) + + parser.add_argument( + "--decoder_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the decoder checkpoint to convert.", + ) + + parser.add_argument( + "--super_res_unet_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the super resolution checkpoint to convert.", + ) + + parser.add_argument( + "--clip_stat_path", default=None, type=str, required=True, help="Path to the clip stats checkpoint to convert." + ) + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + parser.add_argument( + "--debug", + default=None, + type=str, + required=False, + help="Only run a specific stage of the convert script. Used for debugging", + ) + + args = parser.parse_args() + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + if args.debug is not None: + print(f"debug: only executing {args.debug}") + + if args.debug is None: + text_encoder_model, tokenizer_model = text_encoder() + + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + + decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location) + + super_res_first_model, super_res_last_model = super_res_unet( + args=args, checkpoint_map_location=checkpoint_map_location + ) + + prior_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="sample", + num_train_timesteps=1000, + clip_sample_range=5.0, + ) + + decoder_scheduler = UnCLIPScheduler( + variance_type="learned_range", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + super_res_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + print(f"saving Kakao Brain unCLIP to {args.dump_path}") + + pipe = UnCLIPPipeline( + prior=prior_model, + decoder=decoder_model, + text_proj=text_proj_model, + tokenizer=tokenizer_model, + text_encoder=text_encoder_model, + super_res_first=super_res_first_model, + super_res_last=super_res_last_model, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + pipe.save_pretrained(args.dump_path) + + print("done writing Kakao Brain unCLIP") + elif args.debug == "text_encoder": + text_encoder_model, tokenizer_model = text_encoder() + elif args.debug == "prior": + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + elif args.debug == "decoder": + decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location) + elif args.debug == "super_res_unet": + super_res_first_model, super_res_last_model = super_res_unet( + args=args, checkpoint_map_location=checkpoint_map_location + ) + else: + raise ValueError(f"unknown debug value : {args.debug}") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a61024d9aa..68064af4a7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -25,7 +25,15 @@ from .utils.dummy_pt_objects import * # noqa F403 else: from .modeling_utils import ModelMixin - from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel + from .models import ( + AutoencoderKL, + PriorTransformer, + Transformer2DModel, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + VQModel, + ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -63,6 +71,7 @@ RePaintScheduler, SchedulerMixin, ScoreSdeVeScheduler, + UnCLIPScheduler, VQDiffusionScheduler, ) from .training_utils import EMAModel @@ -96,6 +105,7 @@ StableDiffusionPipeline, StableDiffusionPipelineSafe, StableDiffusionUpscalePipeline, + UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 5b101d1691..d0ee290fbe 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -17,6 +17,7 @@ if is_torch_available(): from .attention import Transformer2DModel + from .prior_transformer import PriorTransformer from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 99dd5d8d51..b29fe02c4a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -66,8 +66,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin): in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. @@ -181,7 +181,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): @@ -213,7 +213,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu # 2. Blocks for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep) # 3. Output if self.is_input_continuous: @@ -260,6 +260,8 @@ class AttentionBlock(nn.Module): eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ + # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore + def __init__( self, channels: int, @@ -369,6 +371,7 @@ def forward(self, hidden_states): # compute next hidden_states hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale @@ -385,7 +388,7 @@ class BasicTransformerBlock(nn.Module): num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. @@ -432,7 +435,7 @@ def __init__( dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, - ) # is self-attn if context is none + ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None @@ -472,23 +475,30 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - def forward(self, hidden_states, context=None, timestep=None): + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None): # 1. Self-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) if self.only_cross_attention: - hidden_states = self.attn1(norm_hidden_states, context) + hidden_states + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) else: - hidden_states = self.attn1(norm_hidden_states) + hidden_states + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states if self.attn2 is not None: # 2. Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states @@ -503,7 +513,7 @@ class CrossAttention(nn.Module): Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): - The number of channels in the context. If not given, defaults to `query_dim`. + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. @@ -520,13 +530,18 @@ def __init__( dropout: float = 0.0, bias=False, upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, ): super().__init__() inner_dim = dim_head * heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax self.scale = dim_head**-0.5 + self.heads = heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory @@ -534,11 +549,21 @@ def __init__( self.sliceable_head_dim = heads self._slice_size = None self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim)) self.to_out.append(nn.Dropout(dropout)) @@ -563,40 +588,58 @@ def set_attention_slice(self, slice_size): self._slice_size = slice_size - def forward(self, hidden_states, context=None, mask=None): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape - query = self.to_q(hidden_states) - context = context if context is not None else hidden_states - key = self.to_k(context) - value = self.to_v(context) + encoder_hidden_states = encoder_hidden_states - dim = query.shape[-1] + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = self.to_q(hidden_states) + dim = query.shape[-1] query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - # TODO(PVP) - mask is currently never used. Remember to re-implement when used + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: - hidden_states = self._memory_efficient_attention_xformers(query, key, value) + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) + hidden_states = self._attention(query, key, value, attention_mask) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) # linear proj hidden_states = self.to_out[0](hidden_states) + # dropout hidden_states = self.to_out[1](hidden_states) return hidden_states - def _attention(self, query, key, value): + def _attention(self, query, key, value, attention_mask=None): if self.upcast_attention: query = query.float() key = key.float() @@ -608,6 +651,18 @@ def _attention(self, query, key, value): beta=0, alpha=self.scale, ) + + if attention_mask is not None: + if attention_mask.shape != attention_scores.shape: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + attention_probs = attention_scores.softmax(dim=-1) # cast back to the original dtype @@ -656,7 +711,8 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states - def _memory_efficient_attention_xformers(self, query, key, value): + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask query = query.contiguous() key = key.contiguous() value = value.contiguous() @@ -802,7 +858,7 @@ class DualTransformer2DModel(nn.Module): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. @@ -867,17 +923,21 @@ def __init__( # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` self.transformer_index_for_condition = [1, 0] - def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True): + def forward( + self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True + ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in CrossAttention return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -894,9 +954,13 @@ def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_di # for each of the two transformers, pass the corresponding condition tokens condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] transformer_index = self.transformer_index_for_condition[i] - encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[ - 0 - ] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + attention_mask=attention_mask, + return_dict=False, + )[0] encoded_states.append(encoded_state - input_states) tokens_start += self.condition_lengths[i] diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py new file mode 100644 index 0000000000..714d5d52cf --- /dev/null +++ b/src/diffusers/models/prior_transformer.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .attention import BasicTransformerBlock +from .embeddings import TimestepEmbedding, Timesteps + + +@dataclass +class PriorTransformerOutput(BaseOutput): + """ + Args: + predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + The predicted CLIP image embedding conditioned on the CLIP text embedding input. + """ + + predicted_image_embedding: torch.FloatTensor + + +class PriorTransformer(ModelMixin, ConfigMixin): + """ + The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the + transformer predicts the image embeddings through a denoising diffusion process. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + For more details, see the original paper: https://arxiv.org/abs/2204.06125 + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP + image embeddings and text embeddings are both the same dimension. + num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the + length of the prompt after it has been tokenized. + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected hidden_states. The actual length of the used hidden_states is `num_embeddings + + additional_embeddings`. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 32, + attention_head_dim: int = 64, + num_layers: int = 20, + embedding_dim: int = 768, + num_embeddings=77, + additional_embeddings=4, + dropout: float = 0.0, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.additional_embeddings = additional_embeddings + + self.time_proj = Timesteps(inner_dim, True, 0) + self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) + + self.proj_in = nn.Linear(embedding_dim, inner_dim) + + self.embedding_proj = nn.Linear(embedding_dim, inner_dim) + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + + self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) + + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn="gelu", + attention_bias=True, + ) + for d in range(num_layers) + ] + ) + + self.norm_out = nn.LayerNorm(inner_dim) + self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) + + causal_attention_mask = torch.full( + [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf") + ) + causal_attention_mask.triu_(1) + causal_attention_mask = causal_attention_mask[None, ...] + self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) + + self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) + + def forward( + self, + hidden_states, + timestep: Union[torch.Tensor, float, int], + proj_embedding: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.BoolTensor] = None, + return_dict: bool = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + x_t, the currently predicted image embeddings. + timestep (`torch.long`): + Current denoising step. + proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): + Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): + Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): + Text mask for the text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain + tuple. + + Returns: + [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: + [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + batch_size = hidden_states.shape[0] + + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) + + timesteps_projected = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timesteps_projected = timesteps_projected.to(dtype=self.dtype) + time_embeddings = self.time_embedding(timesteps_projected) + + proj_embeddings = self.embedding_proj(proj_embedding) + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + hidden_states = self.proj_in(hidden_states) + prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) + + hidden_states = torch.cat( + [ + encoder_hidden_states, + proj_embeddings[:, None, :], + time_embeddings[:, None, :], + hidden_states[:, None, :], + prd_embedding, + ], + dim=1, + ) + + hidden_states = hidden_states + positional_embeddings + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) + attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) + attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states[:, -1] + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) + + if not return_dict: + return (predicted_image_embedding,) + + return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) + + def post_process_latents(self, prior_latents): + prior_latents = (prior_latents * self.clip_std) + self.clip_mean + return prior_latents diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 52d056ae96..7037da5725 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -405,7 +405,14 @@ def __init__( self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) else: self.time_emb_proj = None @@ -465,9 +472,16 @@ def forward(self, input_tensor, temb): if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 5b337f482c..b49931a136 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -55,6 +55,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): down_block_types (`Tuple[str]`, *optional*, defaults to : obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): + The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. up_block_types (`Tuple[str]`, *optional*, defaults to : obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : @@ -66,6 +68,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. """ @register_to_config @@ -88,6 +92,8 @@ def __init__( attention_head_dim: int = 8, norm_num_groups: int = 32, norm_eps: float = 1e-5, + resnet_time_scale_shift: str = "default", + add_attention: bool = True, ): super().__init__() @@ -130,6 +136,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) @@ -140,9 +147,10 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", + resnet_time_scale_shift=resnet_time_scale_shift, attn_num_head_channels=attention_head_dim, resnet_groups=norm_num_groups, + add_attention=add_attention, ) # up @@ -167,6 +175,7 @@ def __init__( resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attn_num_head_channels=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index aa8d4c9849..1ee6f65500 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,7 @@ import torch from torch import nn -from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel +from .attention import AttentionBlock, CrossAttention, DualTransformer2DModel, Transformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -36,6 +36,7 @@ def get_down_block( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + resnet_time_scale_shift="default", ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -49,6 +50,19 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownBlock2D": return AttnDownBlock2D( @@ -62,6 +76,7 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: @@ -82,6 +97,23 @@ def get_down_block( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -93,6 +125,7 @@ def get_down_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnSkipDownBlock2D": return AttnSkipDownBlock2D( @@ -105,6 +138,7 @@ def get_down_block( resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( @@ -116,6 +150,7 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownEncoderBlock2D": return AttnDownEncoderBlock2D( @@ -128,6 +163,7 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{down_block_type} does not exist.") @@ -149,6 +185,7 @@ def get_up_block( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + resnet_time_scale_shift="default", ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -162,6 +199,20 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: @@ -182,6 +233,24 @@ def get_up_block( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -195,6 +264,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( @@ -206,6 +276,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnSkipUpBlock2D": return AttnSkipUpBlock2D( @@ -218,6 +289,7 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( @@ -228,6 +300,7 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( @@ -239,6 +312,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, + resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{up_block_type} does not exist.") @@ -255,14 +329,13 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, + add_attention: bool = True, attn_num_head_channels=1, - attention_type="default", output_scale_factor=1.0, ): super().__init__() - - self.attention_type = attention_type resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention # there is always at least one resnet resnets = [ @@ -282,15 +355,19 @@ def __init__( attentions = [] for _ in range(num_layers): - attentions.append( - AttentionBlock( - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, + if self.add_attention: + attentions.append( + AttentionBlock( + in_channels, + num_head_channels=attn_num_head_channels, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + ) ) - ) + else: + attentions.append(None) + resnets.append( ResnetBlock2D( in_channels=in_channels, @@ -309,13 +386,11 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_states=None): + def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.attention_type == "default": + if attn is not None: hidden_states = attn(hidden_states) - else: - hidden_states = attn(hidden_states, encoder_states) hidden_states = resnet(hidden_states, temb) return hidden_states @@ -334,7 +409,6 @@ def __init__( resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, - attention_type="default", output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, @@ -344,7 +418,6 @@ def __init__( super().__init__() self.has_cross_attention = True - self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -408,10 +481,121 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + # TODO(Patrick, William) - attention_mask is currently not used. Implement once used hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + ): + super().__init__() + + self.has_cross_attention = True + + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attn_num_head_channels + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + CrossAttention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + attention_mask=attention_mask, + ) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + # resnet hidden_states = resnet(hidden_states, temb) return hidden_states @@ -431,7 +615,6 @@ def __init__( resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, - attention_type="default", output_scale_factor=1.0, downsample_padding=1, add_downsample=True, @@ -440,8 +623,6 @@ def __init__( resnets = [] attentions = [] - self.attention_type = attention_type - for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( @@ -514,7 +695,6 @@ def __init__( resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, - attention_type="default", output_scale_factor=1.0, downsample_padding=1, add_downsample=True, @@ -528,7 +708,6 @@ def __init__( attentions = [] self.has_cross_attention = True - self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): @@ -588,7 +767,8 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -605,7 +785,9 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -847,7 +1029,6 @@ def __init__( resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attn_num_head_channels=1, - attention_type="default", output_scale_factor=np.sqrt(2.0), downsample_padding=1, add_downsample=True, @@ -856,8 +1037,6 @@ def __init__( self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) - self.attention_type = attention_type - for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( @@ -1006,6 +1185,205 @@ def forward(self, hidden_states, temb=None, skip_sample=None): return hidden_states, output_states, skip_sample +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attn_num_head_channels = attn_num_head_channels + self.num_heads = out_channels // self.attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + CrossAttention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + hidden_states = resnet(hidden_states, temb) + + # attn + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + attention_mask=attention_mask, + ) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states += (hidden_states,) + + return hidden_states, output_states + + class AttnUpBlock2D(nn.Module): def __init__( self, @@ -1020,7 +1398,6 @@ def __init__( resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_type="default", attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, @@ -1029,8 +1406,6 @@ def __init__( resnets = [] attentions = [] - self.attention_type = attention_type - for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -1100,7 +1475,6 @@ def __init__( resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, - attention_type="default", output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, @@ -1113,7 +1487,6 @@ def __init__( attentions = [] self.has_cross_attention = True - self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): @@ -1176,7 +1549,9 @@ def forward( temb=None, encoder_hidden_states=None, upsample_size=None, + attention_mask=None, ): + # TODO(Patrick, William) - attention mask is not used for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1196,7 +1571,9 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1418,7 +1795,6 @@ def __init__( resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attn_num_head_channels=1, - attention_type="default", output_scale_factor=np.sqrt(2.0), upsample_padding=1, add_upsample=True, @@ -1427,8 +1803,6 @@ def __init__( self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) - self.attention_type = attention_type - for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -1608,3 +1982,213 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + self.num_heads = out_channels // self.attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + CrossAttention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + # attn + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + attention_mask=attention_mask, + ) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0cfb152249..4612e650e4 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -27,6 +27,7 @@ CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, UpBlock2D, get_down_block, get_up_block, @@ -66,6 +67,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): @@ -78,6 +81,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately + summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. """ _supports_gradient_checkpointing = True @@ -97,6 +104,7 @@ def __init__( "CrossAttnDownBlock2D", "DownBlock2D", ), + mid_block_type: str = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), @@ -110,8 +118,10 @@ def __init__( attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", ): super().__init__() @@ -128,8 +138,14 @@ def __init__( self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding - if num_class_embeds is not None: + if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None @@ -165,24 +181,40 @@ def __init__( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) # mid - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 @@ -223,6 +255,7 @@ def __init__( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -307,6 +340,7 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -336,6 +370,11 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -365,9 +404,13 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) - if self.config.num_class_embeds is not None: + if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb @@ -382,6 +425,7 @@ def forward( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -389,7 +433,9 @@ def forward( down_block_res_samples += res_samples # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) # 5. up for i, upsample_block in enumerate(self.up_blocks): @@ -410,6 +456,7 @@ def forward( res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, + attention_mask=attention_mask, ) else: sample = upsample_block( diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9f7d1a05db..ea4199a8b5 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -53,6 +53,7 @@ StableDiffusionUpscalePipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .unclip import UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, diff --git a/src/diffusers/pipelines/unclip/__init__.py b/src/diffusers/pipelines/unclip/__init__.py new file mode 100644 index 0000000000..75f526fb5b --- /dev/null +++ b/src/diffusers/pipelines/unclip/__init__.py @@ -0,0 +1,6 @@ +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .pipeline_unclip import UnCLIPPipeline + from .text_proj import UnCLIPTextProjModel diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py new file mode 100644 index 0000000000..c9edcfd7aa --- /dev/null +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -0,0 +1,370 @@ +# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import torch +from torch.nn import functional as F + +from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel +from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import UnCLIPScheduler +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...utils import logging +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPPipeline(DiffusionPipeline): + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(self.device) + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(self.device)) + + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(self.device) + uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(self.device)) + + uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return text_embeddings, text_encoder_hidden_states, text_mask + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[torch.Generator] = None, + prior_latents: Optional[torch.FloatTensor] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, num_images_per_prompt, do_classifier_free_guidance + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=self.device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + text_embeddings.dtype, + self.device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=text_embeddings, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + text_embeddings=text_embeddings, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=self.device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + self.device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=self.device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + self.device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep + ).prev_sample + + image = super_res_latents + + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/unclip/text_proj.py b/src/diffusers/pipelines/unclip/text_proj.py new file mode 100644 index 0000000000..31e7e0644e --- /dev/null +++ b/src/diffusers/pipelines/unclip/text_proj.py @@ -0,0 +1,87 @@ +# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from diffusers.modeling_utils import ModelMixin + +from ...configuration_utils import ConfigMixin, register_to_config + + +class UnCLIPTextProjModel(ModelMixin, ConfigMixin): + """ + Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the + decoder. + + For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 + """ + + @register_to_config + def __init__( + self, + *, + clip_extra_context_tokens: int = 4, + clip_embeddings_dim: int = 768, + time_embed_dim: int, + cross_attention_dim, + ): + super().__init__() + + self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) + + # parameters for additional clip time embeddings + self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) + self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) + + # parameters for encoder hidden states + self.clip_extra_context_tokens = clip_extra_context_tokens + self.clip_extra_context_tokens_proj = nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) + self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance): + if do_classifier_free_guidance: + # Add the classifier free guidance embeddings to the image embeddings + image_embeddings_batch_size = image_embeddings.shape[0] + classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) + classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( + image_embeddings_batch_size, -1 + ) + image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) + + # The image embeddings batch size and the text embeddings batch size are equal + assert image_embeddings.shape[0] == text_embeddings.shape[0] + + batch_size = text_embeddings.shape[0] + + # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and + # adding CLIP embeddings to the existing timestep embedding, ... + time_projected_text_embeddings = self.embedding_proj(text_embeddings) + time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) + additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings + + # ... and by projecting CLIP embeddings into four + # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" + clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) + clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) + + text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) + text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) + text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2) + + return text_encoder_hidden_states, additive_clip_time_embeddings diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 6a9e83f134..2d35bced7d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -6,8 +6,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...modeling_utils import ModelMixin -from ...models.attention import DualTransformer2DModel, Transformer2DModel +from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel from ...models.embeddings import TimestepEmbedding, Timesteps +from ...models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn as UNetMidBlockFlatSimpleCrossAttn from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import logging @@ -141,6 +142,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): + The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): @@ -153,6 +156,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately + summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. """ _supports_gradient_checkpointing = True @@ -172,6 +179,7 @@ def __init__( "CrossAttnDownBlockFlat", "DownBlockFlat", ), + mid_block_type: str = "UNetMidBlockFlatCrossAttn", up_block_types: Tuple[str] = ( "UpBlockFlat", "CrossAttnUpBlockFlat", @@ -190,8 +198,10 @@ def __init__( attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", ): super().__init__() @@ -208,8 +218,14 @@ def __init__( self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding - if num_class_embeds is not None: + if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None @@ -245,24 +261,40 @@ def __init__( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) # mid - self.mid_block = UNetMidBlockFlatCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift="default", - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) + if mid_block_type == "UNetMidBlockFlatCrossAttn": + self.mid_block = UNetMidBlockFlatCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": + self.mid_block = UNetMidBlockFlatSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 @@ -303,6 +335,7 @@ def __init__( use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -387,6 +420,7 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -416,6 +450,11 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -445,9 +484,13 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) - if self.config.num_class_embeds is not None: + if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb @@ -462,6 +505,7 @@ def forward( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -469,7 +513,9 @@ def forward( down_block_res_samples += res_samples # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) # 5. up for i, upsample_block in enumerate(self.up_blocks): @@ -490,6 +536,7 @@ def forward( res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, + attention_mask=attention_mask, ) else: sample = upsample_block( @@ -715,7 +762,6 @@ def __init__( resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, - attention_type="default", output_scale_factor=1.0, downsample_padding=1, add_downsample=True, @@ -729,7 +775,6 @@ def __init__( attentions = [] self.has_cross_attention = True - self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): @@ -789,7 +834,8 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -806,7 +852,9 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -915,7 +963,6 @@ def __init__( resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, - attention_type="default", output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, @@ -928,7 +975,6 @@ def __init__( attentions = [] self.has_cross_attention = True - self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): @@ -991,7 +1037,9 @@ def forward( temb=None, encoder_hidden_states=None, upsample_size=None, + attention_mask=None, ): + # TODO(Patrick, William) - attention mask is not used for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1011,7 +1059,9 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1038,7 +1088,6 @@ def __init__( resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, - attention_type="default", output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, @@ -1048,7 +1097,6 @@ def __init__( super().__init__() self.has_cross_attention = True - self.attention_type = attention_type self.attn_num_head_channels = attn_num_head_channels resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -1112,10 +1160,122 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + # TODO(Patrick, William) - attention_mask is currently not used. Implement once used hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + ): + super().__init__() + + self.has_cross_attention = True + + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attn_num_head_channels + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + CrossAttention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=attn_num_head_channels, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + attention_mask=attention_mask, + ) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + # resnet hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b2af345782..6b90fe79f6 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -37,6 +37,7 @@ from .scheduling_repaint import RePaintScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler + from .scheduling_unclip import UnCLIPScheduler from .scheduling_utils import SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py new file mode 100644 index 0000000000..243ca2f009 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -0,0 +1,309 @@ +# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP +class UnCLIPSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class UnCLIPScheduler(SchedulerMixin, ConfigMixin): + """ + This is a modified DDPM Scheduler specifically for the karlo unCLIP model. + + This scheduler has some minor variations in how it calculates the learned range variance and dynamically + re-calculates betas based off the timesteps it is skipping. + + The scheduler also uses a slightly different step ratio when computing timesteps to use for inference. + + See [`~DDPMScheduler`] for more information on DDPM scheduling + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log` + or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical + stability. + clip_sample_range (`float`, default `1.0`): + The range to clip the sample between. See `clip_sample`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) + or `sample` (directly predicting the noisy sample`) + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + variance_type: str = "fixed_small_log", + clip_sample: bool = True, + clip_sample_range: Optional[float] = 1.0, + prediction_type: str = "epsilon", + ): + # beta scheduler is "squaredcos_cap_v2" + self.betas = betas_for_alpha_bar(num_train_timesteps) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + self.variance_type = variance_type + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The + different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy + of the results. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None): + if prev_timestep is None: + prev_timestep = t - 1 + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if prev_timestep == t - 1: + beta = self.betas[t] + else: + beta = 1 - alpha_prod_t / alpha_prod_t_prev + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = beta_prod_t_prev / beta_prod_t * beta + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small_log": + variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.exp(0.5 * variance) + elif variance_type == "learned_range": + # NOTE difference with DDPM scheduler + min_log = variance.log() + max_log = beta.log() + + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + prev_timestep: Optional[int] = None, + generator=None, + return_dict: bool = True, + ) -> Union[UnCLIPSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at. + Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range": + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + if prev_timestep is None: + prev_timestep = t - 1 + + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if prev_timestep == t - 1: + beta = self.betas[t] + alpha = self.alphas[t] + else: + beta = 1 - alpha_prod_t / alpha_prod_t_prev + alpha = 1 - beta + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`" + " for the UnCLIPScheduler." + ) + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t + current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + device = model_output.device + if device.type == "mps": + # randn does not work reproducibly on mps + variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) + variance_noise = variance_noise.to(device) + else: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + + variance = self._get_variance( + t, + predicted_variance=predicted_variance, + prev_timestep=prev_timestep, + ) + + if self.variance_type == "fixed_small_log": + variance = variance + elif self.variance_type == "learned_range": + variance = (0.5 * variance).exp() + else: + raise ValueError( + f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`" + " for the UnCLIPScheduler." + ) + + variance = variance * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index dc036c82c6..615f84d115 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class PriorTransformer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -512,6 +527,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UnCLIPScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class VQDiffusionScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 160b83a7e6..ba2798c784 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -199,6 +199,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class UnCLIPPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/unclip/__init__.py b/tests/pipelines/unclip/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py new file mode 100644 index 0000000000..ba49625511 --- /dev/null +++ b/tests/pipelines/unclip/test_unclip.py @@ -0,0 +1,282 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel +from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel +from diffusers.utils import load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class UnCLIPPipelineFastTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def cross_attention_dim(self): + return 100 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + projection_dim=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModelWithProjection(config) + + @property + def dummy_prior(self): + torch.manual_seed(0) + + model_kwargs = { + "num_attention_heads": 2, + "attention_head_dim": 12, + "embedding_dim": self.text_embedder_hidden_size, + "num_layers": 1, + } + + model = PriorTransformer(**model_kwargs) + return model + + @property + def dummy_text_proj(self): + torch.manual_seed(0) + + model_kwargs = { + "clip_embeddings_dim": self.text_embedder_hidden_size, + "time_embed_dim": self.time_embed_dim, + "cross_attention_dim": self.cross_attention_dim, + } + + model = UnCLIPTextProjModel(**model_kwargs) + return model + + @property + def dummy_decoder(self): + torch.manual_seed(0) + + model_kwargs = { + "sample_size": 64, + # RGB in channels + "in_channels": 3, + # Out channels is double in channels because predicts mean and variance + "out_channels": 6, + "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"), + "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "layers_per_block": 1, + "cross_attention_dim": self.cross_attention_dim, + "attention_head_dim": 4, + "resnet_time_scale_shift": "scale_shift", + "class_embed_type": "identity", + } + + model = UNet2DConditionModel(**model_kwargs) + return model + + @property + def dummy_super_res_kwargs(self): + return { + "sample_size": 128, + "layers_per_block": 1, + "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"), + "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"), + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "in_channels": 6, + "out_channels": 3, + } + + @property + def dummy_super_res_first(self): + torch.manual_seed(0) + + model = UNet2DModel(**self.dummy_super_res_kwargs) + return model + + @property + def dummy_super_res_last(self): + # seeded differently to get different unet than `self.dummy_super_res_first` + torch.manual_seed(1) + + model = UNet2DModel(**self.dummy_super_res_kwargs) + return model + + def test_unclip(self): + device = "cpu" + + prior = self.dummy_prior + decoder = self.dummy_decoder + text_proj = self.dummy_text_proj + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + super_res_first = self.dummy_super_res_first + super_res_last = self.dummy_super_res_last + + prior_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="sample", + num_train_timesteps=1000, + clip_sample_range=5.0, + ) + + decoder_scheduler = UnCLIPScheduler( + variance_type="learned_range", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + super_res_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + pipe = UnCLIPPipeline( + prior=prior, + decoder=decoder, + text_proj=text_proj, + text_encoder=text_encoder, + tokenizer=tokenizer, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + prompt = "horse" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe( + [prompt], + generator=generator, + prior_num_inference_steps=2, + decoder_num_inference_steps=2, + super_res_num_inference_steps=2, + output_type="np", + ) + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = pipe( + [prompt], + generator=generator, + prior_num_inference_steps=2, + decoder_num_inference_steps=2, + super_res_num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 128, 128, 3) + + expected_slice = np.array( + [ + 0.0009, + 0.9997, + 0.0003, + 0.9991, + 0.9967, + 0.0003, + 0.9997, + 0.0003, + 0.0004, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class UnCLIPPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_unclip_karlo(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/karlo_v1_alpha/horse.npy" + ) + + pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha") + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipeline( + "horse", + num_images_per_prompt=1, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).max() < 1e-2 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index ba109bdf47..8b27ab52e0 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -38,6 +38,7 @@ LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler, + UnCLIPScheduler, VQDiffusionScheduler, logging, ) @@ -2643,3 +2644,135 @@ def test_full_loop_device(self): # CUDA assert abs(result_sum.item() - 13913.0332) < 1e-1 assert abs(result_mean.item() - 18.1159) < 1e-3 + + +# UnCLIPScheduler is a modified DDPMScheduler with a subset of the configuration. +class UnCLIPSchedulerTest(SchedulerCommonTest): + scheduler_classes = (UnCLIPScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "variance_type": "fixed_small_log", + "clip_sample": True, + "clip_sample_range": 1.0, + "prediction_type": "epsilon", + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [1, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_variance_type(self): + for variance in ["fixed_small_log", "learned_range"]: + self.check_over_configs(variance_type=variance) + + def test_clip_sample(self): + for clip_sample in [True, False]: + self.check_over_configs(clip_sample=clip_sample) + + def test_clip_sample_range(self): + for clip_sample_range in [1, 5, 10, 20]: + self.check_over_configs(clip_sample_range=clip_sample_range) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_time_indices(self): + for time_step in [0, 500, 999]: + for prev_timestep in [None, 5, 100, 250, 500, 750]: + if prev_timestep is not None and prev_timestep >= time_step: + continue + + self.check_over_forward(time_step=time_step, prev_timestep=prev_timestep) + + def test_variance_fixed_small_log(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(variance_type="fixed_small_log") + scheduler = scheduler_class(**scheduler_config) + + assert torch.sum(torch.abs(scheduler._get_variance(0) - 1.0000e-10)) < 1e-5 + assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.0549625)) < 1e-5 + assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.9994987)) < 1e-5 + + def test_variance_learned_range(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(variance_type="learned_range") + scheduler = scheduler_class(**scheduler_config) + + predicted_variance = 0.5 + + assert scheduler._get_variance(1, predicted_variance=predicted_variance) - -10.1712790 < 1e-5 + assert scheduler._get_variance(487, predicted_variance=predicted_variance) - -5.7998052 < 1e-5 + assert scheduler._get_variance(999, predicted_variance=predicted_variance) - -0.0010011 < 1e-5 + + def test_full_loop(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = scheduler.timesteps + + model = self.dummy_model() + sample = self.dummy_sample_deter + generator = torch.manual_seed(0) + + for i, t in enumerate(timesteps): + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + sample = pred_prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 252.2682495) < 1e-2 + assert abs(result_mean.item() - 0.3284743) < 1e-3 + + def test_full_loop_skip_timesteps(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(25) + + timesteps = scheduler.timesteps + + model = self.dummy_model() + sample = self.dummy_sample_deter + generator = torch.manual_seed(0) + + for i, t in enumerate(timesteps): + # 1. predict noise residual + residual = model(sample, t) + + if i + 1 == timesteps.shape[0]: + prev_timestep = None + else: + prev_timestep = timesteps[i + 1] + + # 2. predict previous mean of sample x_t-1 + pred_prev_sample = scheduler.step( + residual, t, sample, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + sample = pred_prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 258.2044983) < 1e-2 + assert abs(result_mean.item() - 0.3362038) < 1e-3 + + def test_trained_betas(self): + pass + + def test_add_noise_device(self): + pass