diff --git a/.gitignore b/.gitignore
new file mode 100644
index 000000000..0c70c5e90
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+.pt2
+.pt2_2
+.pt13
+*.egg-info
+build
+/outputs
+/checkpoints
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 000000000..b01c5a683
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,75 @@
+SDXL 0.9 RESEARCH LICENSE AGREEMENT
+Copyright (c) Stability AI Ltd.
+This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”).
+By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.
+1. LICENSE GRANT
+
+a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AI’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.
+
+b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.
+
+c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License.
+
+
+2. RESTRICTIONS
+
+You will not, and will not permit, assist or cause any third party to:
+
+a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
+
+b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
+
+c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or
+
+d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.
+
+e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
+
+
+3. ATTRIBUTION
+
+Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “SDXL 0.9 is licensed under the SDXL Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”
+
+
+4. DISCLAIMERS
+
+THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
+
+
+5. LIMITATION OF LIABILITY
+
+TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
+
+
+6. INDEMNIFICATION
+
+You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties.
+
+
+7. TERMINATION; SURVIVAL
+
+a. This License will automatically terminate upon any breach by you of the terms of this License.
+
+b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
+
+c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
+
+
+8. THIRD PARTY MATERIALS
+
+The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
+
+
+9. TRADEMARKS
+
+Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
+
+
+10. APPLICABLE LAW; DISPUTE RESOLUTION
+
+This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.
+
+
+11. MISCELLANEOUS
+
+If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..242ec4a19
--- /dev/null
+++ b/README.md
@@ -0,0 +1,187 @@
+# Generative Models by Stability AI
+
+
+
+## News
+
+**June 22, 2023**
+
+
+- We are releasing two new diffusion models:
+ - `SD-XL 0.9-base`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
+ - `SD-XL 0.9-refiner`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
+
+**We plan to do a full release soon (July).**
+
+## The codebase
+
+### General Philosophy
+
+Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
+
+### Changelog from the old `ldm` codebase
+
+For training, we use [pytorch-lightning](https://www.pytorchlightning.ai/index.html), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
+
+- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
+- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
+ samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
+- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
+ * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
+ * The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
+- Autoencoding models have also been cleaned up.
+
+## Installation:
+
+
+#### 1. Clone the repo
+
+```shell
+git clone git@github.com:Stability-AI/generative-models.git
+cd generative-models
+```
+
+#### 2. Setting up the virtualenv
+
+This is assuming you have navigated to the `generative-models` root after cloning it.
+
+**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
+
+
+**PyTorch 1.13**
+
+```shell
+# install required packages from pypi
+python3 -m venv .pt1
+source .pt1/bin/activate
+pip3 install wheel
+pip3 install -r requirements_pt13.txt
+```
+
+**PyTorch 2.0**
+
+
+```shell
+# install required packages from pypi
+python3 -m venv .pt2
+source .pt2/bin/activate
+pip3 install wheel
+pip3 install -r requirements_pt2.txt
+```
+
+## Inference:
+
+We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported:
+- [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
+- [SD-XL 0.9-refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
+- [SD 2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
+- [SD 2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
+
+**Weights for SDXL**:
+If you would like to access these models for your research, please apply using one of the following links:
+[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
+This means that you can apply for any of the two links - and if you are granted - you can access both.
+Please log in to your HuggingFace Account with your organization email to request access.
+
+After obtaining the weights, place them into `checkpoints/`.
+Next, start the demo using
+
+```
+streamlit run scripts/demo/sampling.py --server.port
+```
+
+### Invisible Watermark Detection
+
+Images generated with our code use the
+[invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
+library to embed an invisible watermark into the model output. We also provide
+a script to easily detect that watermark. Please note that this watermark is
+not the same as in previous Stable Diffusion 1.x/2.x versions.
+
+To run the script you need to either have a working installation as above or
+try an _experimental_ import using only a minimal amount of packages:
+```bash
+python -m venv .detect
+source .detect/bin/activate
+
+pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
+pip install --no-deps invisible-watermark
+```
+
+To run the script you need to have a working installation as above. The script
+is then useable in the following ways (don't forget to activate your
+virtual environment beforehand, e.g. `source .pt1/bin/activate`):
+```bash
+# test a single file
+python scripts/demo/detect.py
+# test multiple files at once
+python scripts/demo/detect.py ...
+# test all files in a specific folder
+python scripts/demo/detect.py /*
+```
+
+## Training:
+
+We are providing example training configs in `configs/example_training`. To launch a training, run
+
+```
+python main.py --base configs/ configs/
+```
+
+where configs are merged from left to right (later configs overwrite the same values).
+This can be used to combine model, training and data configs. However, all of them can also be
+defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
+run
+
+```bash
+python main.py --base configs/example_training/toy/mnist_cond.yaml
+```
+
+**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depdending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
+
+**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
+
+**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
+
+### Building New Diffusion Models
+
+#### Conditioner
+
+The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
+different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
+All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
+guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
+When computing conditionings, the embedder will get `batch[input_key]` as input.
+We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
+appropriately.
+Note that the order of the embedders in the `conditioner_config` is important.
+
+#### Network
+
+The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
+enough as we plan to experiment with transformer-based diffusion backbones.
+
+#### Loss
+
+The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
+
+#### Sampler config
+
+As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
+solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
+guidance.
+
+### Dataset Handling
+
+
+For large scale training we recommend using the datapipelines from our [datapipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
+Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
+data keys/values,
+e.g.,
+
+```python
+example = {"jpg": x, # this is a tensor -1...1 chw
+ "txt": "a beautiful image"}
+```
+
+where we expect images in -1...1, channel-first format.
diff --git a/assets/000.jpg b/assets/000.jpg
new file mode 100644
index 000000000..e93d6c1b6
Binary files /dev/null and b/assets/000.jpg differ
diff --git a/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml b/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml
new file mode 100644
index 000000000..482b25901
--- /dev/null
+++ b/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml
@@ -0,0 +1,115 @@
+model:
+ base_learning_rate: 4.5e-6
+ target: sgm.models.autoencoder.AutoencodingEngine
+ params:
+ input_key: jpg
+ monitor: val/rec_loss
+
+ loss_config:
+ target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
+ params:
+ perceptual_weight: 0.25
+ disc_start: 20001
+ disc_weight: 0.5
+ learn_logvar: True
+
+ regularization_weights:
+ kl_loss: 1.0
+
+ regularizer_config:
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
+
+ encoder_config:
+ target: sgm.modules.diffusionmodules.model.Encoder
+ params:
+ attn_type: none
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 4 ]
+ num_res_blocks: 4
+ attn_resolutions: [ ]
+ dropout: 0.0
+
+ decoder_config:
+ target: sgm.modules.diffusionmodules.model.Decoder
+ params:
+ attn_type: none
+ double_z: False
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 4 ]
+ num_res_blocks: 4
+ attn_resolutions: [ ]
+ dropout: 0.0
+
+data:
+ target: sgm.data.dataset.StableDataModuleFromConfig
+ params:
+ train:
+ datapipeline:
+ urls:
+ - "DATA-PATH"
+ pipeline_config:
+ shardshuffle: 10000
+ sample_shuffle: 10000
+
+ decoders:
+ - "pil"
+
+ postprocessors:
+ - target: sdata.mappers.TorchVisionImageTransforms
+ params:
+ key: 'jpg'
+ transforms:
+ - target: torchvision.transforms.Resize
+ params:
+ size: 256
+ interpolation: 3
+ - target: torchvision.transforms.ToTensor
+ - target: sdata.mappers.Rescaler
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
+ params:
+ h_key: height
+ w_key: width
+
+ loader:
+ batch_size: 8
+ num_workers: 4
+
+
+lightning:
+ strategy:
+ target: pytorch_lightning.strategies.DDPStrategy
+ params:
+ find_unused_parameters: True
+
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 50000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ enable_autocast: False
+ batch_frequency: 1000
+ max_images: 8
+ increase_log_steps: True
+
+ trainer:
+ devices: 0,
+ limit_val_batches: 50
+ benchmark: True
+ accumulate_grad_batches: 1
+ val_check_interval: 10000
\ No newline at end of file
diff --git a/configs/example_training/imagenet-f8_cond.yaml b/configs/example_training/imagenet-f8_cond.yaml
new file mode 100644
index 000000000..60627331b
--- /dev/null
+++ b/configs/example_training/imagenet-f8_cond.yaml
@@ -0,0 +1,188 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.13025
+ disable_first_stage_autocast: True
+ log_keys:
+ - cls
+
+ scheduler_config:
+ target: sgm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [10000]
+ cycle_lengths: [10000000000000]
+ f_start: [1.e-6]
+ f_max: [1.]
+ f_min: [1.]
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ use_fp16: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 256
+ attention_resolutions: [1, 2, 4]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4]
+ num_head_channels: 64
+ num_classes: sequential
+ adm_in_channels: 1024
+ use_spatial_transformer: true
+ transformer_depth: 1
+ context_dim: 1024
+ spatial_transformer_attn_type: softmax-xformers
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn cond
+ - is_trainable: True
+ input_key: cls
+ ucg_rate: 0.2
+ target: sgm.modules.encoders.modules.ClassEmbedder
+ params:
+ add_sequence_dim: True # will be used through crossattn then
+ embed_dim: 1024
+ n_classes: 1000
+ # vector cond
+ - is_trainable: False
+ ucg_rate: 0.2
+ input_key: original_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: crop_coords_top_left
+ ucg_rate: 0.2
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ ckpt_path: CKPT_PATH
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ num_idx: 1000
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
+ params:
+ scale: 5.0
+
+data:
+ target: sgm.data.dataset.StableDataModuleFromConfig
+ params:
+ train:
+ datapipeline:
+ urls:
+ # USER: adapt this path the root of your custom dataset
+ - "DATA_PATH"
+ pipeline_config:
+ shardshuffle: 10000
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
+
+ decoders:
+ - "pil"
+
+ postprocessors:
+ - target: sdata.mappers.TorchVisionImageTransforms
+ params:
+ key: 'jpg' # USER: you might wanna adapt this for your custom dataset
+ transforms:
+ - target: torchvision.transforms.Resize
+ params:
+ size: 256
+ interpolation: 3
+ - target: torchvision.transforms.ToTensor
+ - target: sdata.mappers.Rescaler
+
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
+ params:
+ h_key: height # USER: you might wanna adapt this for your custom dataset
+ w_key: width # USER: you might wanna adapt this for your custom dataset
+
+ loader:
+ batch_size: 64
+ num_workers: 6
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ enable_autocast: False
+ batch_frequency: 1000
+ max_images: 8
+ increase_log_steps: True
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 8
+ n_rows: 2
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 1000
\ No newline at end of file
diff --git a/configs/example_training/toy/cifar10_cond.yaml b/configs/example_training/toy/cifar10_cond.yaml
new file mode 100644
index 000000000..36ba2527a
--- /dev/null
+++ b/configs/example_training/toy/cifar10_cond.yaml
@@ -0,0 +1,99 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
+ params:
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
+ params:
+ sigma_data: 1.0
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
+ params:
+ sigma_data: 1.0
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ in_channels: 3
+ out_channels: 3
+ model_channels: 32
+ attention_resolutions: []
+ num_res_blocks: 4
+ channel_mult: [1, 2, 2]
+ num_head_channels: 32
+ num_classes: sequential
+ adm_in_channels: 128
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: True
+ input_key: cls
+ ucg_rate: 0.2
+ target: sgm.modules.encoders.modules.ClassEmbedder
+ params:
+ embed_dim: 128
+ n_classes: 10
+
+ first_stage_config:
+ target: sgm.models.autoencoder.IdentityFirstStage
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
+ params:
+ scale: 3.0
+
+data:
+ target: sgm.data.cifar10.CIFAR10Loader
+ params:
+ batch_size: 512
+ num_workers: 1
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ batch_frequency: 1000
+ max_images: 64
+ increase_log_steps: True
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 64
+ n_rows: 8
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 20
\ No newline at end of file
diff --git a/configs/example_training/toy/mnist.yaml b/configs/example_training/toy/mnist.yaml
new file mode 100644
index 000000000..44d8e6fea
--- /dev/null
+++ b/configs/example_training/toy/mnist.yaml
@@ -0,0 +1,80 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
+ params:
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
+ params:
+ sigma_data: 1.0
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
+ params:
+ sigma_data: 1.0
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ in_channels: 1
+ out_channels: 1
+ model_channels: 32
+ attention_resolutions: []
+ num_res_blocks: 4
+ channel_mult: [1, 2, 2]
+ num_head_channels: 32
+
+ first_stage_config:
+ target: sgm.models.autoencoder.IdentityFirstStage
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
+
+data:
+ target: sgm.data.mnist.MNISTLoader
+ params:
+ batch_size: 512
+ num_workers: 1
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ batch_frequency: 1000
+ max_images: 64
+ increase_log_steps: False
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 64
+ n_rows: 8
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 10
\ No newline at end of file
diff --git a/configs/example_training/toy/mnist_cond.yaml b/configs/example_training/toy/mnist_cond.yaml
new file mode 100644
index 000000000..557be128b
--- /dev/null
+++ b/configs/example_training/toy/mnist_cond.yaml
@@ -0,0 +1,99 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
+ params:
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
+ params:
+ sigma_data: 1.0
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
+ params:
+ sigma_data: 1.0
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ in_channels: 1
+ out_channels: 1
+ model_channels: 32
+ attention_resolutions: [ ]
+ num_res_blocks: 4
+ channel_mult: [ 1, 2, 2 ]
+ num_head_channels: 32
+ num_classes: sequential
+ adm_in_channels: 128
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: True
+ input_key: "cls"
+ ucg_rate: 0.2
+ target: sgm.modules.encoders.modules.ClassEmbedder
+ params:
+ embed_dim: 128
+ n_classes: 10
+
+ first_stage_config:
+ target: sgm.models.autoencoder.IdentityFirstStage
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
+ params:
+ scale: 3.0
+
+data:
+ target: sgm.data.mnist.MNISTLoader
+ params:
+ batch_size: 512
+ num_workers: 1
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ batch_frequency: 1000
+ max_images: 16
+ increase_log_steps: True
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 16
+ n_rows: 4
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 20
\ No newline at end of file
diff --git a/configs/example_training/toy/mnist_cond_discrete_eps.yaml b/configs/example_training/toy/mnist_cond_discrete_eps.yaml
new file mode 100644
index 000000000..f92b4cdf0
--- /dev/null
+++ b/configs/example_training/toy/mnist_cond_discrete_eps.yaml
@@ -0,0 +1,104 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ in_channels: 1
+ out_channels: 1
+ model_channels: 32
+ attention_resolutions: [ ]
+ num_res_blocks: 4
+ channel_mult: [ 1, 2, 2 ]
+ num_head_channels: 32
+ num_classes: sequential
+ adm_in_channels: 128
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: True
+ input_key: "cls"
+ ucg_rate: 0.2
+ target: sgm.modules.encoders.modules.ClassEmbedder
+ params:
+ embed_dim: 128
+ n_classes: 10
+
+ first_stage_config:
+ target: sgm.models.autoencoder.IdentityFirstStage
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ num_idx: 1000
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
+ params:
+ scale: 5.0
+
+data:
+ target: sgm.data.mnist.MNISTLoader
+ params:
+ batch_size: 512
+ num_workers: 1
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ batch_frequency: 1000
+ max_images: 16
+ increase_log_steps: True
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 16
+ n_rows: 4
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 20
\ No newline at end of file
diff --git a/configs/example_training/toy/mnist_cond_l1_loss.yaml b/configs/example_training/toy/mnist_cond_l1_loss.yaml
new file mode 100644
index 000000000..42b153004
--- /dev/null
+++ b/configs/example_training/toy/mnist_cond_l1_loss.yaml
@@ -0,0 +1,104 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
+ params:
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
+ params:
+ sigma_data: 1.0
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
+ params:
+ sigma_data: 1.0
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ in_channels: 1
+ out_channels: 1
+ model_channels: 32
+ attention_resolutions: []
+ num_res_blocks: 4
+ channel_mult: [1, 2, 2]
+ num_head_channels: 32
+ num_classes: "sequential"
+ adm_in_channels: 128
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: True
+ input_key: "cls"
+ ucg_rate: 0.2
+ target: sgm.modules.encoders.modules.ClassEmbedder
+ params:
+ embed_dim: 128
+ n_classes: 10
+
+ first_stage_config:
+ target: sgm.models.autoencoder.IdentityFirstStage
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
+ params:
+ scale: 3.0
+
+ loss_config:
+ target: sgm.modules.diffusionmodules.StandardDiffusionLoss
+ params:
+ type: l1
+
+data:
+ target: sgm.data.mnist.MNISTLoader
+ params:
+ batch_size: 512
+ num_workers: 1
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ batch_frequency: 1000
+ max_images: 64
+ increase_log_steps: True
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 64
+ n_rows: 8
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 20
\ No newline at end of file
diff --git a/configs/example_training/toy/mnist_cond_with_ema.yaml b/configs/example_training/toy/mnist_cond_with_ema.yaml
new file mode 100644
index 000000000..632e8b420
--- /dev/null
+++ b/configs/example_training/toy/mnist_cond_with_ema.yaml
@@ -0,0 +1,101 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ use_ema: True
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
+ params:
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
+ params:
+ sigma_data: 1.0
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
+ params:
+ sigma_data: 1.0
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ in_channels: 1
+ out_channels: 1
+ model_channels: 32
+ attention_resolutions: []
+ num_res_blocks: 4
+ channel_mult: [1, 2, 2]
+ num_head_channels: 32
+ num_classes: sequential
+ adm_in_channels: 128
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: True
+ input_key: cls
+ ucg_rate: 0.2
+ target: sgm.modules.encoders.modules.ClassEmbedder
+ params:
+ embed_dim: 128
+ n_classes: 10
+
+ first_stage_config:
+ target: sgm.models.autoencoder.IdentityFirstStage
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
+ params:
+ scale: 3.0
+
+data:
+ target: sgm.data.mnist.MNISTLoader
+ params:
+ batch_size: 512
+ num_workers: 1
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ batch_frequency: 1000
+ max_images: 64
+ increase_log_steps: True
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 64
+ n_rows: 8
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 20
\ No newline at end of file
diff --git a/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml b/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml
new file mode 100644
index 000000000..4c92ccf04
--- /dev/null
+++ b/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml
@@ -0,0 +1,185 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.13025
+ disable_first_stage_autocast: True
+ log_keys:
+ - txt
+
+ scheduler_config:
+ target: sgm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ]
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ use_fp16: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 1, 2, 4 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64
+ num_classes: sequential
+ adm_in_channels: 1792
+ num_heads: 1
+ use_spatial_transformer: true
+ transformer_depth: 1
+ context_dim: 768
+ spatial_transformer_attn_type: softmax-xformers
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn cond
+ - is_trainable: True
+ input_key: txt
+ ucg_rate: 0.1
+ legacy_ucg_value: ""
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ always_return_pooled: True
+ # vector cond
+ - is_trainable: False
+ ucg_rate: 0.1
+ input_key: original_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: crop_coords_top_left
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ ckpt_path: CKPT_PATH
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 4, 4 ]
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ num_idx: 1000
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
+ params:
+ scale: 7.5
+
+data:
+ target: sgm.data.dataset.StableDataModuleFromConfig
+ params:
+ train:
+ datapipeline:
+ urls:
+ # USER: adapt this path the root of your custom dataset
+ - "DATA_PATH"
+ pipeline_config:
+ shardshuffle: 10000
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
+
+ decoders:
+ - "pil"
+
+ postprocessors:
+ - target: sdata.mappers.TorchVisionImageTransforms
+ params:
+ key: 'jpg' # USER: you might wanna adapt this for your custom dataset
+ transforms:
+ - target: torchvision.transforms.Resize
+ params:
+ size: 256
+ interpolation: 3
+ - target: torchvision.transforms.ToTensor
+ - target: sdata.mappers.Rescaler
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
+ # USER: you might wanna use non-default parameters due to your custom dataset
+
+ loader:
+ batch_size: 64
+ num_workers: 6
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ enable_autocast: False
+ batch_frequency: 1000
+ max_images: 8
+ increase_log_steps: True
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 8
+ n_rows: 2
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 1000
\ No newline at end of file
diff --git a/configs/example_training/txt2img-clipl.yaml b/configs/example_training/txt2img-clipl.yaml
new file mode 100644
index 000000000..1676fef7b
--- /dev/null
+++ b/configs/example_training/txt2img-clipl.yaml
@@ -0,0 +1,186 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.13025
+ disable_first_stage_autocast: True
+ log_keys:
+ - txt
+
+ scheduler_config:
+ target: sgm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ]
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ use_fp16: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 1, 2, 4 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64
+ num_classes: sequential
+ adm_in_channels: 1792
+ num_heads: 1
+ use_spatial_transformer: true
+ transformer_depth: 1
+ context_dim: 768
+ spatial_transformer_attn_type: softmax-xformers
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn cond
+ - is_trainable: True
+ input_key: txt
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ always_return_pooled: True
+ # vector cond
+ - is_trainable: False
+ ucg_rate: 0.1
+ input_key: original_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: crop_coords_top_left
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ ckpt_path: CKPT_PATH
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 4, 4 ]
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ num_idx: 1000
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 50
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
+ params:
+ scale: 7.5
+
+data:
+ target: sgm.data.dataset.StableDataModuleFromConfig
+ params:
+ train:
+ datapipeline:
+ urls:
+ # USER: adapt this path the root of your custom dataset
+ - "DATA_PATH"
+ pipeline_config:
+ shardshuffle: 10000
+ sample_shuffle: 10000
+
+
+ decoders:
+ - "pil"
+
+ postprocessors:
+ - target: sdata.mappers.TorchVisionImageTransforms
+ params:
+ key: 'jpg' # USER: you might wanna adapt this for your custom dataset
+ transforms:
+ - target: torchvision.transforms.Resize
+ params:
+ size: 256
+ interpolation: 3
+ - target: torchvision.transforms.ToTensor
+ - target: sdata.mappers.Rescaler
+ # USER: you might wanna use non-default parameters due to your custom dataset
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
+ # USER: you might wanna use non-default parameters due to your custom dataset
+
+ loader:
+ batch_size: 64
+ num_workers: 6
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 25000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ disabled: False
+ enable_autocast: False
+ batch_frequency: 1000
+ max_images: 8
+ increase_log_steps: True
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ N: 8
+ n_rows: 2
+
+ trainer:
+ devices: 0,
+ benchmark: True
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
+ max_epochs: 1000
\ No newline at end of file
diff --git a/configs/inference/sd_2_1.yaml b/configs/inference/sd_2_1.yaml
new file mode 100644
index 000000000..22bb63d19
--- /dev/null
+++ b/configs/inference/sd_2_1.yaml
@@ -0,0 +1,66 @@
+model:
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.18215
+ disable_first_stage_autocast: True
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ use_fp16: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [4, 2, 1]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn cond
+ - is_trainable: False
+ input_key: txt
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: true
+ layer: penultimate
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
\ No newline at end of file
diff --git a/configs/inference/sd_2_1_768.yaml b/configs/inference/sd_2_1_768.yaml
new file mode 100644
index 000000000..71a0a121f
--- /dev/null
+++ b/configs/inference/sd_2_1_768.yaml
@@ -0,0 +1,66 @@
+model:
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.18215
+ disable_first_stage_autocast: True
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.VWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ use_fp16: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [4, 2, 1]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn cond
+ - is_trainable: False
+ input_key: txt
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: true
+ layer: penultimate
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
\ No newline at end of file
diff --git a/configs/inference/sd_xl_base.yaml b/configs/inference/sd_xl_base.yaml
new file mode 100644
index 000000000..8aaf5b6ec
--- /dev/null
+++ b/configs/inference/sd_xl_base.yaml
@@ -0,0 +1,98 @@
+model:
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.13025
+ disable_first_stage_autocast: True
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ adm_in_channels: 2816
+ num_classes: sequential
+ use_checkpoint: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [4, 2]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
+ context_dim: 2048
+ spatial_transformer_attn_type: softmax-xformers
+ legacy: False
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn cond
+ - is_trainable: False
+ input_key: txt
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ layer: hidden
+ layer_idx: 11
+ # crossattn and vector cond
+ - is_trainable: False
+ input_key: txt
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
+ params:
+ arch: ViT-bigG-14
+ version: laion2b_s39b_b160k
+ freeze: True
+ layer: penultimate
+ always_return_pooled: True
+ legacy: False
+ # vector cond
+ - is_trainable: False
+ input_key: original_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: crop_coords_top_left
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: target_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
diff --git a/configs/inference/sd_xl_refiner.yaml b/configs/inference/sd_xl_refiner.yaml
new file mode 100644
index 000000000..cab5fe283
--- /dev/null
+++ b/configs/inference/sd_xl_refiner.yaml
@@ -0,0 +1,91 @@
+model:
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.13025
+ disable_first_stage_autocast: True
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ adm_in_channels: 2560
+ num_classes: sequential
+ use_checkpoint: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 384
+ attention_resolutions: [4, 2]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 4
+ context_dim: [1280, 1280, 1280, 1280] # 1280
+ spatial_transformer_attn_type: softmax-xformers
+ legacy: False
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn and vector cond
+ - is_trainable: False
+ input_key: txt
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
+ params:
+ arch: ViT-bigG-14
+ version: laion2b_s39b_b160k
+ legacy: False
+ freeze: True
+ layer: penultimate
+ always_return_pooled: True
+ # vector cond
+ - is_trainable: False
+ input_key: original_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: crop_coords_top_left
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: aesthetic_score
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by one
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
diff --git a/data/DejaVuSans.ttf b/data/DejaVuSans.ttf
new file mode 100644
index 000000000..e5f7eecce
Binary files /dev/null and b/data/DejaVuSans.ttf differ
diff --git a/main.py b/main.py
new file mode 100644
index 000000000..66b74d19a
--- /dev/null
+++ b/main.py
@@ -0,0 +1,947 @@
+import argparse
+import datetime
+import glob
+import inspect
+import os
+import sys
+from inspect import Parameter
+from typing import Union
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torchvision
+import wandb
+from PIL import Image
+from matplotlib import pyplot as plt
+from natsort import natsorted
+from omegaconf import OmegaConf
+from packaging import version
+from pytorch_lightning import seed_everything
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.trainer import Trainer
+from pytorch_lightning.utilities import rank_zero_only
+
+from sgm.util import (
+ exists,
+ instantiate_from_config,
+ isheatmap,
+)
+
+MULTINODE_HACKS = True
+
+
+def default_trainer_args():
+ argspec = dict(inspect.signature(Trainer.__init__).parameters)
+ argspec.pop("self")
+ default_args = {
+ param: argspec[param].default
+ for param in argspec
+ if argspec[param] != Parameter.empty
+ }
+ return default_args
+
+
+def get_parser(**parser_kwargs):
+ def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ parser.add_argument(
+ "-n",
+ "--name",
+ type=str,
+ const=True,
+ default="",
+ nargs="?",
+ help="postfix for logdir",
+ )
+ parser.add_argument(
+ "--no_date",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)",
+ )
+ parser.add_argument(
+ "-r",
+ "--resume",
+ type=str,
+ const=True,
+ default="",
+ nargs="?",
+ help="resume from logdir or checkpoint in logdir",
+ )
+ parser.add_argument(
+ "-b",
+ "--base",
+ nargs="*",
+ metavar="base_config.yaml",
+ help="paths to base configs. Loaded from left-to-right. "
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
+ default=list(),
+ )
+ parser.add_argument(
+ "-t",
+ "--train",
+ type=str2bool,
+ const=True,
+ default=True,
+ nargs="?",
+ help="train",
+ )
+ parser.add_argument(
+ "--no-test",
+ type=str2bool,
+ const=True,
+ default=False,
+ nargs="?",
+ help="disable test",
+ )
+ parser.add_argument(
+ "-p", "--project", help="name of new or path to existing project"
+ )
+ parser.add_argument(
+ "-d",
+ "--debug",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="enable post-mortem debugging",
+ )
+ parser.add_argument(
+ "-s",
+ "--seed",
+ type=int,
+ default=23,
+ help="seed for seed_everything",
+ )
+ parser.add_argument(
+ "-f",
+ "--postfix",
+ type=str,
+ default="",
+ help="post-postfix for default name",
+ )
+ parser.add_argument(
+ "--projectname",
+ type=str,
+ default="stablediffusion",
+ )
+ parser.add_argument(
+ "-l",
+ "--logdir",
+ type=str,
+ default="logs",
+ help="directory for logging dat shit",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
+ )
+ parser.add_argument(
+ "--legacy_naming",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="name run based on config file name if true, else by whole path",
+ )
+ parser.add_argument(
+ "--enable_tf32",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12",
+ )
+ parser.add_argument(
+ "--startup",
+ type=str,
+ default=None,
+ help="Startuptime from distributed script",
+ )
+ parser.add_argument(
+ "--wandb",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False, # TODO: later default to True
+ help="log to wandb",
+ )
+ parser.add_argument(
+ "--no_base_name",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False, # TODO: later default to True
+ help="log to wandb",
+ )
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help="single checkpoint file to resume from",
+ )
+ default_args = default_trainer_args()
+ for key in default_args:
+ parser.add_argument("--" + key, default=default_args[key])
+ return parser
+
+
+def get_checkpoint_name(logdir):
+ ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt")
+ ckpt = natsorted(glob.glob(ckpt))
+ print('available "last" checkpoints:')
+ print(ckpt)
+ if len(ckpt) > 1:
+ print("got most recent checkpoint")
+ ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]
+ print(f"Most recent ckpt is {ckpt}")
+ with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f:
+ f.write(ckpt + "\n")
+ try:
+ version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0])
+ except Exception as e:
+ print("version confusion but not bad")
+ print(e)
+ version = 1
+ # version = last_version + 1
+ else:
+ # in this case, we only have one "last.ckpt"
+ ckpt = ckpt[0]
+ version = 1
+ melk_ckpt_name = f"last-v{version}.ckpt"
+ print(f"Current melk ckpt name: {melk_ckpt_name}")
+ return ckpt, melk_ckpt_name
+
+
+class SetupCallback(Callback):
+ def __init__(
+ self,
+ resume,
+ now,
+ logdir,
+ ckptdir,
+ cfgdir,
+ config,
+ lightning_config,
+ debug,
+ ckpt_name=None,
+ ):
+ super().__init__()
+ self.resume = resume
+ self.now = now
+ self.logdir = logdir
+ self.ckptdir = ckptdir
+ self.cfgdir = cfgdir
+ self.config = config
+ self.lightning_config = lightning_config
+ self.debug = debug
+ self.ckpt_name = ckpt_name
+
+ def on_exception(self, trainer: pl.Trainer, pl_module, exception):
+ if not self.debug and trainer.global_rank == 0:
+ print("Summoning checkpoint.")
+ if self.ckpt_name is None:
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
+ else:
+ ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
+ trainer.save_checkpoint(ckpt_path)
+
+ def on_fit_start(self, trainer, pl_module):
+ if trainer.global_rank == 0:
+ # Create logdirs and save configs
+ os.makedirs(self.logdir, exist_ok=True)
+ os.makedirs(self.ckptdir, exist_ok=True)
+ os.makedirs(self.cfgdir, exist_ok=True)
+
+ if "callbacks" in self.lightning_config:
+ if (
+ "metrics_over_trainsteps_checkpoint"
+ in self.lightning_config["callbacks"]
+ ):
+ os.makedirs(
+ os.path.join(self.ckptdir, "trainstep_checkpoints"),
+ exist_ok=True,
+ )
+ print("Project config")
+ print(OmegaConf.to_yaml(self.config))
+ if MULTINODE_HACKS:
+ import time
+
+ time.sleep(5)
+ OmegaConf.save(
+ self.config,
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
+ )
+
+ print("Lightning config")
+ print(OmegaConf.to_yaml(self.lightning_config))
+ OmegaConf.save(
+ OmegaConf.create({"lightning": self.lightning_config}),
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
+ )
+
+ else:
+ # ModelCheckpoint callback created log directory --- remove it
+ if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
+ dst, name = os.path.split(self.logdir)
+ dst = os.path.join(dst, "child_runs", name)
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
+ try:
+ os.rename(self.logdir, dst)
+ except FileNotFoundError:
+ pass
+
+
+class ImageLogger(Callback):
+ def __init__(
+ self,
+ batch_frequency,
+ max_images,
+ clamp=True,
+ increase_log_steps=True,
+ rescale=True,
+ disabled=False,
+ log_on_batch_idx=False,
+ log_first_step=False,
+ log_images_kwargs=None,
+ log_before_first_step=False,
+ enable_autocast=True,
+ ):
+ super().__init__()
+ self.enable_autocast = enable_autocast
+ self.rescale = rescale
+ self.batch_freq = batch_frequency
+ self.max_images = max_images
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
+ if not increase_log_steps:
+ self.log_steps = [self.batch_freq]
+ self.clamp = clamp
+ self.disabled = disabled
+ self.log_on_batch_idx = log_on_batch_idx
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+ self.log_first_step = log_first_step
+ self.log_before_first_step = log_before_first_step
+
+ @rank_zero_only
+ def log_local(
+ self,
+ save_dir,
+ split,
+ images,
+ global_step,
+ current_epoch,
+ batch_idx,
+ pl_module: Union[None, pl.LightningModule] = None,
+ ):
+ root = os.path.join(save_dir, "images", split)
+ for k in images:
+ if isheatmap(images[k]):
+ fig, ax = plt.subplots()
+ ax = ax.matshow(
+ images[k].cpu().numpy(), cmap="hot", interpolation="lanczos"
+ )
+ plt.colorbar(ax)
+ plt.axis("off")
+
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
+ k, global_step, current_epoch, batch_idx
+ )
+ os.makedirs(root, exist_ok=True)
+ path = os.path.join(root, filename)
+ plt.savefig(path)
+ plt.close()
+ # TODO: support wandb
+ else:
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
+ if self.rescale:
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
+ k, global_step, current_epoch, batch_idx
+ )
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ img = Image.fromarray(grid)
+ img.save(path)
+ if exists(pl_module):
+ assert isinstance(
+ pl_module.logger, WandbLogger
+ ), "logger_log_image only supports WandbLogger currently"
+ pl_module.logger.log_image(
+ key=f"{split}/{k}",
+ images=[
+ img,
+ ],
+ step=pl_module.global_step,
+ )
+
+ @rank_zero_only
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
+ if (
+ self.check_frequency(check_idx)
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
+ and callable(pl_module.log_images)
+ and
+ # batch_idx > 5 and
+ self.max_images > 0
+ ):
+ logger = type(pl_module.logger)
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+
+ gpu_autocast_kwargs = {
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
+ images = pl_module.log_images(
+ batch, split=split, **self.log_images_kwargs
+ )
+
+ for k in images:
+ N = min(images[k].shape[0], self.max_images)
+ if not isheatmap(images[k]):
+ images[k] = images[k][:N]
+ if isinstance(images[k], torch.Tensor):
+ images[k] = images[k].detach().float().cpu()
+ if self.clamp and not isheatmap(images[k]):
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
+
+ self.log_local(
+ pl_module.logger.save_dir,
+ split,
+ images,
+ pl_module.global_step,
+ pl_module.current_epoch,
+ batch_idx,
+ pl_module=pl_module
+ if isinstance(pl_module.logger, WandbLogger)
+ else None,
+ )
+
+ if is_train:
+ pl_module.train()
+
+ def check_frequency(self, check_idx):
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
+ check_idx > 0 or self.log_first_step
+ ):
+ try:
+ self.log_steps.pop(0)
+ except IndexError as e:
+ print(e)
+ pass
+ return True
+ return False
+
+ @rank_zero_only
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
+ self.log_img(pl_module, batch, batch_idx, split="train")
+
+ @rank_zero_only
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
+ if self.log_before_first_step and pl_module.global_step == 0:
+ print(f"{self.__class__.__name__}: logging before training")
+ self.log_img(pl_module, batch, batch_idx, split="train")
+
+ @rank_zero_only
+ # def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+ def on_validation_batch_end(
+ self, trainer, pl_module, outputs, batch, batch_idx, **kwargs
+ ):
+ if not self.disabled and pl_module.global_step > 0:
+ self.log_img(pl_module, batch, batch_idx, split="val")
+ if hasattr(pl_module, "calibrate_grad_norm"):
+ if (
+ pl_module.calibrate_grad_norm and batch_idx % 25 == 0
+ ) and batch_idx > 0:
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
+
+
+@rank_zero_only
+def init_wandb(save_dir, opt, config, group_name, name_str):
+ print(f"setting WANDB_DIR to {save_dir}")
+ os.makedirs(save_dir, exist_ok=True)
+
+ os.environ["WANDB_DIR"] = save_dir
+ if opt.debug:
+ wandb.init(project=opt.projectname, mode="offline", group=group_name)
+ else:
+ wandb.init(
+ project=opt.projectname,
+ config=config,
+ settings=wandb.Settings(code_dir="./sgm"),
+ group=group_name,
+ name=name_str,
+ )
+
+
+if __name__ == "__main__":
+ # custom parser to specify config files, train, test and debug mode,
+ # postfix, resume.
+ # `--key value` arguments are interpreted as arguments to the trainer.
+ # `nested.key=value` arguments are interpreted as config parameters.
+ # configs are merged from left-to-right followed by command line parameters.
+
+ # model:
+ # base_learning_rate: float
+ # target: path to lightning module
+ # params:
+ # key: value
+ # data:
+ # target: main.DataModuleFromConfig
+ # params:
+ # batch_size: int
+ # wrap: bool
+ # train:
+ # target: path to train dataset
+ # params:
+ # key: value
+ # validation:
+ # target: path to validation dataset
+ # params:
+ # key: value
+ # test:
+ # target: path to test dataset
+ # params:
+ # key: value
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
+ # trainer:
+ # additional arguments to trainer
+ # logger:
+ # logger to instantiate
+ # modelcheckpoint:
+ # modelcheckpoint to instantiate
+ # callbacks:
+ # callback1:
+ # target: importpath
+ # params:
+ # key: value
+
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+
+ # add cwd for convenience and to make classes in this file available when
+ # running as `python main.py`
+ # (in particular `main.DataModuleFromConfig`)
+ sys.path.append(os.getcwd())
+
+ parser = get_parser()
+
+ opt, unknown = parser.parse_known_args()
+
+ if opt.name and opt.resume:
+ raise ValueError(
+ "-n/--name and -r/--resume cannot be specified both."
+ "If you want to resume training in a new log folder, "
+ "use -n/--name in combination with --resume_from_checkpoint"
+ )
+ melk_ckpt_name = None
+ name = None
+ if opt.resume:
+ if not os.path.exists(opt.resume):
+ raise ValueError("Cannot find {}".format(opt.resume))
+ if os.path.isfile(opt.resume):
+ paths = opt.resume.split("/")
+ # idx = len(paths)-paths[::-1].index("logs")+1
+ # logdir = "/".join(paths[:idx])
+ logdir = "/".join(paths[:-2])
+ ckpt = opt.resume
+ _, melk_ckpt_name = get_checkpoint_name(logdir)
+ else:
+ assert os.path.isdir(opt.resume), opt.resume
+ logdir = opt.resume.rstrip("/")
+ ckpt, melk_ckpt_name = get_checkpoint_name(logdir)
+
+ print("#" * 100)
+ print(f'Resuming from checkpoint "{ckpt}"')
+ print("#" * 100)
+
+ opt.resume_from_checkpoint = ckpt
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
+ opt.base = base_configs + opt.base
+ _tmp = logdir.split("/")
+ nowname = _tmp[-1]
+ else:
+ if opt.name:
+ name = "_" + opt.name
+ elif opt.base:
+ if opt.no_base_name:
+ name = ""
+ else:
+ if opt.legacy_naming:
+ cfg_fname = os.path.split(opt.base[0])[-1]
+ cfg_name = os.path.splitext(cfg_fname)[0]
+ else:
+ assert "configs" in os.path.split(opt.base[0])[0], os.path.split(
+ opt.base[0]
+ )[0]
+ cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[
+ os.path.split(opt.base[0])[0].split(os.sep).index("configs")
+ + 1 :
+ ] # cut away the first one (we assert all configs are in "configs")
+ cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]
+ cfg_name = "-".join(cfg_path) + f"-{cfg_name}"
+ name = "_" + cfg_name
+ else:
+ name = ""
+ if not opt.no_date:
+ nowname = now + name + opt.postfix
+ else:
+ nowname = name + opt.postfix
+ if nowname.startswith("_"):
+ nowname = nowname[1:]
+ logdir = os.path.join(opt.logdir, nowname)
+ print(f"LOGDIR: {logdir}")
+
+ ckptdir = os.path.join(logdir, "checkpoints")
+ cfgdir = os.path.join(logdir, "configs")
+ seed_everything(opt.seed, workers=True)
+
+ # move before model init, in case a torch.compile(...) is called somewhere
+ if opt.enable_tf32:
+ # pt_version = version.parse(torch.__version__)
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ print(f"Enabling TF32 for PyTorch {torch.__version__}")
+ else:
+ print(f"Using default TF32 settings for PyTorch {torch.__version__}:")
+ print(
+ f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}"
+ )
+ print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}")
+
+ try:
+ # init and save configs
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ config = OmegaConf.merge(*configs, cli)
+ lightning_config = config.pop("lightning", OmegaConf.create())
+ # merge trainer cli with config
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
+
+ # default to gpu
+ trainer_config["accelerator"] = "gpu"
+ #
+ standard_args = default_trainer_args()
+ for k in standard_args:
+ if getattr(opt, k) != standard_args[k]:
+ trainer_config[k] = getattr(opt, k)
+
+ ckpt_resume_path = opt.resume_from_checkpoint
+
+ if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
+ del trainer_config["accelerator"]
+ cpu = True
+ else:
+ gpuinfo = trainer_config["devices"]
+ print(f"Running on GPUs {gpuinfo}")
+ cpu = False
+ trainer_opt = argparse.Namespace(**trainer_config)
+ lightning_config.trainer = trainer_config
+
+ # model
+ model = instantiate_from_config(config.model)
+
+ # trainer and callbacks
+ trainer_kwargs = dict()
+
+ # default logger configs
+ default_logger_cfgs = {
+ "wandb": {
+ "target": "pytorch_lightning.loggers.WandbLogger",
+ "params": {
+ "name": nowname,
+ # "save_dir": logdir,
+ "offline": opt.debug,
+ "id": nowname,
+ "project": opt.projectname,
+ "log_model": False,
+ # "dir": logdir,
+ },
+ },
+ "csv": {
+ "target": "pytorch_lightning.loggers.CSVLogger",
+ "params": {
+ "name": "testtube", # hack for sbord fanatics
+ "save_dir": logdir,
+ },
+ },
+ }
+ default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"]
+ if opt.wandb:
+ # TODO change once leaving "swiffer" config directory
+ try:
+ group_name = nowname.split(now)[-1].split("-")[1]
+ except:
+ group_name = nowname
+ default_logger_cfg["params"]["group"] = group_name
+ init_wandb(
+ os.path.join(os.getcwd(), logdir),
+ opt=opt,
+ group_name=group_name,
+ config=config,
+ name_str=nowname,
+ )
+ if "logger" in lightning_config:
+ logger_cfg = lightning_config.logger
+ else:
+ logger_cfg = OmegaConf.create()
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
+
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
+ # specify which metric is used to determine best models
+ default_modelckpt_cfg = {
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
+ "params": {
+ "dirpath": ckptdir,
+ "filename": "{epoch:06}",
+ "verbose": True,
+ "save_last": True,
+ },
+ }
+ if hasattr(model, "monitor"):
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
+
+ if "modelcheckpoint" in lightning_config:
+ modelckpt_cfg = lightning_config.modelcheckpoint
+ else:
+ modelckpt_cfg = OmegaConf.create()
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
+
+ # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html
+ # default to ddp if not further specified
+ default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"}
+
+ if "strategy" in lightning_config:
+ strategy_cfg = lightning_config.strategy
+ else:
+ strategy_cfg = OmegaConf.create()
+ default_strategy_config["params"] = {
+ "find_unused_parameters": False,
+ # "static_graph": True,
+ # "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded
+ }
+ strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)
+ print(
+ f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ "
+ )
+ trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
+
+ # add callback which sets up log directory
+ default_callbacks_cfg = {
+ "setup_callback": {
+ "target": "main.SetupCallback",
+ "params": {
+ "resume": opt.resume,
+ "now": now,
+ "logdir": logdir,
+ "ckptdir": ckptdir,
+ "cfgdir": cfgdir,
+ "config": config,
+ "lightning_config": lightning_config,
+ "debug": opt.debug,
+ "ckpt_name": melk_ckpt_name,
+ },
+ },
+ "image_logger": {
+ "target": "main.ImageLogger",
+ "params": {"batch_frequency": 1000, "max_images": 4, "clamp": True},
+ },
+ "learning_rate_logger": {
+ "target": "pytorch_lightning.callbacks.LearningRateMonitor",
+ "params": {
+ "logging_interval": "step",
+ # "log_momentum": True
+ },
+ },
+ }
+ if version.parse(pl.__version__) >= version.parse("1.4.0"):
+ default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
+
+ if "callbacks" in lightning_config:
+ callbacks_cfg = lightning_config.callbacks
+ else:
+ callbacks_cfg = OmegaConf.create()
+
+ if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
+ print(
+ "Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
+ )
+ default_metrics_over_trainsteps_ckpt_dict = {
+ "metrics_over_trainsteps_checkpoint": {
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
+ "params": {
+ "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
+ "filename": "{epoch:06}-{step:09}",
+ "verbose": True,
+ "save_top_k": -1,
+ "every_n_train_steps": 10000,
+ "save_weights_only": True,
+ },
+ }
+ }
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
+
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
+ if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None:
+ callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path
+ elif "ignore_keys_callback" in callbacks_cfg:
+ del callbacks_cfg["ignore_keys_callback"]
+
+ trainer_kwargs["callbacks"] = [
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
+ ]
+ if not "plugins" in trainer_kwargs:
+ trainer_kwargs["plugins"] = list()
+
+ # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
+ trainer_opt = vars(trainer_opt)
+ trainer_kwargs = {
+ key: val for key, val in trainer_kwargs.items() if key not in trainer_opt
+ }
+ trainer = Trainer(**trainer_opt, **trainer_kwargs)
+
+ trainer.logdir = logdir ###
+
+ # data
+ data = instantiate_from_config(config.data)
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
+ # calling these ourselves should not be necessary but it is.
+ # lightning still takes care of proper multiprocessing though
+ data.prepare_data()
+ # data.setup()
+ print("#### Data #####")
+ try:
+ for k in data.datasets:
+ print(
+ f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
+ )
+ except:
+ print("datasets not yet initialized.")
+
+ # configure learning rate
+ if "batch_size" in config.data.params:
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
+ else:
+ bs, base_lr = (
+ config.data.params.train.loader.batch_size,
+ config.model.base_learning_rate,
+ )
+ if not cpu:
+ ngpu = len(lightning_config.trainer.devices.strip(",").split(","))
+ else:
+ ngpu = 1
+ if "accumulate_grad_batches" in lightning_config.trainer:
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
+ else:
+ accumulate_grad_batches = 1
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
+ if opt.scale_lr:
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
+ print(
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
+ )
+ )
+ else:
+ model.learning_rate = base_lr
+ print("++++ NOT USING LR SCALING ++++")
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
+
+ # allow checkpointing via USR1
+ def melk(*args, **kwargs):
+ # run all checkpoint hooks
+ if trainer.global_rank == 0:
+ print("Summoning checkpoint.")
+ if melk_ckpt_name is None:
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
+ else:
+ ckpt_path = os.path.join(ckptdir, melk_ckpt_name)
+ trainer.save_checkpoint(ckpt_path)
+
+ def divein(*args, **kwargs):
+ if trainer.global_rank == 0:
+ import pudb
+
+ pudb.set_trace()
+
+ import signal
+
+ signal.signal(signal.SIGUSR1, melk)
+ signal.signal(signal.SIGUSR2, divein)
+
+ # run
+ if opt.train:
+ try:
+ trainer.fit(model, data, ckpt_path=ckpt_resume_path)
+ except Exception:
+ if not opt.debug:
+ melk()
+ raise
+ if not opt.no_test and not trainer.interrupted:
+ trainer.test(model, data)
+ except RuntimeError as err:
+ if MULTINODE_HACKS:
+ import requests
+ import datetime
+ import os
+ import socket
+
+ device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
+ hostname = socket.gethostname()
+ ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
+ resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id")
+ print(
+ f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}",
+ flush=True,
+ )
+ raise err
+ except Exception:
+ if opt.debug and trainer.global_rank == 0:
+ try:
+ import pudb as debugger
+ except ImportError:
+ import pdb as debugger
+ debugger.post_mortem()
+ raise
+ finally:
+ # move newly created debug project to debug_runs
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
+ dst, name = os.path.split(logdir)
+ dst = os.path.join(dst, "debug_runs", name)
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
+ os.rename(logdir, dst)
+
+ if opt.wandb:
+ wandb.finish()
+ # if trainer.global_rank == 0:
+ # print(trainer.profiler.summary())
diff --git a/requirements_pt13.txt b/requirements_pt13.txt
new file mode 100644
index 000000000..3d5b117c3
--- /dev/null
+++ b/requirements_pt13.txt
@@ -0,0 +1,41 @@
+omegaconf
+einops
+fire
+tqdm
+pillow
+numpy
+webdataset>=0.2.33
+--extra-index-url https://download.pytorch.org/whl/cu117
+torch==1.13.1+cu117
+xformers==0.0.16
+torchaudio==0.13.1
+torchvision==0.14.1+cu117
+torchmetrics
+opencv-python==4.6.0.66
+fairscale
+pytorch-lightning==1.8.5
+fsspec
+kornia==0.6.9
+matplotlib
+natsort
+tensorboardx==2.5.1
+open-clip-torch
+chardet
+scipy
+pandas
+pudb
+pyyaml
+urllib3<1.27,>=1.25.4
+streamlit>=0.73.1
+timm
+tokenizers==0.12.1
+torchdata==0.5.1
+transformers==4.19.1
+onnx<=1.12.0
+triton
+wandb
+invisible-watermark
+-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+-e git+https://github.com/openai/CLIP.git@main#egg=clip
+-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
+-e .
\ No newline at end of file
diff --git a/requirements_pt2.txt b/requirements_pt2.txt
new file mode 100644
index 000000000..9988b9084
--- /dev/null
+++ b/requirements_pt2.txt
@@ -0,0 +1,41 @@
+omegaconf
+einops
+fire
+tqdm
+pillow
+numpy
+webdataset>=0.2.33
+ninja
+torch
+matplotlib
+torchaudio>=2.0.2
+torchmetrics
+torchvision>=0.15.2
+opencv-python==4.6.0.66
+fairscale
+pytorch-lightning==2.0.1
+fire
+fsspec
+kornia==0.6.9
+natsort
+open-clip-torch
+chardet==5.1.0
+tensorboardx==2.6
+pandas
+pudb
+pyyaml
+urllib3<1.27,>=1.25.4
+scipy
+streamlit>=0.73.1
+timm
+tokenizers==0.12.1
+transformers==4.19.1
+triton==2.0.0
+torchdata==0.6.1
+wandb
+invisible-watermark
+xformers
+-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+-e git+https://github.com/openai/CLIP.git@main#egg=clip
+-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
+-e .
\ No newline at end of file
diff --git a/scripts/demo/detect.py b/scripts/demo/detect.py
new file mode 100644
index 000000000..823ae8d39
--- /dev/null
+++ b/scripts/demo/detect.py
@@ -0,0 +1,157 @@
+import argparse
+
+import cv2
+import numpy as np
+
+try:
+ from imwatermark import WatermarkDecoder
+except ImportError as e:
+ try:
+ # Assume some of the other dependencies such as torch are not fulfilled
+ # import file without loading unnecessary libraries.
+ import importlib.util
+ import sys
+
+ spec = importlib.util.find_spec("imwatermark.maxDct")
+ assert spec is not None
+ maxDct = importlib.util.module_from_spec(spec)
+ sys.modules["maxDct"] = maxDct
+ spec.loader.exec_module(maxDct)
+
+ class WatermarkDecoder(object):
+ """A minimal version of
+ https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
+ to only reconstruct bits using dwtDct"""
+
+ def __init__(self, wm_type="bytes", length=0):
+ assert wm_type == "bits", "Only bits defined in minimal import"
+ self._wmType = wm_type
+ self._wmLen = length
+
+ def reconstruct(self, bits):
+ if len(bits) != self._wmLen:
+ raise RuntimeError("bits are not matched with watermark length")
+
+ return bits
+
+ def decode(self, cv2Image, method="dwtDct", **configs):
+ (r, c, channels) = cv2Image.shape
+ if r * c < 256 * 256:
+ raise RuntimeError("image too small, should be larger than 256x256")
+
+ bits = []
+ assert method == "dwtDct"
+ embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
+ bits = embed.decode(cv2Image)
+ return self.reconstruct(bits)
+
+ except:
+ raise e
+
+
+# A fixed 48-bit message that was choosen at random
+# WATERMARK_MESSAGE = 0xB3EC907BB19E
+WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
+# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
+WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
+MATCH_VALUES = [
+ [27, "No watermark detected"],
+ [33, "Partial watermark match. Cannot determine with certainty."],
+ [
+ 35,
+ (
+ "Likely watermarked. In our test 0.02% of real images were "
+ 'falsely detected as "Likely watermarked"'
+ ),
+ ],
+ [
+ 49,
+ (
+ "Very likely watermarked. In our test no real images were "
+ 'falsely detected as "Very likely watermarked"'
+ ),
+ ],
+]
+
+
+class GetWatermarkMatch:
+ def __init__(self, watermark):
+ self.watermark = watermark
+ self.num_bits = len(self.watermark)
+ self.decoder = WatermarkDecoder("bits", self.num_bits)
+
+ def __call__(self, x: np.ndarray) -> np.ndarray:
+ """
+ Detects the number of matching bits the predefined watermark with one
+ or multiple images. Images should be in cv2 format, e.g. h x w x c.
+
+ Args:
+ x: ([B], h w, c) in range [0, 255]
+
+ Returns:
+ number of matched bits ([B],)
+ """
+ squeeze = len(x.shape) == 3
+ if squeeze:
+ x = x[None, ...]
+ x = np.flip(x, axis=-1)
+
+ bs = x.shape[0]
+ detected = np.empty((bs, self.num_bits), dtype=bool)
+ for k in range(bs):
+ detected[k] = self.decoder.decode(x[k], "dwtDct")
+ result = np.sum(detected == self.watermark, axis=-1)
+ if squeeze:
+ return result[0]
+ else:
+ return result
+
+
+get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "filename",
+ nargs="+",
+ type=str,
+ help="Image files to check for watermarks",
+ )
+ opts = parser.parse_args()
+
+ print(
+ """
+ This script tries to detect watermarked images. Please be aware of
+ the following:
+ - As the watermark is supposed to be invisible, there is the risk that
+ watermarked images may not be detected.
+ - To maximize the chance of detection make sure that the image has the same
+ dimensions as when the watermark was applied (most likely 1024x1024
+ or 512x512).
+ - Specific image manipulation may drastically decrease the chance that
+ watermarks can be detected.
+ - There is also the chance that an image has the characteristics of the
+ watermark by chance.
+ - The watermark script is public, anybody may watermark any images, and
+ could therefore claim it to be generated.
+ - All numbers below are based on a test using 10,000 images without any
+ modifications after applying the watermark.
+ """
+ )
+
+ for fn in opts.filename:
+ image = cv2.imread(fn)
+ if image is None:
+ print(f"Couldn't read {fn}. Skipping")
+ continue
+
+ num_bits = get_watermark_match(image)
+ k = 0
+ while num_bits > MATCH_VALUES[k][0]:
+ k += 1
+ print(
+ f"{fn}: {MATCH_VALUES[k][1]}",
+ f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
+ sep="\n\t",
+ )
diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py
new file mode 100644
index 000000000..0aa1a424b
--- /dev/null
+++ b/scripts/demo/sampling.py
@@ -0,0 +1,328 @@
+from pytorch_lightning import seed_everything
+from scripts.demo.streamlit_helpers import *
+from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
+
+SAVE_PATH = "outputs/demo/txt2img/"
+
+SD_XL_BASE_RATIOS = {
+ "0.5": (704, 1408),
+ "0.52": (704, 1344),
+ "0.57": (768, 1344),
+ "0.6": (768, 1280),
+ "0.68": (832, 1216),
+ "0.72": (832, 1152),
+ "0.78": (896, 1152),
+ "0.82": (896, 1088),
+ "0.88": (960, 1088),
+ "0.94": (960, 1024),
+ "1.0": (1024, 1024),
+ "1.07": (1024, 960),
+ "1.13": (1088, 960),
+ "1.21": (1088, 896),
+ "1.29": (1152, 896),
+ "1.38": (1152, 832),
+ "1.46": (1216, 832),
+ "1.67": (1280, 768),
+ "1.75": (1344, 768),
+ "1.91": (1344, 704),
+ "2.0": (1408, 704),
+ "2.09": (1472, 704),
+ "2.4": (1536, 640),
+ "2.5": (1600, 640),
+ "2.89": (1664, 576),
+ "3.0": (1728, 576),
+}
+
+VERSION2SPECS = {
+ "SD-XL base": {
+ "H": 1024,
+ "W": 1024,
+ "C": 4,
+ "f": 8,
+ "is_legacy": False,
+ "config": "configs/inference/sd_xl_base.yaml",
+ "ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
+ "is_guided": True,
+ },
+ "sd-2.1": {
+ "H": 512,
+ "W": 512,
+ "C": 4,
+ "f": 8,
+ "is_legacy": True,
+ "config": "configs/inference/sd_2_1.yaml",
+ "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
+ "is_guided": True,
+ },
+ "sd-2.1-768": {
+ "H": 768,
+ "W": 768,
+ "C": 4,
+ "f": 8,
+ "is_legacy": True,
+ "config": "configs/inference/sd_2_1_768.yaml",
+ "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
+ },
+ "SDXL-Refiner": {
+ "H": 1024,
+ "W": 1024,
+ "C": 4,
+ "f": 8,
+ "is_legacy": True,
+ "config": "configs/inference/sd_xl_refiner.yaml",
+ "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
+ "is_guided": True,
+ },
+}
+
+
+def load_img(display=True, key=None, device="cuda"):
+ image = get_interactive_image(key=key)
+ if image is None:
+ return None
+ if display:
+ st.image(image)
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h})")
+ width, height = map(
+ lambda x: x - x % 64, (w, h)
+ ) # resize to integer multiple of 64
+ image = image.resize((width, height))
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+ return image.to(device)
+
+
+def run_txt2img(
+ state, version, version_dict, is_legacy=False, return_latents=False, filter=None
+):
+ if version == "SD-XL base":
+ ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
+ W, H = SD_XL_BASE_RATIOS[ratio]
+ else:
+ H = st.sidebar.number_input(
+ "H", value=version_dict["H"], min_value=64, max_value=2048
+ )
+ W = st.sidebar.number_input(
+ "W", value=version_dict["W"], min_value=64, max_value=2048
+ )
+ C = version_dict["C"]
+ F = version_dict["f"]
+
+ init_dict = {
+ "orig_width": W,
+ "orig_height": H,
+ "target_width": W,
+ "target_height": H,
+ }
+ value_dict = init_embedder_options(
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
+ init_dict,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ )
+ num_rows, num_cols, sampler = init_sampling(
+ use_identity_guider=not version_dict["is_guided"]
+ )
+
+ num_samples = num_rows * num_cols
+
+ if st.button("Sample"):
+ st.write(f"**Model I:** {version}")
+ out = do_sample(
+ state["model"],
+ sampler,
+ value_dict,
+ num_samples,
+ H,
+ W,
+ C,
+ F,
+ force_uc_zero_embeddings=["txt"] if not is_legacy else [],
+ return_latents=return_latents,
+ filter=filter,
+ )
+ return out
+
+
+def run_img2img(
+ state, version_dict, is_legacy=False, return_latents=False, filter=None
+):
+ img = load_img()
+ if img is None:
+ return None
+ H, W = img.shape[2], img.shape[3]
+
+ init_dict = {
+ "orig_width": W,
+ "orig_height": H,
+ "target_width": W,
+ "target_height": H,
+ }
+ value_dict = init_embedder_options(
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
+ init_dict,
+ )
+ strength = st.number_input(
+ "**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
+ )
+ num_rows, num_cols, sampler = init_sampling(
+ img2img_strength=strength,
+ use_identity_guider=not version_dict["is_guided"],
+ )
+ num_samples = num_rows * num_cols
+
+ if st.button("Sample"):
+ out = do_img2img(
+ repeat(img, "1 ... -> n ...", n=num_samples),
+ state["model"],
+ sampler,
+ value_dict,
+ num_samples,
+ force_uc_zero_embeddings=["txt"] if not is_legacy else [],
+ return_latents=return_latents,
+ filter=filter,
+ )
+ return out
+
+
+def apply_refiner(
+ input,
+ state,
+ sampler,
+ num_samples,
+ prompt,
+ negative_prompt,
+ filter=None,
+):
+ init_dict = {
+ "orig_width": input.shape[3] * 8,
+ "orig_height": input.shape[2] * 8,
+ "target_width": input.shape[3] * 8,
+ "target_height": input.shape[2] * 8,
+ }
+
+ value_dict = init_dict
+ value_dict["prompt"] = prompt
+ value_dict["negative_prompt"] = negative_prompt
+
+ value_dict["crop_coords_top"] = 0
+ value_dict["crop_coords_left"] = 0
+
+ value_dict["aesthetic_score"] = 6.0
+ value_dict["negative_aesthetic_score"] = 2.5
+
+ st.warning(f"refiner input shape: {input.shape}")
+ samples = do_img2img(
+ input,
+ state["model"],
+ sampler,
+ value_dict,
+ num_samples,
+ skip_encode=True,
+ filter=filter,
+ )
+
+ return samples
+
+
+if __name__ == "__main__":
+ st.title("Stable Diffusion")
+ version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
+ version_dict = VERSION2SPECS[version]
+ mode = st.radio("Mode", ("txt2img", "img2img"), 0)
+ st.write("__________________________")
+
+ if version == "SD-XL base":
+ add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
+ st.write("__________________________")
+ else:
+ add_pipeline = False
+
+ filter = DeepFloydDataFiltering(verbose=False)
+
+ seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
+ seed_everything(seed)
+
+ save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
+
+ state = init_st(version_dict)
+ if state["msg"]:
+ st.info(state["msg"])
+ model = state["model"]
+
+ is_legacy = version_dict["is_legacy"]
+
+ prompt = st.text_input(
+ "prompt",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ )
+ if is_legacy:
+ negative_prompt = st.text_input("negative prompt", "")
+ else:
+ negative_prompt = "" # which is unused
+
+ if add_pipeline:
+ st.write("__________________________")
+
+ version2 = "SDXL-Refiner"
+ st.warning(
+ f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
+ )
+ st.write("**Refiner Options:**")
+
+ version_dict2 = VERSION2SPECS[version2]
+ state2 = init_st(version_dict2)
+ st.info(state2["msg"])
+
+ stage2strength = st.number_input(
+ "**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
+ )
+
+ sampler2 = init_sampling(
+ key=2,
+ img2img_strength=stage2strength,
+ use_identity_guider=not version_dict["is_guided"],
+ get_num_samples=False,
+ )
+ st.write("__________________________")
+
+ if mode == "txt2img":
+ out = run_txt2img(
+ state,
+ version,
+ version_dict,
+ is_legacy=is_legacy,
+ return_latents=add_pipeline,
+ filter=filter,
+ )
+ elif mode == "img2img":
+ out = run_img2img(
+ state,
+ version_dict,
+ is_legacy=is_legacy,
+ return_latents=add_pipeline,
+ filter=filter,
+ )
+ else:
+ raise ValueError(f"unknown mode {mode}")
+ if isinstance(out, (tuple, list)):
+ samples, samples_z = out
+ else:
+ samples = out
+
+ if add_pipeline:
+ st.write("**Running Refinement Stage**")
+ samples = apply_refiner(
+ samples_z,
+ state2,
+ sampler2,
+ samples_z.shape[0],
+ prompt=prompt,
+ negative_prompt=negative_prompt if is_legacy else "",
+ filter=filter,
+ )
+
+ if save_locally and samples is not None:
+ perform_save_locally(save_path, samples)
diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py
new file mode 100644
index 000000000..ddc9c6ba5
--- /dev/null
+++ b/scripts/demo/streamlit_helpers.py
@@ -0,0 +1,668 @@
+import os
+from typing import Union, List
+
+import math
+import numpy as np
+import streamlit as st
+import torch
+from PIL import Image
+from einops import rearrange, repeat
+from imwatermark import WatermarkEncoder
+from omegaconf import OmegaConf, ListConfig
+from torch import autocast
+from torchvision import transforms
+from torchvision.utils import make_grid
+from safetensors.torch import load_file as load_safetensors
+
+from sgm.modules.diffusionmodules.sampling import (
+ EulerEDMSampler,
+ HeunEDMSampler,
+ EulerAncestralSampler,
+ DPMPP2SAncestralSampler,
+ DPMPP2MSampler,
+ LinearMultistepSampler,
+)
+from sgm.util import append_dims
+from sgm.util import instantiate_from_config
+
+
+class WatermarkEmbedder:
+ def __init__(self, watermark):
+ self.watermark = watermark
+ self.num_bits = len(WATERMARK_BITS)
+ self.encoder = WatermarkEncoder()
+ self.encoder.set_watermark("bits", self.watermark)
+
+ def __call__(self, image: torch.Tensor):
+ """
+ Adds a predefined watermark to the input image
+
+ Args:
+ image: ([N,] B, C, H, W) in range [0, 1]
+
+ Returns:
+ same as input but watermarked
+ """
+ # watermarking libary expects input as cv2 format
+ squeeze = len(image.shape) == 4
+ if squeeze:
+ image = image[None, ...]
+ n = image.shape[0]
+ image_np = rearrange(
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
+ ).numpy()
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
+ for k in range(image_np.shape[0]):
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
+ image = torch.from_numpy(
+ rearrange(image_np, "(n b) h w c -> n b c h w", n=n)
+ ).to(image.device)
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
+ if squeeze:
+ image = image[0]
+ return image
+
+
+# A fixed 48-bit message that was choosen at random
+# WATERMARK_MESSAGE = 0xB3EC907BB19E
+WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
+# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
+WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
+embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
+
+
+@st.cache_resource()
+def init_st(version_dict, load_ckpt=True):
+ state = dict()
+ if not "model" in state:
+ config = version_dict["config"]
+ ckpt = version_dict["ckpt"]
+
+ config = OmegaConf.load(config)
+ model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
+
+ state["msg"] = msg
+ state["model"] = model
+ state["ckpt"] = ckpt if load_ckpt else None
+ state["config"] = config
+ return state
+
+
+def load_model_from_config(config, ckpt=None, verbose=True):
+ model = instantiate_from_config(config.model)
+
+ if ckpt is not None:
+ print(f"Loading model from {ckpt}")
+ if ckpt.endswith("ckpt"):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ global_step = pl_sd["global_step"]
+ st.info(f"loaded ckpt from global step {global_step}")
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ elif ckpt.endswith("safetensors"):
+ sd = load_safetensors(ckpt)
+ else:
+ raise NotImplementedError
+
+ msg = None
+
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+ else:
+ msg = None
+
+ model.cuda()
+ model.eval()
+ return model, msg
+
+
+def get_unique_embedder_keys_from_conditioner(conditioner):
+ return list(set([x.input_key for x in conditioner.embedders]))
+
+
+def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
+ # Hardcoded demo settings; might undergo some changes in the future
+
+ value_dict = {}
+ for key in keys:
+ if key == "txt":
+ if prompt is None:
+ prompt = st.text_input(
+ "Prompt", "A professional photograph of an astronaut riding a pig"
+ )
+ if negative_prompt is None:
+ negative_prompt = st.text_input("Negative prompt", "")
+
+ value_dict["prompt"] = prompt
+ value_dict["negative_prompt"] = negative_prompt
+
+ if key == "original_size_as_tuple":
+ orig_width = st.number_input(
+ "orig_width",
+ value=init_dict["orig_width"],
+ min_value=16,
+ )
+ orig_height = st.number_input(
+ "orig_height",
+ value=init_dict["orig_height"],
+ min_value=16,
+ )
+
+ value_dict["orig_width"] = orig_width
+ value_dict["orig_height"] = orig_height
+
+ if key == "crop_coords_top_left":
+ crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
+ crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
+
+ value_dict["crop_coords_top"] = crop_coord_top
+ value_dict["crop_coords_left"] = crop_coord_left
+
+ if key == "aesthetic_score":
+ value_dict["aesthetic_score"] = 6.0
+ value_dict["negative_aesthetic_score"] = 2.5
+
+ if key == "target_size_as_tuple":
+ target_width = st.number_input(
+ "target_width",
+ value=init_dict["target_width"],
+ min_value=16,
+ )
+ target_height = st.number_input(
+ "target_height",
+ value=init_dict["target_height"],
+ min_value=16,
+ )
+
+ value_dict["target_width"] = target_width
+ value_dict["target_height"] = target_height
+
+ return value_dict
+
+
+def perform_save_locally(save_path, samples):
+ os.makedirs(os.path.join(save_path), exist_ok=True)
+ base_count = len(os.listdir(os.path.join(save_path)))
+ samples = embed_watemark(samples)
+ for sample in samples:
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
+ Image.fromarray(sample.astype(np.uint8)).save(
+ os.path.join(save_path, f"{base_count:09}.png")
+ )
+ base_count += 1
+
+
+def init_save_locally(_dir, init_value: bool = False):
+ save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
+ if save_locally:
+ save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
+ else:
+ save_path = None
+
+ return save_locally, save_path
+
+
+class Img2ImgDiscretizationWrapper:
+ """
+ wraps a discretizer, and prunes the sigmas
+ params:
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
+ """
+
+ def __init__(self, discretization, strength: float = 1.0):
+ self.discretization = discretization
+ self.strength = strength
+ assert 0.0 <= self.strength <= 1.0
+
+ def __call__(self, *args, **kwargs):
+ # sigmas start large first, and decrease then
+ sigmas = self.discretization(*args, **kwargs)
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
+ sigmas = torch.flip(sigmas, (0,))
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
+ sigmas = torch.flip(sigmas, (0,))
+ print(f"sigmas after pruning: ", sigmas)
+ return sigmas
+
+
+def get_guider(key):
+ guider = st.sidebar.selectbox(
+ f"Discretization #{key}",
+ [
+ "VanillaCFG",
+ "IdentityGuider",
+ ],
+ )
+
+ if guider == "IdentityGuider":
+ guider_config = {
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
+ }
+ elif guider == "VanillaCFG":
+ scale = st.number_input(
+ f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
+ )
+
+ thresholder = st.sidebar.selectbox(
+ f"Thresholder #{key}",
+ [
+ "None",
+ ],
+ )
+
+ if thresholder == "None":
+ dyn_thresh_config = {
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
+ }
+ else:
+ raise NotImplementedError
+
+ guider_config = {
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
+ }
+ else:
+ raise NotImplementedError
+ return guider_config
+
+
+def init_sampling(
+ key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True
+):
+ if get_num_samples:
+ num_rows = 1
+ num_cols = st.number_input(
+ f"num cols #{key}", value=2, min_value=1, max_value=10
+ )
+
+ steps = st.sidebar.number_input(
+ f"steps #{key}", value=50, min_value=1, max_value=1000
+ )
+ sampler = st.sidebar.selectbox(
+ f"Sampler #{key}",
+ [
+ "EulerEDMSampler",
+ "HeunEDMSampler",
+ "EulerAncestralSampler",
+ "DPMPP2SAncestralSampler",
+ "DPMPP2MSampler",
+ "LinearMultistepSampler",
+ ],
+ 0,
+ )
+ discretization = st.sidebar.selectbox(
+ f"Discretization #{key}",
+ [
+ "LegacyDDPMDiscretization",
+ "EDMDiscretization",
+ ],
+ )
+
+ discretization_config = get_discretization(discretization, key=key)
+
+ guider_config = get_guider(key=key)
+
+ sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
+ if img2img_strength < 1.0:
+ st.warning(
+ f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
+ )
+ sampler.discretization = Img2ImgDiscretizationWrapper(
+ sampler.discretization, strength=img2img_strength
+ )
+ if get_num_samples:
+ return num_rows, num_cols, sampler
+ return sampler
+
+
+def get_discretization(discretization, key=1):
+ if discretization == "LegacyDDPMDiscretization":
+ use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False)
+ discretization_config = {
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
+ "params": {"legacy_range": not use_new_range},
+ }
+ elif discretization == "EDMDiscretization":
+ sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
+ sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
+ rho = st.number_input(f"rho #{key}", value=3.0)
+ discretization_config = {
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
+ "params": {
+ "sigma_min": sigma_min,
+ "sigma_max": sigma_max,
+ "rho": rho,
+ },
+ }
+
+ return discretization_config
+
+
+def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
+ if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
+ s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
+ s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
+ s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
+ s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
+
+ if sampler_name == "EulerEDMSampler":
+ sampler = EulerEDMSampler(
+ num_steps=steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ s_churn=s_churn,
+ s_tmin=s_tmin,
+ s_tmax=s_tmax,
+ s_noise=s_noise,
+ verbose=True,
+ )
+ elif sampler_name == "HeunEDMSampler":
+ sampler = HeunEDMSampler(
+ num_steps=steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ s_churn=s_churn,
+ s_tmin=s_tmin,
+ s_tmax=s_tmax,
+ s_noise=s_noise,
+ verbose=True,
+ )
+ elif (
+ sampler_name == "EulerAncestralSampler"
+ or sampler_name == "DPMPP2SAncestralSampler"
+ ):
+ s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
+ eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
+
+ if sampler_name == "EulerAncestralSampler":
+ sampler = EulerAncestralSampler(
+ num_steps=steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ eta=eta,
+ s_noise=s_noise,
+ verbose=True,
+ )
+ elif sampler_name == "DPMPP2SAncestralSampler":
+ sampler = DPMPP2SAncestralSampler(
+ num_steps=steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ eta=eta,
+ s_noise=s_noise,
+ verbose=True,
+ )
+ elif sampler_name == "DPMPP2MSampler":
+ sampler = DPMPP2MSampler(
+ num_steps=steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ verbose=True,
+ )
+ elif sampler_name == "LinearMultistepSampler":
+ order = st.sidebar.number_input("order", value=4, min_value=1)
+ sampler = LinearMultistepSampler(
+ num_steps=steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ order=order,
+ verbose=True,
+ )
+ else:
+ raise ValueError(f"unknown sampler {sampler_name}!")
+
+ return sampler
+
+
+def get_interactive_image(key=None) -> Image.Image:
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
+ if image is not None:
+ image = Image.open(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ return image
+
+
+def load_img(display=True, key=None):
+ image = get_interactive_image(key=key)
+ if image is None:
+ return None
+ if display:
+ st.image(image)
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h})")
+
+ transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Lambda(lambda x: x * 2.0 - 1.0),
+ ]
+ )
+ img = transform(image)[None, ...]
+ st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
+ return img
+
+
+def get_init_img(batch_size=1, key=None):
+ init_image = load_img(key=key).cuda()
+ init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
+ return init_image
+
+
+def do_sample(
+ model,
+ sampler,
+ value_dict,
+ num_samples,
+ H,
+ W,
+ C,
+ F,
+ force_uc_zero_embeddings: List = None,
+ batch2model_input: List = None,
+ return_latents=False,
+ filter=None,
+):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ if batch2model_input is None:
+ batch2model_input = []
+
+ st.text("Sampling")
+
+ outputs = st.empty()
+ precision_scope = autocast
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with model.ema_scope():
+ num_samples = [num_samples]
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
+ value_dict,
+ num_samples,
+ )
+ for key in batch:
+ if isinstance(batch[key], torch.Tensor):
+ print(key, batch[key].shape)
+ elif isinstance(batch[key], list):
+ print(key, [len(l) for l in batch[key]])
+ else:
+ print(key, batch[key])
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in c:
+ if not k == "crossattn":
+ c[k], uc[k] = map(
+ lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
+ )
+
+ additional_model_inputs = {}
+ for k in batch2model_input:
+ additional_model_inputs[k] = batch[k]
+
+ shape = (math.prod(num_samples), C, H // F, W // F)
+ randn = torch.randn(shape).to("cuda")
+
+ def denoiser(input, sigma, c):
+ return model.denoiser(
+ model.model, input, sigma, c, **additional_model_inputs
+ )
+
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
+ samples_x = model.decode_first_stage(samples_z)
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if filter is not None:
+ samples = filter(samples)
+
+ grid = torch.stack([samples])
+ grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
+ outputs.image(grid.cpu().numpy())
+
+ if return_latents:
+ return samples, samples_z
+ return samples
+
+
+def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
+ # Hardcoded demo setups; might undergo some changes in the future
+
+ batch = {}
+ batch_uc = {}
+
+ for key in keys:
+ if key == "txt":
+ batch["txt"] = (
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
+ .reshape(N)
+ .tolist()
+ )
+ batch_uc["txt"] = (
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
+ .reshape(N)
+ .tolist()
+ )
+ elif key == "original_size_as_tuple":
+ batch["original_size_as_tuple"] = (
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+ elif key == "crop_coords_top_left":
+ batch["crop_coords_top_left"] = (
+ torch.tensor(
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
+ )
+ .to(device)
+ .repeat(*N, 1)
+ )
+ elif key == "aesthetic_score":
+ batch["aesthetic_score"] = (
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
+ )
+ batch_uc["aesthetic_score"] = (
+ torch.tensor([value_dict["negative_aesthetic_score"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+
+ elif key == "target_size_as_tuple":
+ batch["target_size_as_tuple"] = (
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+ else:
+ batch[key] = value_dict[key]
+
+ for key in batch.keys():
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
+ batch_uc[key] = torch.clone(batch[key])
+ return batch, batch_uc
+
+
+@torch.no_grad()
+def do_img2img(
+ img,
+ model,
+ sampler,
+ value_dict,
+ num_samples,
+ force_uc_zero_embeddings=[],
+ additional_kwargs={},
+ offset_noise_level: int = 0.0,
+ return_latents=False,
+ skip_encode=False,
+ filter=None,
+):
+ st.text("Sampling")
+
+ outputs = st.empty()
+ precision_scope = autocast
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with model.ema_scope():
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
+ value_dict,
+ [num_samples],
+ )
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in c:
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
+
+ for k in additional_kwargs:
+ c[k] = uc[k] = additional_kwargs[k]
+ if skip_encode:
+ z = img
+ else:
+ z = model.encode_first_stage(img)
+ noise = torch.randn_like(z)
+ sigmas = sampler.discretization(sampler.num_steps)
+ sigma = sigmas[0]
+
+ st.info(f"all sigmas: {sigmas}")
+ st.info(f"noising sigma: {sigma}")
+
+ if offset_noise_level > 0.0:
+ noise = noise + offset_noise_level * append_dims(
+ torch.randn(z.shape[0], device=z.device), z.ndim
+ )
+ noised_z = z + noise * append_dims(sigma, z.ndim)
+ noised_z = noised_z / torch.sqrt(
+ 1.0 + sigmas[0] ** 2.0
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
+
+ def denoiser(x, sigma, c):
+ return model.denoiser(model.model, x, sigma, c)
+
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
+ samples_x = model.decode_first_stage(samples_z)
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if filter is not None:
+ samples = filter(samples)
+
+ grid = embed_watemark(torch.stack([samples]))
+ grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
+ outputs.image(grid.cpu().numpy())
+ if return_latents:
+ return samples, samples_z
+ return samples
diff --git a/scripts/util/detection/nsfw_and_watermark_dectection.py b/scripts/util/detection/nsfw_and_watermark_dectection.py
new file mode 100644
index 000000000..af84acf30
--- /dev/null
+++ b/scripts/util/detection/nsfw_and_watermark_dectection.py
@@ -0,0 +1,104 @@
+import os
+import torch
+import numpy as np
+import torchvision.transforms as T
+from PIL import Image
+import clip
+
+RESOURCES_ROOT = "scripts/util/detection/"
+
+
+def predict_proba(X, weights, biases):
+ logits = X @ weights.T + biases
+ proba = np.where(
+ logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
+ )
+ return proba.T
+
+
+def load_model_weights(path: str):
+ model_weights = np.load(path)
+ return model_weights["weights"], model_weights["biases"]
+
+
+def clip_process_images(images: torch.Tensor) -> torch.Tensor:
+ min_size = min(images.shape[-2:])
+ return T.Compose(
+ [
+ T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
+ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
+ T.Normalize(
+ (0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711),
+ ),
+ ]
+ )(images)
+
+
+class DeepFloydDataFiltering(object):
+ def __init__(self, verbose: bool = False):
+ super().__init__()
+ self.verbose = verbose
+ self.clip_model, _ = clip.load("ViT-L/14", device="cpu")
+ self.clip_model.eval()
+
+ self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
+ os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
+ )
+ self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
+ os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
+ )
+ self.w_threshold, self.p_threshold = 0.5, 0.5
+
+ @torch.inference_mode()
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
+ imgs = clip_process_images(images)
+ image_features = self.clip_model.encode_image(imgs.to("cpu"))
+ image_features = image_features.detach().cpu().numpy().astype(np.float16)
+ p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
+ w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
+ print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
+ query = p_pred > self.p_threshold
+ if query.sum() > 0:
+ print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
+ query = w_pred > self.w_threshold
+ if query.sum() > 0:
+ print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
+ return images
+
+
+def load_img(path: str) -> torch.Tensor:
+ image = Image.open(path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image_transforms = T.Compose(
+ [
+ T.ToTensor(),
+ ]
+ )
+ return image_transforms(image)[None, ...]
+
+
+def test(root):
+ from einops import rearrange
+
+ filter = DeepFloydDataFiltering(verbose=True)
+ for p in os.listdir((root)):
+ print(f"running on {p}...")
+ img = load_img(os.path.join(root, p))
+ filtered_img = filter(img)
+ filtered_img = rearrange(
+ 255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
+ ).astype(np.uint8)
+ Image.fromarray(filtered_img).save(
+ os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
+ )
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire(test)
+ print("done.")
diff --git a/scripts/util/detection/p_head_v1.npz b/scripts/util/detection/p_head_v1.npz
new file mode 100644
index 000000000..a2c3babe1
Binary files /dev/null and b/scripts/util/detection/p_head_v1.npz differ
diff --git a/scripts/util/detection/w_head_v1.npz b/scripts/util/detection/w_head_v1.npz
new file mode 100644
index 000000000..65030c5ea
Binary files /dev/null and b/scripts/util/detection/w_head_v1.npz differ
diff --git a/setup.py b/setup.py
new file mode 100644
index 000000000..3117b8851
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,13 @@
+from setuptools import find_packages, setup
+
+setup(
+ name="sgm",
+ version="0.0.1",
+ packages=find_packages(),
+ python_requires=">=3.8",
+ py_modules=["sgm"],
+ description="Stability Generative Models",
+ long_description=open("README.md", "r", encoding="utf-8").read(),
+ long_description_content_type="text/markdown",
+ url="https://github.com/Stability-AI/generative-models",
+)
diff --git a/sgm/__init__.py b/sgm/__init__.py
new file mode 100644
index 000000000..cc9c7dc57
--- /dev/null
+++ b/sgm/__init__.py
@@ -0,0 +1,3 @@
+from .data import StableDataModuleFromConfig
+from .models import AutoencodingEngine, DiffusionEngine
+from .util import instantiate_from_config
diff --git a/sgm/data/__init__.py b/sgm/data/__init__.py
new file mode 100644
index 000000000..7664a25c6
--- /dev/null
+++ b/sgm/data/__init__.py
@@ -0,0 +1 @@
+from .dataset import StableDataModuleFromConfig
diff --git a/sgm/data/cifar10.py b/sgm/data/cifar10.py
new file mode 100644
index 000000000..aa3ae6777
--- /dev/null
+++ b/sgm/data/cifar10.py
@@ -0,0 +1,67 @@
+import torchvision
+import pytorch_lightning as pl
+from torchvision import transforms
+from torch.utils.data import DataLoader, Dataset
+
+
+class CIFAR10DataDictWrapper(Dataset):
+ def __init__(self, dset):
+ super().__init__()
+ self.dset = dset
+
+ def __getitem__(self, i):
+ x, y = self.dset[i]
+ return {"jpg": x, "cls": y}
+
+ def __len__(self):
+ return len(self.dset)
+
+
+class CIFAR10Loader(pl.LightningDataModule):
+ def __init__(self, batch_size, num_workers=0, shuffle=True):
+ super().__init__()
+
+ transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
+ )
+
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.shuffle = shuffle
+ self.train_dataset = CIFAR10DataDictWrapper(
+ torchvision.datasets.CIFAR10(
+ root=".data/", train=True, download=True, transform=transform
+ )
+ )
+ self.test_dataset = CIFAR10DataDictWrapper(
+ torchvision.datasets.CIFAR10(
+ root=".data/", train=False, download=True, transform=transform
+ )
+ )
+
+ def prepare_data(self):
+ pass
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ )
diff --git a/sgm/data/dataset.py b/sgm/data/dataset.py
new file mode 100644
index 000000000..b72614999
--- /dev/null
+++ b/sgm/data/dataset.py
@@ -0,0 +1,80 @@
+from typing import Optional
+
+import torchdata.datapipes.iter
+import webdataset as wds
+from omegaconf import DictConfig
+from pytorch_lightning import LightningDataModule
+
+try:
+ from sdata import create_dataset, create_dummy_dataset, create_loader
+except ImportError as e:
+ print("#" * 100)
+ print("Datasets not yet available")
+ print("to enable, we need to add stable-datasets as a submodule")
+ print("please use ``git submodule update --init --recursive``")
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
+ print("#" * 100)
+ exit(1)
+
+
+class StableDataModuleFromConfig(LightningDataModule):
+ def __init__(
+ self,
+ train: DictConfig,
+ validation: Optional[DictConfig] = None,
+ test: Optional[DictConfig] = None,
+ skip_val_loader: bool = False,
+ dummy: bool = False,
+ ):
+ super().__init__()
+ self.train_config = train
+ assert (
+ "datapipeline" in self.train_config and "loader" in self.train_config
+ ), "train config requires the fields `datapipeline` and `loader`"
+
+ self.val_config = validation
+ if not skip_val_loader:
+ if self.val_config is not None:
+ assert (
+ "datapipeline" in self.val_config and "loader" in self.val_config
+ ), "validation config requires the fields `datapipeline` and `loader`"
+ else:
+ print(
+ "Warning: No Validation datapipeline defined, using that one from training"
+ )
+ self.val_config = train
+
+ self.test_config = test
+ if self.test_config is not None:
+ assert (
+ "datapipeline" in self.test_config and "loader" in self.test_config
+ ), "test config requires the fields `datapipeline` and `loader`"
+
+ self.dummy = dummy
+ if self.dummy:
+ print("#" * 100)
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
+ print("#" * 100)
+
+ def setup(self, stage: str) -> None:
+ print("Preparing datasets")
+ if self.dummy:
+ data_fn = create_dummy_dataset
+ else:
+ data_fn = create_dataset
+
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
+ if self.val_config:
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
+ if self.test_config:
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
+
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
+ return loader
+
+ def val_dataloader(self) -> wds.DataPipeline:
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
+
+ def test_dataloader(self) -> wds.DataPipeline:
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
diff --git a/sgm/data/mnist.py b/sgm/data/mnist.py
new file mode 100644
index 000000000..ab7478f40
--- /dev/null
+++ b/sgm/data/mnist.py
@@ -0,0 +1,85 @@
+import torchvision
+import pytorch_lightning as pl
+from torchvision import transforms
+from torch.utils.data import DataLoader, Dataset
+
+
+class MNISTDataDictWrapper(Dataset):
+ def __init__(self, dset):
+ super().__init__()
+ self.dset = dset
+
+ def __getitem__(self, i):
+ x, y = self.dset[i]
+ return {"jpg": x, "cls": y}
+
+ def __len__(self):
+ return len(self.dset)
+
+
+class MNISTLoader(pl.LightningDataModule):
+ def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
+ super().__init__()
+
+ transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
+ )
+
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
+ self.shuffle = shuffle
+ self.train_dataset = MNISTDataDictWrapper(
+ torchvision.datasets.MNIST(
+ root=".data/", train=True, download=True, transform=transform
+ )
+ )
+ self.test_dataset = MNISTDataDictWrapper(
+ torchvision.datasets.MNIST(
+ root=".data/", train=False, download=True, transform=transform
+ )
+ )
+
+ def prepare_data(self):
+ pass
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ )
+
+
+if __name__ == "__main__":
+ dset = MNISTDataDictWrapper(
+ torchvision.datasets.MNIST(
+ root=".data/",
+ train=False,
+ download=True,
+ transform=transforms.Compose(
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
+ ),
+ )
+ )
+ ex = dset[0]
diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py
new file mode 100644
index 000000000..b2f4d384c
--- /dev/null
+++ b/sgm/lr_scheduler.py
@@ -0,0 +1,135 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+
+ def __init__(
+ self,
+ warm_up_steps,
+ lr_min,
+ lr_max,
+ lr_start,
+ max_decay_steps,
+ verbosity_interval=0,
+ ):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (
+ self.lr_max - self.lr_start
+ ) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (
+ self.lr_max_decay_steps - self.lr_warm_up_steps
+ )
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+
+ def __init__(
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
+ ):
+ assert (
+ len(warm_up_steps)
+ == len(f_min)
+ == len(f_max)
+ == len(f_start)
+ == len(cycle_lengths)
+ )
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
+ )
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
+ self.cycle_lengths[cycle] - n
+ ) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
diff --git a/sgm/models/__init__.py b/sgm/models/__init__.py
new file mode 100644
index 000000000..c410b3747
--- /dev/null
+++ b/sgm/models/__init__.py
@@ -0,0 +1,2 @@
+from .autoencoder import AutoencodingEngine
+from .diffusion import DiffusionEngine
diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py
new file mode 100644
index 000000000..78fb551a2
--- /dev/null
+++ b/sgm/models/autoencoder.py
@@ -0,0 +1,335 @@
+import re
+from abc import abstractmethod
+from contextlib import contextmanager
+from typing import Any, Dict, Tuple, Union
+
+import pytorch_lightning as pl
+import torch
+from omegaconf import ListConfig
+from packaging import version
+from safetensors.torch import load_file as load_safetensors
+
+from ..modules.diffusionmodules.model import Decoder, Encoder
+from ..modules.distributions.distributions import DiagonalGaussianDistribution
+from ..modules.ema import LitEma
+from ..util import default, get_obj_from_str, instantiate_from_config
+
+
+class AbstractAutoencoder(pl.LightningModule):
+ """
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
+ unCLIP models, etc. Hence, it is fairly general, and specific features
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
+ """
+
+ def __init__(
+ self,
+ ema_decay: Union[None, float] = None,
+ monitor: Union[None, str] = None,
+ input_key: str = "jpg",
+ ckpt_path: Union[None, str] = None,
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
+ ):
+ super().__init__()
+ self.input_key = input_key
+ self.use_ema = ema_decay is not None
+ if monitor is not None:
+ self.monitor = monitor
+
+ if self.use_ema:
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ self.automatic_optimization = False
+
+ def init_from_ckpt(
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
+ ) -> None:
+ if path.endswith("ckpt"):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ elif path.endswith("safetensors"):
+ sd = load_safetensors(path)
+ else:
+ raise NotImplementedError
+
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if re.match(ik, k):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @abstractmethod
+ def get_input(self, batch) -> Any:
+ raise NotImplementedError()
+
+ def on_train_batch_end(self, *args, **kwargs):
+ # for EMA computation
+ if self.use_ema:
+ self.model_ema(self)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @abstractmethod
+ def encode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("encode()-method of abstract base class called")
+
+ @abstractmethod
+ def decode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("decode()-method of abstract base class called")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self) -> Any:
+ raise NotImplementedError()
+
+
+class AutoencodingEngine(AbstractAutoencoder):
+ """
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
+ (we also restore them explicitly as special cases for legacy reasons).
+ Regularizations such as KL or VQ are moved to the regularizer class.
+ """
+
+ def __init__(
+ self,
+ *args,
+ encoder_config: Dict,
+ decoder_config: Dict,
+ loss_config: Dict,
+ regularizer_config: Dict,
+ optimizer_config: Union[Dict, None] = None,
+ lr_g_factor: float = 1.0,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ # todo: add options to freeze encoder/decoder
+ self.encoder = instantiate_from_config(encoder_config)
+ self.decoder = instantiate_from_config(decoder_config)
+ self.loss = instantiate_from_config(loss_config)
+ self.regularization = instantiate_from_config(regularizer_config)
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.Adam"}
+ )
+ self.lr_g_factor = lr_g_factor
+
+ def get_input(self, batch: Dict) -> torch.Tensor:
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
+ return batch[self.input_key]
+
+ def get_autoencoder_params(self) -> list:
+ params = (
+ list(self.encoder.parameters())
+ + list(self.decoder.parameters())
+ + list(self.regularization.get_trainable_parameters())
+ + list(self.loss.get_trainable_autoencoder_parameters())
+ )
+ return params
+
+ def get_discriminator_params(self) -> list:
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
+ return params
+
+ def get_last_layer(self):
+ return self.decoder.get_last_layer()
+
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
+ z = self.encoder(x)
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: Any) -> torch.Tensor:
+ x = self.decoder(z)
+ return x
+
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ z, reg_log = self.encode(x, return_reg_log=True)
+ dec = self.decode(z)
+ return z, dec, reg_log
+
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
+ x = self.get_input(batch)
+ z, xrec, regularization_log = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(
+ regularization_log,
+ x,
+ xrec,
+ optimizer_idx,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+
+ self.log_dict(
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
+ )
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(
+ regularization_log,
+ x,
+ xrec,
+ optimizer_idx,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+ self.log_dict(
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
+ )
+ return discloss
+
+ def validation_step(self, batch, batch_idx) -> Dict:
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ log_dict.update(log_dict_ema)
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
+ x = self.get_input(batch)
+
+ z, xrec, regularization_log = self(x)
+ aeloss, log_dict_ae = self.loss(
+ regularization_log,
+ x,
+ xrec,
+ 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val" + postfix,
+ )
+
+ discloss, log_dict_disc = self.loss(
+ regularization_log,
+ x,
+ xrec,
+ 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val" + postfix,
+ )
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ log_dict_ae.update(log_dict_disc)
+ self.log_dict(log_dict_ae)
+ return log_dict_ae
+
+ def configure_optimizers(self) -> Any:
+ ae_params = self.get_autoencoder_params()
+ disc_params = self.get_discriminator_params()
+
+ opt_ae = self.instantiate_optimizer_from_config(
+ ae_params,
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
+ self.optimizer_config,
+ )
+ opt_disc = self.instantiate_optimizer_from_config(
+ disc_params, self.learning_rate, self.optimizer_config
+ )
+
+ return [opt_ae, opt_disc], []
+
+ @torch.no_grad()
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
+ log = dict()
+ x = self.get_input(batch)
+ _, xrec, _ = self(x)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ with self.ema_scope():
+ _, xrec_ema, _ = self(x)
+ log["reconstructions_ema"] = xrec_ema
+ return log
+
+
+class AutoencoderKL(AutoencodingEngine):
+ def __init__(self, embed_dim: int, **kwargs):
+ ddconfig = kwargs.pop("ddconfig")
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", ())
+ super().__init__(
+ encoder_config={"target": "torch.nn.Identity"},
+ decoder_config={"target": "torch.nn.Identity"},
+ regularizer_config={"target": "torch.nn.Identity"},
+ loss_config=kwargs.pop("lossconfig"),
+ **kwargs,
+ )
+ assert ddconfig["double_z"]
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def encode(self, x):
+ assert (
+ not self.training
+ ), f"{self.__class__.__name__} only supports inference currently"
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, **decoder_kwargs):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z, **decoder_kwargs)
+ return dec
+
+
+class AutoencoderKLInferenceWrapper(AutoencoderKL):
+ def encode(self, x):
+ return super().encode(x).sample()
+
+
+class IdentityFirstStage(AbstractAutoencoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get_input(self, x: Any) -> Any:
+ return x
+
+ def encode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+ def decode(self, x: Any, *args, **kwargs) -> Any:
+ return x
diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py
new file mode 100644
index 000000000..1cdbf541e
--- /dev/null
+++ b/sgm/models/diffusion.py
@@ -0,0 +1,324 @@
+from contextlib import contextmanager
+from typing import Any, Dict, List, Tuple, Union
+
+import pytorch_lightning as pl
+import torch
+from omegaconf import ListConfig, OmegaConf
+from safetensors.torch import load_file as load_safetensors
+from torch.optim.lr_scheduler import LambdaLR
+
+from ..modules import UNCONDITIONAL_CONFIG
+from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
+from ..modules.ema import LitEma
+from ..util import (
+ default,
+ disabled_train,
+ get_obj_from_str,
+ instantiate_from_config,
+ log_txt_as_img,
+)
+
+
+class DiffusionEngine(pl.LightningModule):
+ def __init__(
+ self,
+ network_config,
+ denoiser_config,
+ first_stage_config,
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ network_wrapper: Union[None, str] = None,
+ ckpt_path: Union[None, str] = None,
+ use_ema: bool = False,
+ ema_decay_rate: float = 0.9999,
+ scale_factor: float = 1.0,
+ disable_first_stage_autocast=False,
+ input_key: str = "jpg",
+ log_keys: Union[List, None] = None,
+ no_cond_log: bool = False,
+ compile_model: bool = False,
+ ):
+ super().__init__()
+ self.log_keys = log_keys
+ self.input_key = input_key
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.AdamW"}
+ )
+ model = instantiate_from_config(network_config)
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
+ model, compile_model=compile_model
+ )
+
+ self.denoiser = instantiate_from_config(denoiser_config)
+ self.sampler = (
+ instantiate_from_config(sampler_config)
+ if sampler_config is not None
+ else None
+ )
+ self.conditioner = instantiate_from_config(
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
+ )
+ self.scheduler_config = scheduler_config
+ self._init_first_stage(first_stage_config)
+
+ self.loss_fn = (
+ instantiate_from_config(loss_fn_config)
+ if loss_fn_config is not None
+ else None
+ )
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.scale_factor = scale_factor
+ self.disable_first_stage_autocast = disable_first_stage_autocast
+ self.no_cond_log = no_cond_log
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def init_from_ckpt(
+ self,
+ path: str,
+ ) -> None:
+ if path.endswith("ckpt"):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ elif path.endswith("safetensors"):
+ sd = load_safetensors(path)
+ else:
+ raise NotImplementedError
+
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def _init_first_stage(self, config):
+ model = instantiate_from_config(config).eval()
+ model.train = disabled_train
+ for param in model.parameters():
+ param.requires_grad = False
+ self.first_stage_model = model
+
+ def get_input(self, batch):
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in bchw format
+ return batch[self.input_key]
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ z = 1.0 / self.scale_factor * z
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ out = self.first_stage_model.decode(z)
+ return out
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ z = self.first_stage_model.encode(x)
+ z = self.scale_factor * z
+ return z
+
+ def forward(self, x, batch):
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
+ loss_mean = loss.mean()
+ loss_dict = {"loss": loss_mean}
+ return loss_mean, loss_dict
+
+ def shared_step(self, batch: Dict) -> Any:
+ x = self.get_input(batch)
+ x = self.encode_first_stage(x)
+ batch["global_step"] = self.global_step
+ loss, loss_dict = self(x, batch)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ self.log(
+ "global_step",
+ self.global_step,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ if self.scheduler_config is not None:
+ lr = self.optimizers().param_groups[0]["lr"]
+ self.log(
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ return loss
+
+ def on_train_start(self, *args, **kwargs):
+ if self.sampler is None or self.loss_fn is None:
+ raise ValueError("Sampler and loss function need to be set for training.")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ for embedder in self.conditioner.embedders:
+ if embedder.is_trainable:
+ params = params + list(embedder.parameters())
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+ "interval": "step",
+ "frequency": 1,
+ }
+ ]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond: Dict,
+ uc: Union[Dict, None] = None,
+ batch_size: int = 16,
+ shape: Union[None, Tuple, List] = None,
+ **kwargs,
+ ):
+ randn = torch.randn(batch_size, *shape).to(self.device)
+
+ denoiser = lambda input, sigma, c: self.denoiser(
+ self.model, input, sigma, c, **kwargs
+ )
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
+ return samples
+
+ @torch.no_grad()
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
+ """
+ Defines heuristics to log different conditionings.
+ These can be lists of strings (text-to-image), tensors, ints, ...
+ """
+ image_h, image_w = batch[self.input_key].shape[2:]
+ log = dict()
+
+ for embedder in self.conditioner.embedders:
+ if (
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
+ ) and not self.no_cond_log:
+ x = batch[embedder.input_key][:n]
+ if isinstance(x, torch.Tensor):
+ if x.dim() == 1:
+ # class-conditional, convert integer to string
+ x = [str(x[i].item()) for i in range(x.shape[0])]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
+ elif x.dim() == 2:
+ # size and crop cond and the like
+ x = [
+ "x".join([str(xx) for xx in x[i].tolist()])
+ for i in range(x.shape[0])
+ ]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ elif isinstance(x, Union[List, ListConfig]):
+ if isinstance(x[0], str):
+ # strings
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ elif isinstance(x[0], Union[ListConfig, List]):
+ # # case: videos processed
+ x = [xx[0] for xx in x]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError()
+ log[embedder.input_key] = xc
+ return log
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch: Dict,
+ N: int = 8,
+ sample: bool = True,
+ ucg_keys: List[str] = None,
+ **kwargs,
+ ) -> Dict:
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
+ if ucg_keys:
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
+ )
+ else:
+ ucg_keys = conditioner_input_keys
+ log = dict()
+
+ x = self.get_input(batch)
+
+ c, uc = self.conditioner.get_unconditional_conditioning(
+ batch,
+ force_uc_zero_embeddings=ucg_keys
+ if len(self.conditioner.embedders) > 0
+ else [],
+ )
+
+ sampling_kwargs = {}
+
+ N = min(x.shape[0], N)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+ z = self.encode_first_stage(x)
+ log["reconstructions"] = self.decode_first_stage(z)
+ log.update(self.log_conditionings(batch, N))
+
+ for k in c:
+ if isinstance(c[k], torch.Tensor):
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
+
+ if sample:
+ with self.ema_scope("Plotting"):
+ samples = self.sample(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+ log["samples"] = samples
+ return log
diff --git a/sgm/modules/__init__.py b/sgm/modules/__init__.py
new file mode 100644
index 000000000..0db1d7716
--- /dev/null
+++ b/sgm/modules/__init__.py
@@ -0,0 +1,6 @@
+from .encoders.modules import GeneralConditioner
+
+UNCONDITIONAL_CONFIG = {
+ "target": "sgm.modules.GeneralConditioner",
+ "params": {"emb_models": []},
+}
diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py
new file mode 100644
index 000000000..a17edda72
--- /dev/null
+++ b/sgm/modules/attention.py
@@ -0,0 +1,947 @@
+import math
+from inspect import isfunction
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from packaging import version
+from torch import nn
+
+if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ SDP_IS_AVAILABLE = True
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ BACKEND_MAP = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
+ }
+else:
+ from contextlib import nullcontext
+
+ SDP_IS_AVAILABLE = False
+ sdp_kernel = nullcontext
+ BACKEND_MAP = {}
+ print(
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
+ )
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ print("no module 'xformers'. Processing without...")
+
+from .diffusionmodules.util import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ backend=None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.backend = backend
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ h = self.heads
+
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
+ )
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+ ## old
+ """
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ """
+ ## new
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
+ out = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask
+ ) # scale is dim_head ** -0.5 per default
+
+ del q, k, v
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
+ ):
+ super().__init__()
+ print(
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads with a dimension of {dim_head}."
+ )
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.attention_op: Optional[Any] = None
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ # TODO: Use this directly in the attention operation, as a bias
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attn_mode="softmax",
+ sdp_backend=None,
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
+ print(
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
+ )
+ attn_mode = "softmax"
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
+ print(
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
+ )
+ if not XFORMERS_IS_AVAILABLE:
+ assert (
+ False
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ else:
+ print("Falling back to xformers efficient attention.")
+ attn_mode = "softmax-xformers"
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
+ else:
+ assert sdp_backend is None
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None,
+ backend=sdp_backend,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ backend=sdp_backend,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ print(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
+ ):
+ kwargs = {"x": x}
+
+ if context is not None:
+ kwargs.update({"context": context})
+
+ if additional_tokens is not None:
+ kwargs.update({"additional_tokens": additional_tokens})
+
+ if n_times_crossframe_attn_in_self:
+ kwargs.update(
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
+ )
+
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
+ return checkpoint(
+ self._forward, (x, context), self.parameters(), self.checkpoint
+ )
+
+ def _forward(
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
+ ):
+ x = (
+ self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ additional_tokens=additional_tokens,
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
+ if not self.disable_self_attn
+ else 0,
+ )
+ + x
+ )
+ x = (
+ self.attn2(
+ self.norm2(x), context=context, additional_tokens=additional_tokens
+ )
+ + x
+ )
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class BasicTransformerSingleLayerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ attn_mode="softmax",
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(
+ self._forward, (x, context), self.parameters(), self.checkpoint
+ )
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context) + x
+ x = self.ff(self.norm2(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ attn_type="softmax",
+ use_checkpoint=True,
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
+ sdp_backend=None,
+ ):
+ super().__init__()
+ print(
+ f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
+ )
+ from omegaconf import ListConfig
+
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
+ context_dim = [context_dim]
+ if exists(context_dim) and isinstance(context_dim, list):
+ if depth != len(context_dim):
+ print(
+ f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
+ f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
+ )
+ # depth does not match context dims.
+ assert all(
+ map(lambda x: x == context_dim[0], context_dim)
+ ), "need homogenous context_dim to match depth automatically"
+ context_dim = depth * [context_dim[0]]
+ elif context_dim is None:
+ context_dim = [None] * depth
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ attn_mode=attn_type,
+ checkpoint=use_checkpoint,
+ sdp_backend=sdp_backend,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ if i > 0 and len(context) == 1:
+ i = 0 # use same context for each block
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+def benchmark_attn():
+ # Lets define a helpful benchmarking function:
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ import torch.nn.functional as F
+ import torch.utils.benchmark as benchmark
+
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+ )
+ return t0.blocked_autorange().mean * 1e6
+
+ # Lets define the hyper-parameters of our input
+ batch_size = 32
+ max_sequence_len = 1024
+ num_heads = 32
+ embed_dimension = 32
+
+ dtype = torch.float16
+
+ query = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+ key = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+ value = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
+
+ # Lets explore the speed of each of the 3 implementations
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ # Helpful arguments mapper
+ backend_map = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ }
+
+ from torch.profiler import ProfilerActivity, profile, record_function
+
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
+
+ print(
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("Default detailed stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ print(
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("Math implmentation stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
+ try:
+ print(
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ except RuntimeError:
+ print("FlashAttention is not supported. See warnings for reasons.")
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("FlashAttention stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
+ try:
+ print(
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ except RuntimeError:
+ print("EfficientAttention is not supported. See warnings for reasons.")
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("EfficientAttention stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+
+def run_model(model, x, context):
+ return model(x, context)
+
+
+def benchmark_transformer_blocks():
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ import torch.utils.benchmark as benchmark
+
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+ )
+ return t0.blocked_autorange().mean * 1e6
+
+ checkpoint = True
+ compile = False
+
+ batch_size = 32
+ h, w = 64, 64
+ context_len = 77
+ embed_dimension = 1024
+ context_dim = 1024
+ d_head = 64
+
+ transformer_depth = 4
+
+ n_heads = embed_dimension // d_head
+
+ dtype = torch.float16
+
+ model_native = SpatialTransformer(
+ embed_dimension,
+ n_heads,
+ d_head,
+ context_dim=context_dim,
+ use_linear=True,
+ use_checkpoint=checkpoint,
+ attn_type="softmax",
+ depth=transformer_depth,
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
+ ).to(device)
+ model_efficient_attn = SpatialTransformer(
+ embed_dimension,
+ n_heads,
+ d_head,
+ context_dim=context_dim,
+ use_linear=True,
+ depth=transformer_depth,
+ use_checkpoint=checkpoint,
+ attn_type="softmax-xformers",
+ ).to(device)
+ if not checkpoint and compile:
+ print("compiling models")
+ model_native = torch.compile(model_native)
+ model_efficient_attn = torch.compile(model_efficient_attn)
+
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
+
+ from torch.profiler import ProfilerActivity, profile, record_function
+
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
+
+ with torch.autocast("cuda"):
+ print(
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
+ )
+ print(
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
+ )
+
+ print(75 * "+")
+ print("NATIVE")
+ print(75 * "+")
+ torch.cuda.reset_peak_memory_stats()
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("NativeAttention stats"):
+ for _ in range(25):
+ model_native(x, c)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
+
+ print(75 * "+")
+ print("Xformers")
+ print(75 * "+")
+ torch.cuda.reset_peak_memory_stats()
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("xformers stats"):
+ for _ in range(25):
+ model_efficient_attn(x, c)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
+
+
+def test01():
+ # conv1x1 vs linear
+ from ..util import count_params
+
+ conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
+ print(count_params(conv))
+ linear = torch.nn.Linear(3, 32).cuda()
+ print(count_params(linear))
+
+ print(conv.weight.shape)
+
+ # use same initialization
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
+ linear.bias = torch.nn.Parameter(conv.bias)
+
+ print(linear.weight.shape)
+
+ x = torch.randn(11, 3, 64, 64).cuda()
+
+ xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ print(xr.shape)
+ out_linear = linear(xr)
+ print(out_linear.mean(), out_linear.shape)
+
+ out_conv = conv(x)
+ print(out_conv.mean(), out_conv.shape)
+ print("done with test01.\n")
+
+
+def test02():
+ # try cosine flash attention
+ import time
+
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+ print("testing cosine flash attention...")
+ DIM = 1024
+ SEQLEN = 4096
+ BS = 16
+
+ print(" softmax (vanilla) first...")
+ model = BasicTransformerBlock(
+ dim=DIM,
+ n_heads=16,
+ d_head=64,
+ dropout=0.0,
+ context_dim=None,
+ attn_mode="softmax",
+ ).cuda()
+ try:
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
+ tic = time.time()
+ y = model(x)
+ toc = time.time()
+ print(y.shape, toc - tic)
+ except RuntimeError as e:
+ # likely oom
+ print(str(e))
+
+ print("\n now flash-cosine...")
+ model = BasicTransformerBlock(
+ dim=DIM,
+ n_heads=16,
+ d_head=64,
+ dropout=0.0,
+ context_dim=None,
+ attn_mode="flash-cosine",
+ ).cuda()
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
+ tic = time.time()
+ y = model(x)
+ toc = time.time()
+ print(y.shape, toc - tic)
+ print("done with test02.\n")
+
+
+if __name__ == "__main__":
+ # test01()
+ # test02()
+ # test03()
+
+ # benchmark_attn()
+ benchmark_transformer_blocks()
+
+ print("done.")
diff --git a/sgm/modules/autoencoding/__init__.py b/sgm/modules/autoencoding/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py
new file mode 100644
index 000000000..6a3b54f72
--- /dev/null
+++ b/sgm/modules/autoencoding/losses/__init__.py
@@ -0,0 +1,246 @@
+from typing import Any, Union
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+from ....util import default, instantiate_from_config
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.0):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+class LatentLPIPS(nn.Module):
+ def __init__(
+ self,
+ decoder_config,
+ perceptual_weight=1.0,
+ latent_weight=1.0,
+ scale_input_to_tgt_size=False,
+ scale_tgt_to_input_size=False,
+ perceptual_weight_on_inputs=0.0,
+ ):
+ super().__init__()
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
+ self.init_decoder(decoder_config)
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ self.latent_weight = latent_weight
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
+
+ def init_decoder(self, config):
+ self.decoder = instantiate_from_config(config)
+ if hasattr(self.decoder, "encoder"):
+ del self.decoder.encoder
+
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
+ log = dict()
+ loss = (latent_inputs - latent_predictions) ** 2
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
+ image_reconstructions = None
+ if self.perceptual_weight > 0.0:
+ image_reconstructions = self.decoder.decode(latent_predictions)
+ image_targets = self.decoder.decode(latent_inputs)
+ perceptual_loss = self.perceptual_loss(
+ image_targets.contiguous(), image_reconstructions.contiguous()
+ )
+ loss = (
+ self.latent_weight * loss.mean()
+ + self.perceptual_weight * perceptual_loss.mean()
+ )
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
+
+ if self.perceptual_weight_on_inputs > 0.0:
+ image_reconstructions = default(
+ image_reconstructions, self.decoder.decode(latent_predictions)
+ )
+ if self.scale_input_to_tgt_size:
+ image_inputs = torch.nn.functional.interpolate(
+ image_inputs,
+ image_reconstructions.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+ elif self.scale_tgt_to_input_size:
+ image_reconstructions = torch.nn.functional.interpolate(
+ image_reconstructions,
+ image_inputs.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+
+ perceptual_loss2 = self.perceptual_loss(
+ image_inputs.contiguous(), image_reconstructions.contiguous()
+ )
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
+ return loss, log
+
+
+class GeneralLPIPSWithDiscriminator(nn.Module):
+ def __init__(
+ self,
+ disc_start: int,
+ logvar_init: float = 0.0,
+ pixelloss_weight=1.0,
+ disc_num_layers: int = 3,
+ disc_in_channels: int = 3,
+ disc_factor: float = 1.0,
+ disc_weight: float = 1.0,
+ perceptual_weight: float = 1.0,
+ disc_loss: str = "hinge",
+ scale_input_to_tgt_size: bool = False,
+ dims: int = 2,
+ learn_logvar: bool = False,
+ regularization_weights: Union[None, dict] = None,
+ ):
+ super().__init__()
+ self.dims = dims
+ if self.dims > 2:
+ print(
+ f"running with dims={dims}. This means that for perceptual loss calculation, "
+ f"the LPIPS loss will be applied to each frame independently. "
+ )
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ assert disc_loss in ["hinge", "vanilla"]
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+ self.learn_logvar = learn_logvar
+
+ self.discriminator = NLayerDiscriminator(
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.regularization_weights = default(regularization_weights, {})
+
+ def get_trainable_parameters(self) -> Any:
+ return self.discriminator.parameters()
+
+ def get_trainable_autoencoder_parameters(self) -> Any:
+ if self.learn_logvar:
+ yield self.logvar
+ yield from ()
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(
+ nll_loss, self.last_layer[0], retain_graph=True
+ )[0]
+ g_grads = torch.autograd.grad(
+ g_loss, self.last_layer[0], retain_graph=True
+ )[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ regularization_log,
+ inputs,
+ reconstructions,
+ optimizer_idx,
+ global_step,
+ last_layer=None,
+ split="train",
+ weights=None,
+ ):
+ if self.scale_input_to_tgt_size:
+ inputs = torch.nn.functional.interpolate(
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
+ )
+
+ if self.dims > 2:
+ inputs, reconstructions = map(
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
+ (inputs, reconstructions),
+ )
+
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(
+ inputs.contiguous(), reconstructions.contiguous()
+ )
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(
+ nll_loss, g_loss, last_layer=last_layer
+ )
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
+ log = dict()
+ for k in regularization_log:
+ if k in self.regularization_weights:
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
+ log[f"{split}/{k}"] = regularization_log[k].detach().mean()
+
+ log.update(
+ {
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/logvar".format(split): self.logvar.detach(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ )
+
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
+ }
+ return d_loss, log
diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py
new file mode 100644
index 000000000..8de3212d3
--- /dev/null
+++ b/sgm/modules/autoencoding/regularizers/__init__.py
@@ -0,0 +1,53 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ....modules.distributions.distributions import DiagonalGaussianDistribution
+
+
+class AbstractRegularizer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_trainable_parameters(self) -> Any:
+ raise NotImplementedError()
+
+
+class DiagonalGaussianRegularizer(AbstractRegularizer):
+ def __init__(self, sample: bool = True):
+ super().__init__()
+ self.sample = sample
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ log = dict()
+ posterior = DiagonalGaussianDistribution(z)
+ if self.sample:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ kl_loss = posterior.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ log["kl_loss"] = kl_loss
+ return z, log
+
+
+def measure_perplexity(predicted_indices, num_centroids):
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = (
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
+ )
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py
new file mode 100644
index 000000000..ce7968af9
--- /dev/null
+++ b/sgm/modules/diffusionmodules/__init__.py
@@ -0,0 +1,7 @@
+from .denoiser import Denoiser
+from .discretizer import Discretization
+from .loss import StandardDiffusionLoss
+from .model import Model, Encoder, Decoder
+from .openaimodel import UNetModel
+from .sampling import BaseDiffusionSampler
+from .wrappers import OpenAIWrapper
diff --git a/sgm/modules/diffusionmodules/denoiser.py b/sgm/modules/diffusionmodules/denoiser.py
new file mode 100644
index 000000000..4651e7de5
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser.py
@@ -0,0 +1,63 @@
+import torch.nn as nn
+
+from ...util import append_dims, instantiate_from_config
+
+
+class Denoiser(nn.Module):
+ def __init__(self, weighting_config, scaling_config):
+ super().__init__()
+
+ self.weighting = instantiate_from_config(weighting_config)
+ self.scaling = instantiate_from_config(scaling_config)
+
+ def possibly_quantize_sigma(self, sigma):
+ return sigma
+
+ def possibly_quantize_c_noise(self, c_noise):
+ return c_noise
+
+ def w(self, sigma):
+ return self.weighting(sigma)
+
+ def __call__(self, network, input, sigma, cond):
+ sigma = self.possibly_quantize_sigma(sigma)
+ sigma_shape = sigma.shape
+ sigma = append_dims(sigma, input.ndim)
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
+ return network(input * c_in, c_noise, cond) * c_out + input * c_skip
+
+
+class DiscreteDenoiser(Denoiser):
+ def __init__(
+ self,
+ weighting_config,
+ scaling_config,
+ num_idx,
+ discretization_config,
+ do_append_zero=False,
+ quantize_c_noise=True,
+ flip=True,
+ ):
+ super().__init__(weighting_config, scaling_config)
+ sigmas = instantiate_from_config(discretization_config)(
+ num_idx, do_append_zero=do_append_zero, flip=flip
+ )
+ self.register_buffer("sigmas", sigmas)
+ self.quantize_c_noise = quantize_c_noise
+
+ def sigma_to_idx(self, sigma):
+ dists = sigma - self.sigmas[:, None]
+ return dists.abs().argmin(dim=0).view(sigma.shape)
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas[idx]
+
+ def possibly_quantize_sigma(self, sigma):
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
+
+ def possibly_quantize_c_noise(self, c_noise):
+ if self.quantize_c_noise:
+ return self.sigma_to_idx(c_noise)
+ else:
+ return c_noise
diff --git a/sgm/modules/diffusionmodules/denoiser_scaling.py b/sgm/modules/diffusionmodules/denoiser_scaling.py
new file mode 100644
index 000000000..f8a2ac673
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser_scaling.py
@@ -0,0 +1,31 @@
+import torch
+
+
+class EDMScaling:
+ def __init__(self, sigma_data=0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma):
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_noise = 0.25 * sigma.log()
+ return c_skip, c_out, c_in, c_noise
+
+
+class EpsScaling:
+ def __call__(self, sigma):
+ c_skip = torch.ones_like(sigma, device=sigma.device)
+ c_out = -sigma
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VScaling:
+ def __call__(self, sigma):
+ c_skip = 1.0 / (sigma**2 + 1.0)
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
diff --git a/sgm/modules/diffusionmodules/denoiser_weighting.py b/sgm/modules/diffusionmodules/denoiser_weighting.py
new file mode 100644
index 000000000..b8b03ca58
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser_weighting.py
@@ -0,0 +1,24 @@
+import torch
+
+
+class UnitWeighting:
+ def __call__(self, sigma):
+ return torch.ones_like(sigma, device=sigma.device)
+
+
+class EDMWeighting:
+ def __init__(self, sigma_data=0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma):
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+
+class VWeighting(EDMWeighting):
+ def __init__(self):
+ super().__init__(sigma_data=1.0)
+
+
+class EpsWeighting:
+ def __call__(self, sigma):
+ return sigma**-2.0
diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py
new file mode 100644
index 000000000..f63218653
--- /dev/null
+++ b/sgm/modules/diffusionmodules/discretizer.py
@@ -0,0 +1,65 @@
+import torch
+import numpy as np
+from functools import partial
+
+from ...util import append_zero
+from ...modules.diffusionmodules.util import make_beta_schedule
+
+
+class Discretization:
+ def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
+ sigmas = self.get_sigmas(n, device)
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
+ return sigmas if not flip else torch.flip(sigmas, (0,))
+
+
+class EDMDiscretization(Discretization):
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
+ self.sigma_min = sigma_min
+ self.sigma_max = sigma_max
+ self.rho = rho
+
+ def get_sigmas(self, n, device):
+ ramp = torch.linspace(0, 1, n, device=device)
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
+ return sigmas
+
+
+class LegacyDDPMDiscretization(Discretization):
+ def __init__(
+ self,
+ linear_start=0.00085,
+ linear_end=0.0120,
+ num_timesteps=1000,
+ legacy_range=True,
+ ):
+ self.num_timesteps = num_timesteps
+ betas = make_beta_schedule(
+ "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
+ )
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
+ self.legacy_range = legacy_range
+
+ def get_sigmas(self, n, device):
+ if n < self.num_timesteps:
+ c = self.num_timesteps // n
+
+ if self.legacy_range:
+ timesteps = np.asarray(list(range(0, self.num_timesteps, c)))
+ timesteps += 1 # Legacy LDM Hack
+ else:
+ timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c)))
+ timesteps -= 1
+ timesteps = timesteps[1:]
+
+ alphas_cumprod = self.alphas_cumprod[timesteps]
+ else:
+ alphas_cumprod = self.alphas_cumprod
+
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ return torch.flip(sigmas, (0,))
diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py
new file mode 100644
index 000000000..7d33361d5
--- /dev/null
+++ b/sgm/modules/diffusionmodules/guiders.py
@@ -0,0 +1,53 @@
+from functools import partial
+
+import torch
+
+from ...util import default, instantiate_from_config
+
+
+class VanillaCFG:
+ """
+ implements parallelized CFG
+ """
+
+ def __init__(self, scale, dyn_thresh_config=None):
+ scale_schedule = lambda scale, sigma: scale # independent of step
+ self.scale_schedule = partial(scale_schedule, scale)
+ self.dyn_thresh = instantiate_from_config(
+ default(
+ dyn_thresh_config,
+ {
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
+ },
+ )
+ )
+
+ def __call__(self, x, sigma):
+ x_u, x_c = x.chunk(2)
+ scale_value = self.scale_schedule(sigma)
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
+ return x_pred
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ if k in ["vector", "crossattn", "concat"]:
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
+
+
+class IdentityGuider:
+ def __call__(self, x, sigma):
+ return x
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ c_out[k] = c[k]
+
+ return x, s, c_out
diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py
new file mode 100644
index 000000000..555abc1c3
--- /dev/null
+++ b/sgm/modules/diffusionmodules/loss.py
@@ -0,0 +1,69 @@
+from typing import List, Optional, Union
+
+import torch
+import torch.nn as nn
+from omegaconf import ListConfig
+from taming.modules.losses.lpips import LPIPS
+
+from ...util import append_dims, instantiate_from_config
+
+
+class StandardDiffusionLoss(nn.Module):
+ def __init__(
+ self,
+ sigma_sampler_config,
+ type="l2",
+ offset_noise_level=0.0,
+ batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
+ ):
+ super().__init__()
+
+ assert type in ["l2", "l1", "lpips"]
+
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
+
+ self.type = type
+ self.offset_noise_level = offset_noise_level
+
+ if type == "lpips":
+ self.lpips = LPIPS().eval()
+
+ if not batch2model_keys:
+ batch2model_keys = []
+
+ if isinstance(batch2model_keys, str):
+ batch2model_keys = [batch2model_keys]
+
+ self.batch2model_keys = set(batch2model_keys)
+
+ def __call__(self, network, denoiser, conditioner, input, batch):
+ cond = conditioner(batch)
+ additional_model_inputs = {
+ key: batch[key] for key in self.batch2model_keys.intersection(batch)
+ }
+
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
+ noise = torch.randn_like(input)
+ if self.offset_noise_level > 0.0:
+ noise = noise + self.offset_noise_level * append_dims(
+ torch.randn(input.shape[0], device=input.device), input.ndim
+ )
+ noised_input = input + noise * append_dims(sigmas, input.ndim)
+ model_output = denoiser(
+ network, noised_input, sigmas, cond, **additional_model_inputs
+ )
+ w = append_dims(denoiser.w(sigmas), input.ndim)
+ return self.get_loss(model_output, input, w)
+
+ def get_loss(self, model_output, target, w):
+ if self.type == "l2":
+ return torch.mean(
+ (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
+ )
+ elif self.type == "l1":
+ return torch.mean(
+ (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
+ )
+ elif self.type == "lpips":
+ loss = self.lpips(model_output, target).reshape(-1)
+ return loss
diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py
new file mode 100644
index 000000000..26efd0784
--- /dev/null
+++ b/sgm/modules/diffusionmodules/model.py
@@ -0,0 +1,743 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+from typing import Any, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from packaging import version
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ print("no module 'xformers'. Processing without...")
+
+from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q, k, v = map(
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
+ )
+ h_ = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v
+ ) # scale is dim ** -0.5 per default
+ # compute attention
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.attention_op: Optional[Any] = None
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
+ b, c, h, w = x.shape
+ x = rearrange(x, "b c h w -> b (h w) c")
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ "memory-efficient-cross-attn",
+ "linear",
+ "none",
+ ], f"attn_type {attn_type} unknown"
+ if (
+ version.parse(torch.__version__) < version.parse("2.0.0")
+ and attn_type != "none"
+ ):
+ assert XFORMERS_IS_AVAILABLE, (
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ )
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb
+ )
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print(
+ "Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)
+ )
+ )
+
+ make_attn_cls = self._make_attn()
+ make_resblock_cls = self._make_resblock()
+ make_conv_cls = self._make_conv()
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
+ self.mid.block_2 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = make_conv_cls(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def _make_attn(self) -> Callable:
+ return make_attn
+
+ def _make_resblock(self) -> Callable:
+ return ResnetBlock
+
+ def _make_conv(self) -> Callable:
+ return torch.nn.Conv2d
+
+ def get_last_layer(self, **kwargs):
+ return self.conv_out.weight
+
+ def forward(self, z, **kwargs):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, **kwargs)
+ h = self.mid.attn_1(h, **kwargs)
+ h = self.mid.block_2(h, temb, **kwargs)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, **kwargs)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h, **kwargs)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 000000000..e19b83f98
--- /dev/null
+++ b/sgm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,1262 @@
+import math
+from abc import abstractmethod
+from functools import partial
+from typing import Iterable
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from ...modules.attention import SpatialTransformer
+from ...modules.diffusionmodules.util import (
+ avg_pool_nd,
+ checkpoint,
+ conv_nd,
+ linear,
+ normalization,
+ timestep_embedding,
+ zero_module,
+)
+from ...util import default, exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
+ )
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(
+ self,
+ x,
+ emb,
+ context=None,
+ skip_time_mix=False,
+ time_context=None,
+ num_video_frames=None,
+ time_context_cat=None,
+ use_crossframe_attention_in_spatial_layers=False,
+ ):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ self.third_up = third_up
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=padding
+ )
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ t_factor = 1 if not self.third_up else 2
+ x = F.interpolate(
+ x,
+ (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode="nearest",
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class TransposedUpsample(nn.Module):
+ "Learned 2x upsampling without padding"
+
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(
+ self.channels, self.out_channels, kernel_size=ks, stride=2
+ )
+
+ def forward(self, x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
+ if use_conv:
+ print(f"Building a Downsample layer with {dims} dims.")
+ print(
+ f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
+ f"kernel-size: 3, stride: {stride}, padding: {padding}"
+ )
+ if dims == 3:
+ print(f" --> Downsampling third axis (time): {third_down}")
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ kernel_size=3,
+ exchange_temb_dims=False,
+ skip_t_emb=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.exchange_temb_dims = exchange_temb_dims
+
+ if isinstance(kernel_size, Iterable):
+ padding = [k // 2 for k in kernel_size]
+ else:
+ padding = kernel_size // 2
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.skip_t_emb = skip_t_emb
+ self.emb_out_channels = (
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
+ )
+ if self.skip_t_emb:
+ print(f"Skipping timestep embedding in {self.__class__.__name__}")
+ assert not self.use_scale_shift_norm
+ self.emb_layers = None
+ self.exchange_temb_dims = False
+ else:
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ self.emb_out_channels,
+ ),
+ )
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(
+ dims,
+ self.out_channels,
+ self.out_channels,
+ kernel_size,
+ padding=padding,
+ )
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, kernel_size, padding=padding
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+
+ if self.skip_t_emb:
+ emb_out = th.zeros_like(h)
+ else:
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ if self.exchange_temb_dims:
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x, **kwargs):
+ # TODO add crossframe attention and use mixed checkpoint
+ return checkpoint(
+ self._forward, (x,), self.parameters(), True
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ # return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class Timestep(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, t):
+ return timestep_embedding(t, self.dim)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ spatial_transformer_attn_type="softmax",
+ adm_in_channels=None,
+ use_fairscale_checkpoint=False,
+ offload_to_cpu=False,
+ transformer_depth_middle=None,
+ ):
+ super().__init__()
+ from omegaconf.listconfig import ListConfig
+
+ if use_spatial_transformer:
+ assert (
+ context_dim is not None
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+ if context_dim is not None:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ elif isinstance(transformer_depth, ListConfig):
+ transformer_depth = list(transformer_depth)
+ transformer_depth_middle = default(
+ transformer_depth_middle, transformer_depth[-1]
+ )
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+ # self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ print(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ ) # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ if use_fp16:
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
+ # self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ assert use_fairscale_checkpoint != use_checkpoint or not (
+ use_checkpoint or use_fairscale_checkpoint
+ )
+
+ self.use_fairscale_checkpoint = False
+ checkpoint_wrapper_fn = (
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
+ if self.use_fairscale_checkpoint
+ else lambda x: x
+ )
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = checkpoint_wrapper_fn(
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = checkpoint_wrapper_fn(
+ nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ )
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or nr < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer( # always uses a self-attn
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or i < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ # h = x.type(self.dtype)
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ assert False, "not supported anymore. what the f*** are you doing?"
+ else:
+ return self.out(h)
+
+
+class NoTimeUNetModel(UNetModel):
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ timesteps = th.zeros_like(timesteps)
+ return super().forward(x, timesteps, context, y, **kwargs)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ # h = x.type(self.dtype)
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+if __name__ == "__main__":
+
+ class Dummy(nn.Module):
+ def __init__(self, in_channels=3, model_channels=64):
+ super().__init__()
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(2, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+
+ model = UNetModel(
+ use_checkpoint=True,
+ image_size=64,
+ in_channels=4,
+ out_channels=4,
+ model_channels=128,
+ attention_resolutions=[4, 2],
+ num_res_blocks=2,
+ channel_mult=[1, 2, 4],
+ num_head_channels=64,
+ use_spatial_transformer=False,
+ use_linear_in_transformer=True,
+ transformer_depth=1,
+ legacy=False,
+ ).cuda()
+ x = th.randn(11, 4, 64, 64).cuda()
+ t = th.randint(low=0, high=10, size=(11,), device="cuda")
+ o = model(x, t)
+ print("done.")
diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py
new file mode 100644
index 000000000..6346829c8
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sampling.py
@@ -0,0 +1,365 @@
+"""
+ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
+"""
+
+
+from typing import Dict, Union
+
+import torch
+from omegaconf import ListConfig, OmegaConf
+from tqdm import tqdm
+
+from ...modules.diffusionmodules.sampling_utils import (
+ get_ancestral_step,
+ linear_multistep_coeff,
+ to_d,
+ to_neg_log_sigma,
+ to_sigma,
+)
+from ...util import append_dims, default, instantiate_from_config
+
+DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
+
+
+class BaseDiffusionSampler:
+ def __init__(
+ self,
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
+ num_steps: Union[int, None] = None,
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
+ verbose: bool = False,
+ device: str = "cuda",
+ ):
+ self.num_steps = num_steps
+ self.discretization = instantiate_from_config(discretization_config)
+ self.guider = instantiate_from_config(
+ default(
+ guider_config,
+ DEFAULT_GUIDER,
+ )
+ )
+ self.verbose = verbose
+ self.device = device
+
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
+ sigmas = self.discretization(
+ self.num_steps if num_steps is None else num_steps, device=self.device
+ )
+ uc = default(uc, cond)
+
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
+ num_sigmas = len(sigmas)
+
+ s_in = x.new_ones([x.shape[0]])
+
+ return x, s_in, sigmas, num_sigmas, cond, uc
+
+ def denoise(self, x, denoiser, sigma, cond, uc):
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
+ denoised = self.guider(denoised, sigma)
+ return denoised
+
+ def get_sigma_gen(self, num_sigmas):
+ sigma_generator = range(num_sigmas - 1)
+ if self.verbose:
+ print("#" * 30, " Sampling setting ", "#" * 30)
+ print(f"Sampler: {self.__class__.__name__}")
+ print(f"Discretization: {self.discretization.__class__.__name__}")
+ print(f"Guider: {self.guider.__class__.__name__}")
+ sigma_generator = tqdm(
+ sigma_generator,
+ total=num_sigmas,
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
+ )
+ return sigma_generator
+
+
+class SingleStepDiffusionSampler(BaseDiffusionSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
+ raise NotImplementedError
+
+ def euler_step(self, x, d, dt):
+ return x + dt * d
+
+
+class EDMSampler(SingleStepDiffusionSampler):
+ def __init__(
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.s_churn = s_churn
+ self.s_tmin = s_tmin
+ self.s_tmax = s_tmax
+ self.s_noise = s_noise
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
+ sigma_hat = sigma * (gamma + 1.0)
+ if gamma > 0:
+ eps = torch.randn_like(x) * self.s_noise
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
+
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
+ d = to_d(x, sigma_hat, denoised)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+
+ euler_step = self.euler_step(x, d, dt)
+ x = self.possible_correction_step(
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ )
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
+ )
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ gamma,
+ )
+
+ return x
+
+
+class AncestralSampler(SingleStepDiffusionSampler):
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.eta = eta
+ self.s_noise = s_noise
+ self.noise_sampler = lambda x: torch.randn_like(x)
+
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
+ d = to_d(x, sigma, denoised)
+ dt = append_dims(sigma_down - sigma, x.ndim)
+
+ return self.euler_step(x, d, dt)
+
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0,
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
+ x,
+ )
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ )
+
+ return x
+
+
+class LinearMultistepSampler(BaseDiffusionSampler):
+ def __init__(
+ self,
+ order=4,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.order = order
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ ds = []
+ sigmas_cpu = sigmas.detach().cpu().numpy()
+ for i in self.get_sigma_gen(num_sigmas):
+ sigma = s_in * sigmas[i]
+ denoised = denoiser(
+ *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
+ )
+ denoised = self.guider(denoised, sigma)
+ d = to_d(x, sigma, denoised)
+ ds.append(d)
+ if len(ds) > self.order:
+ ds.pop(0)
+ cur_order = min(i + 1, self.order)
+ coeffs = [
+ linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
+ for j in range(cur_order)
+ ]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+
+ return x
+
+
+class EulerEDMSampler(EDMSampler):
+ def possible_correction_step(
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ ):
+ return euler_step
+
+
+class HeunEDMSampler(EDMSampler):
+ def possible_correction_step(
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ ):
+ if torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ return euler_step
+ else:
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
+ d_new = to_d(euler_step, next_sigma, denoised)
+ d_prime = (d + d_new) / 2.0
+
+ # apply correction if noise level is not 0
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
+ )
+ return x
+
+
+class EulerAncestralSampler(AncestralSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+
+ return x
+
+
+class DPMPP2SAncestralSampler(AncestralSampler):
+ def get_variables(self, sigma, sigma_down):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
+ h = t_next - t
+ s = t + 0.5 * h
+ return h, s, t, t_next
+
+ def get_mult(self, h, s, t, t_next):
+ mult1 = to_sigma(s) / to_sigma(t)
+ mult2 = (-0.5 * h).expm1()
+ mult3 = to_sigma(t_next) / to_sigma(t)
+ mult4 = (-h).expm1()
+
+ return mult1, mult2, mult3, mult4
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+
+ if torch.sum(sigma_down) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ x = x_euler
+ else:
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
+ mult = [
+ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
+ ]
+
+ x2 = mult[0] * x - mult[1] * denoised
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
+
+ # apply correction if noise level is not 0
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
+
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+ return x
+
+
+class DPMPP2MSampler(BaseDiffusionSampler):
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
+ h = t_next - t
+
+ if previous_sigma is not None:
+ h_last = t - to_neg_log_sigma(previous_sigma)
+ r = h_last / h
+ return h, r, t, t_next
+ else:
+ return h, None, t, t_next
+
+ def get_mult(self, h, r, t, t_next, previous_sigma):
+ mult1 = to_sigma(t_next) / to_sigma(t)
+ mult2 = (-h).expm1()
+
+ if previous_sigma is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def sampler_step(
+ self,
+ old_denoised,
+ previous_sigma,
+ sigma,
+ next_sigma,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ ):
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
+ mult = [
+ append_dims(mult, x.ndim)
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
+ ]
+
+ x_standard = mult[0] * x - mult[1] * denoised
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return x_standard, denoised
+ else:
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
+ x_advanced = mult[0] * x - mult[1] * denoised_d
+
+ # apply correction if noise level is not 0 and not first step
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
+ )
+
+ return x, denoised
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ old_denoised = None
+ for i in self.get_sigma_gen(num_sigmas):
+ x, old_denoised = self.sampler_step(
+ old_denoised,
+ None if i == 0 else s_in * sigmas[i - 1],
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc=uc,
+ )
+
+ return x
diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py
new file mode 100644
index 000000000..7cca6361c
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sampling_utils.py
@@ -0,0 +1,48 @@
+import torch
+from scipy import integrate
+
+from ...util import append_dims
+
+
+class NoDynamicThresholding:
+ def __call__(self, uncond, cond, scale):
+ return uncond + scale * (cond - uncond)
+
+
+def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
+ if order - 1 > i:
+ raise ValueError(f"Order {order} too high for step {i}")
+
+ def fn(tau):
+ prod = 1.0
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
+
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
+ if not eta:
+ return sigma_to, 0.0
+ sigma_up = torch.minimum(
+ sigma_to,
+ eta
+ * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
+ )
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+ return sigma_down, sigma_up
+
+
+def to_d(x, sigma, denoised):
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def to_neg_log_sigma(sigma):
+ return sigma.log().neg()
+
+
+def to_sigma(neg_log_sigma):
+ return neg_log_sigma.neg().exp()
diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py
new file mode 100644
index 000000000..d54724c6e
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sigma_sampling.py
@@ -0,0 +1,31 @@
+import torch
+
+from ...util import default, instantiate_from_config
+
+
+class EDMSampling:
+ def __init__(self, p_mean=-1.2, p_std=1.2):
+ self.p_mean = p_mean
+ self.p_std = p_std
+
+ def __call__(self, n_samples, rand=None):
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
+ return log_sigma.exp()
+
+
+class DiscreteSampling:
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
+ self.num_idx = num_idx
+ self.sigmas = instantiate_from_config(discretization_config)(
+ num_idx, do_append_zero=do_append_zero, flip=flip
+ )
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas[idx]
+
+ def __call__(self, n_samples, rand=None):
+ idx = default(
+ rand,
+ torch.randint(0, self.num_idx, (n_samples,)),
+ )
+ return self.idx_to_sigma(idx)
diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py
new file mode 100644
index 000000000..069ff131f
--- /dev/null
+++ b/sgm/modules/diffusionmodules/util.py
@@ -0,0 +1,308 @@
+"""
+adopted from
+https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+and
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+and
+https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+
+thanks!
+"""
+
+import math
+
+import torch
+import torch.nn as nn
+from einops import repeat
+
+
+def make_beta_schedule(
+ schedule,
+ n_timestep,
+ linear_start=1e-4,
+ linear_end=2e-2,
+):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
+ )
+ ** 2
+ )
+ return betas.numpy()
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def mixed_checkpoint(func, inputs: dict, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
+ it also works with non-tensor inputs
+ :param func: the function to evaluate.
+ :param inputs: the argument dictionary to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
+ tensor_inputs = [
+ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
+ ]
+ non_tensor_keys = [
+ key for key in inputs if not isinstance(inputs[key], torch.Tensor)
+ ]
+ non_tensor_inputs = [
+ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
+ ]
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
+ return MixedCheckpointFunction.apply(
+ func,
+ len(tensor_inputs),
+ len(non_tensor_inputs),
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ )
+ else:
+ return func(**inputs)
+
+
+class MixedCheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ run_function,
+ length_tensors,
+ length_non_tensors,
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ ):
+ ctx.end_tensors = length_tensors
+ ctx.end_non_tensors = length_tensors + length_non_tensors
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ assert (
+ len(tensor_keys) == length_tensors
+ and len(non_tensor_keys) == length_non_tensors
+ )
+
+ ctx.input_tensors = {
+ key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
+ }
+ ctx.input_non_tensors = {
+ key: val
+ for (key, val) in zip(
+ non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
+ )
+ }
+ ctx.run_function = run_function
+ ctx.input_params = list(args[ctx.end_non_tensors :])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(
+ **ctx.input_tensors, **ctx.input_non_tensors
+ )
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
+ ctx.input_tensors = {
+ key: ctx.input_tensors[key].detach().requires_grad_(True)
+ for key in ctx.input_tensors
+ }
+
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = {
+ key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
+ for key in ctx.input_tensors
+ }
+ # shallow_copies.update(additional_args)
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ list(ctx.input_tensors.values()) + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (
+ (None, None, None, None, None)
+ + input_grads[: ctx.end_tensors]
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
+ + input_grads[ctx.end_tensors :]
+ )
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py
new file mode 100644
index 000000000..87ede6061
--- /dev/null
+++ b/sgm/modules/diffusionmodules/wrappers.py
@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+from packaging import version
+
+OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
+
+
+class IdentityWrapper(nn.Module):
+ def __init__(self, diffusion_model, compile_model: bool = False):
+ super().__init__()
+ compile = (
+ torch.compile
+ if (version.parse(torch.__version__) >= version.parse("2.0.0"))
+ and compile_model
+ else lambda x: x
+ )
+ self.diffusion_model = compile(diffusion_model)
+
+ def forward(self, *args, **kwargs):
+ return self.diffusion_model(*args, **kwargs)
+
+
+class OpenAIWrapper(IdentityWrapper):
+ def forward(
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
+ ) -> torch.Tensor:
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
+ return self.diffusion_model(
+ x,
+ timesteps=t,
+ context=c.get("crossattn", None),
+ y=c.get("vector", None),
+ **kwargs
+ )
diff --git a/sgm/modules/distributions/__init__.py b/sgm/modules/distributions/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py
new file mode 100644
index 000000000..0b61f0307
--- /dev/null
+++ b/sgm/modules/distributions/distributions.py
@@ -0,0 +1,102 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(
+ device=self.parameters.device
+ )
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
+ device=self.parameters.device
+ )
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py
new file mode 100644
index 000000000..97b5ae2b2
--- /dev/null
+++ b/sgm/modules/ema.py
@@ -0,0 +1,86 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ torch.tensor(0, dtype=torch.int)
+ if use_num_upates
+ else torch.tensor(-1, dtype=torch.int),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(
+ one_minus_decay * (shadow_params[sname] - m_param[key])
+ )
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/sgm/modules/encoders/__init__.py b/sgm/modules/encoders/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py
new file mode 100644
index 000000000..ed3f2d215
--- /dev/null
+++ b/sgm/modules/encoders/modules.py
@@ -0,0 +1,960 @@
+from contextlib import nullcontext
+from functools import partial
+from typing import Dict, List, Optional, Tuple, Union
+
+import kornia
+import numpy as np
+import open_clip
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from omegaconf import ListConfig
+from torch.utils.checkpoint import checkpoint
+from transformers import (
+ ByT5Tokenizer,
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5Tokenizer,
+)
+
+from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
+from ...modules.diffusionmodules.model import Encoder
+from ...modules.diffusionmodules.openaimodel import Timestep
+from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ...modules.distributions.distributions import DiagonalGaussianDistribution
+from ...util import (
+ autocast,
+ count_params,
+ default,
+ disabled_train,
+ expand_dims_like,
+ instantiate_from_config,
+)
+
+
+class AbstractEmbModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._is_trainable = None
+ self._ucg_rate = None
+ self._input_key = None
+
+ @property
+ def is_trainable(self) -> bool:
+ return self._is_trainable
+
+ @property
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
+ return self._ucg_rate
+
+ @property
+ def input_key(self) -> str:
+ return self._input_key
+
+ @is_trainable.setter
+ def is_trainable(self, value: bool):
+ self._is_trainable = value
+
+ @ucg_rate.setter
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
+ self._ucg_rate = value
+
+ @input_key.setter
+ def input_key(self, value: str):
+ self._input_key = value
+
+ @is_trainable.deleter
+ def is_trainable(self):
+ del self._is_trainable
+
+ @ucg_rate.deleter
+ def ucg_rate(self):
+ del self._ucg_rate
+
+ @input_key.deleter
+ def input_key(self):
+ del self._input_key
+
+
+class GeneralConditioner(nn.Module):
+ OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
+ KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
+
+ def __init__(self, emb_models: Union[List, ListConfig]):
+ super().__init__()
+ embedders = []
+ for n, embconfig in enumerate(emb_models):
+ embedder = instantiate_from_config(embconfig)
+ assert isinstance(
+ embedder, AbstractEmbModel
+ ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
+ embedder.is_trainable = embconfig.get("is_trainable", False)
+ embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
+ if not embedder.is_trainable:
+ embedder.train = disabled_train
+ for param in embedder.parameters():
+ param.requires_grad = False
+ embedder.eval()
+ print(
+ f"Initialized embedder #{n}: {embedder.__class__.__name__} "
+ f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
+ )
+
+ if "input_key" in embconfig:
+ embedder.input_key = embconfig["input_key"]
+ elif "input_keys" in embconfig:
+ embedder.input_keys = embconfig["input_keys"]
+ else:
+ raise KeyError(
+ f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
+ )
+
+ embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
+ if embedder.legacy_ucg_val is not None:
+ embedder.ucg_prng = np.random.RandomState()
+
+ embedders.append(embedder)
+ self.embedders = nn.ModuleList(embedders)
+
+ def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
+ assert embedder.legacy_ucg_val is not None
+ p = embedder.ucg_rate
+ val = embedder.legacy_ucg_val
+ for i in range(len(batch[embedder.input_key])):
+ if embedder.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[embedder.input_key][i] = val
+ return batch
+
+ def forward(
+ self, batch: Dict, force_zero_embeddings: Optional[List] = None
+ ) -> Dict:
+ output = dict()
+ if force_zero_embeddings is None:
+ force_zero_embeddings = []
+ for embedder in self.embedders:
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
+ with embedding_context():
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
+ if embedder.legacy_ucg_val is not None:
+ batch = self.possibly_get_ucg_val(embedder, batch)
+ emb_out = embedder(batch[embedder.input_key])
+ elif hasattr(embedder, "input_keys"):
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
+ assert isinstance(
+ emb_out, (torch.Tensor, list, tuple)
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
+ if not isinstance(emb_out, (list, tuple)):
+ emb_out = [emb_out]
+ for emb in emb_out:
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
+ emb = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - embedder.ucg_rate)
+ * torch.ones(emb.shape[0], device=emb.device)
+ ),
+ emb,
+ )
+ * emb
+ )
+ if (
+ hasattr(embedder, "input_key")
+ and embedder.input_key in force_zero_embeddings
+ ):
+ emb = torch.zeros_like(emb)
+ if out_key in output:
+ output[out_key] = torch.cat(
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
+ )
+ else:
+ output[out_key] = emb
+ return output
+
+ def get_unconditional_conditioning(
+ self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
+ ):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ ucg_rates = list()
+ for embedder in self.embedders:
+ ucg_rates.append(embedder.ucg_rate)
+ embedder.ucg_rate = 0.0
+ c = self(batch_c)
+ uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
+
+ for embedder, rate in zip(self.embedders, ucg_rates):
+ embedder.ucg_rate = rate
+ return c, uc
+
+
+class InceptionV3(nn.Module):
+ """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
+ port with an additional squeeze at the end"""
+
+ def __init__(self, normalize_input=False, **kwargs):
+ super().__init__()
+ from pytorch_fid import inception
+
+ kwargs["resize_input"] = True
+ self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
+
+ def forward(self, inp):
+ # inp = kornia.geometry.resize(inp, (299, 299),
+ # interpolation='bicubic',
+ # align_corners=False,
+ # antialias=True)
+ # inp = inp.clamp(min=-1, max=1)
+
+ outp = self.model(inp)
+
+ if len(outp) == 1:
+ return outp[0].squeeze()
+
+ return outp
+
+
+class IdentityEncoder(AbstractEmbModel):
+ def encode(self, x):
+ return x
+
+ def forward(self, x):
+ return x
+
+
+class ClassEmbedder(AbstractEmbModel):
+ def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
+ super().__init__()
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.add_sequence_dim = add_sequence_dim
+
+ def forward(self, c):
+ c = self.embedding(c)
+ if self.add_sequence_dim:
+ c = c[:, None, :]
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = (
+ self.n_classes - 1
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc.long()}
+ return uc
+
+
+class ClassEmbedderForMultiCond(ClassEmbedder):
+ def forward(self, batch, key=None, disable_dropout=False):
+ out = batch
+ key = default(key, self.key)
+ islist = isinstance(batch[key], list)
+ if islist:
+ batch[key] = batch[key][0]
+ c_out = super().forward(batch, key, disable_dropout)
+ out[key] = [c_out] if islist else c_out
+ return out
+
+
+class FrozenT5Embedder(AbstractEmbModel):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ # @autocast
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenByT5Embedder(AbstractEmbModel):
+ """
+ Uses the ByT5 transformer encoder for text. Is character-aware.
+ """
+
+ def __init__(
+ self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = ByT5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEmbModel):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+
+ LAYERS = ["last", "pooled", "hidden"]
+
+ def __init__(
+ self,
+ version="openai/clip-vit-large-patch14",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ layer_idx=None,
+ always_return_pooled=False,
+ ): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ self.return_pooled = always_return_pooled
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
+ )
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ if self.return_pooled:
+ return z, outputs.pooler_output
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+
+ LAYERS = ["pooled", "last", "penultimate"]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ always_return_pooled=False,
+ legacy=True,
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device("cpu"),
+ pretrained=version,
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ self.return_pooled = always_return_pooled
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+ self.legacy = legacy
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ if not self.return_pooled and self.legacy:
+ return z
+ if self.return_pooled:
+ assert not self.legacy
+ return z[self.layer], z["pooled"]
+ return z[self.layer]
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ if self.legacy:
+ x = x[self.layer]
+ x = self.model.ln_final(x)
+ return x
+ else:
+ # x is a dict and will stay a dict
+ o = x["last"]
+ o = self.model.ln_final(o)
+ pooled = self.pool(o, text)
+ x["pooled"] = pooled
+ return x
+
+ def pool(self, x, text):
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = (
+ x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
+ @ self.model.text_projection
+ )
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ outputs = {}
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - 1:
+ outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
+ return outputs
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEmbModel):
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate",
+ ]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch, device=torch.device("cpu"), pretrained=version
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ antialias=True,
+ ucg_rate=0.0,
+ unsqueeze_dim=False,
+ repeat_to_max_len=False,
+ num_image_crops=0,
+ output_tokens=False,
+ ):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device("cpu"),
+ pretrained=version,
+ )
+ del model.transformer
+ self.model = model
+ self.max_crops = num_image_crops
+ self.pad_to_max_len = self.max_crops > 0
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ self.antialias = antialias
+
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+ self.ucg_rate = ucg_rate
+ self.unsqueeze_dim = unsqueeze_dim
+ self.stored_batch = None
+ self.model.visual.output_tokens = output_tokens
+ self.output_tokens = output_tokens
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ tokens = None
+ if self.output_tokens:
+ z, tokens = z[0], z[1]
+ z = z.to(image.dtype)
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
+ z = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
+ )[:, None]
+ * z
+ )
+ if tokens is not None:
+ tokens = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(tokens.shape[0], device=tokens.device)
+ ),
+ tokens,
+ )
+ * tokens
+ )
+ if self.unsqueeze_dim:
+ z = z[:, None, :]
+ if self.output_tokens:
+ assert not self.repeat_to_max_len
+ assert not self.pad_to_max_len
+ return tokens, z
+ if self.repeat_to_max_len:
+ if z.dim() == 2:
+ z_ = z[:, None, :]
+ else:
+ z_ = z
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
+ elif self.pad_to_max_len:
+ assert z.dim() == 3
+ z_pad = torch.cat(
+ (
+ z,
+ torch.zeros(
+ z.shape[0],
+ self.max_length - z.shape[1],
+ z.shape[2],
+ device=z.device,
+ ),
+ ),
+ 1,
+ )
+ return z_pad, z_pad[:, 0, ...]
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ # if self.max_crops > 0:
+ # img = self.preprocess_by_cropping(img)
+ if img.dim() == 5:
+ assert self.max_crops == img.shape[1]
+ img = rearrange(img, "b n c h w -> (b n) c h w")
+ img = self.preprocess(img)
+ if not self.output_tokens:
+ assert not self.model.visual.output_tokens
+ x = self.model.visual(img)
+ tokens = None
+ else:
+ assert self.model.visual.output_tokens
+ x, tokens = self.model.visual(img)
+ if self.max_crops > 0:
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
+ # drop out between 0 and all along the sequence axis
+ x = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
+ )
+ * x
+ )
+ if tokens is not None:
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
+ print(
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
+ f"Check what you are doing, and then remove this message."
+ )
+ if self.output_tokens:
+ return x, tokens
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEmbModel):
+ def __init__(
+ self,
+ clip_version="openai/clip-vit-large-patch14",
+ t5_version="google/t5-v1_1-xl",
+ device="cuda",
+ clip_max_length=77,
+ t5_max_length=77,
+ ):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(
+ clip_version, device, max_length=clip_max_length
+ )
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
+ )
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(
+ self,
+ n_stages=1,
+ method="bilinear",
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False,
+ wrap_video=False,
+ kernel_size=1,
+ remap_output=False,
+ ):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in [
+ "nearest",
+ "linear",
+ "bilinear",
+ "trilinear",
+ "bicubic",
+ "area",
+ ]
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None or remap_output
+ if self.remap_output:
+ print(
+ f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
+ )
+ self.channel_mapper = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ bias=bias,
+ padding=kernel_size // 2,
+ )
+ self.wrap_video = wrap_video
+
+ def forward(self, x):
+ if self.wrap_video and x.ndim == 5:
+ B, C, T, H, W = x.shape
+ x = rearrange(x, "b c t h w -> b t c h w")
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+ if self.wrap_video:
+ x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
+ x = rearrange(x, "b t c h w -> b c t h w")
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+
+class LowScaleEncoder(nn.Module):
+ def __init__(
+ self,
+ model_config,
+ linear_start,
+ linear_end,
+ timesteps=1000,
+ max_noise_level=250,
+ output_size=64,
+ scale_factor=1.0,
+ ):
+ super().__init__()
+ self.max_noise_level = max_noise_level
+ self.model = instantiate_from_config(model_config)
+ self.augmentation_schedule = self.register_schedule(
+ timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
+ )
+ self.out_size = output_size
+ self.scale_factor = scale_factor
+
+ def register_schedule(
+ self,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ betas = make_beta_schedule(
+ beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ (timesteps,) = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert (
+ alphas_cumprod.shape[0] == self.num_timesteps
+ ), "alphas have to be defined for each timestep"
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer("betas", to_torch(betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+ )
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def forward(self, x):
+ z = self.model.encode(x)
+ if isinstance(z, DiagonalGaussianDistribution):
+ z = z.sample()
+ z = z * self.scale_factor
+ noise_level = torch.randint(
+ 0, self.max_noise_level, (x.shape[0],), device=x.device
+ ).long()
+ z = self.q_sample(z, noise_level)
+ if self.out_size is not None:
+ z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
+ # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
+ return z, noise_level
+
+ def decode(self, z):
+ z = z / self.scale_factor
+ return self.model.decode(z)
+
+
+class ConcatTimestepEmbedderND(AbstractEmbModel):
+ """embeds each dimension independently and concatenates them"""
+
+ def __init__(self, outdim):
+ super().__init__()
+ self.timestep = Timestep(outdim)
+ self.outdim = outdim
+
+ def forward(self, x):
+ if x.ndim == 1:
+ x = x[:, None]
+ assert len(x.shape) == 2
+ b, dims = x.shape[0], x.shape[1]
+ x = rearrange(x, "b d -> (b d)")
+ emb = self.timestep(x)
+ emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
+ return emb
+
+
+class GaussianEncoder(Encoder, AbstractEmbModel):
+ def __init__(
+ self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.posterior = DiagonalGaussianRegularizer()
+ self.weight = weight
+ self.flatten_output = flatten_output
+
+ def forward(self, x) -> Tuple[Dict, torch.Tensor]:
+ z = super().forward(x)
+ z, log = self.posterior(z)
+ log["loss"] = log["kl_loss"]
+ log["weight"] = self.weight
+ if self.flatten_output:
+ z = rearrange(z, "b c h w -> b (h w ) c")
+ return log, z
diff --git a/sgm/util.py b/sgm/util.py
new file mode 100644
index 000000000..06f48a882
--- /dev/null
+++ b/sgm/util.py
@@ -0,0 +1,231 @@
+import functools
+import importlib
+import os
+from functools import partial
+from inspect import isfunction
+
+import fsspec
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from safetensors.torch import load_file as load_safetensors
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def get_string_from_tuple(s):
+ try:
+ # Check if the string starts and ends with parentheses
+ if s[0] == "(" and s[-1] == ")":
+ # Convert the string to a tuple
+ t = eval(s)
+ # Check if the type of t is tuple
+ if type(t) == tuple:
+ return t[0]
+ else:
+ pass
+ except:
+ pass
+ return s
+
+
+def is_power_of_two(n):
+ """
+ chat.openai.com/chat
+ Return True if n is a power of 2, otherwise return False.
+
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
+
+ """
+ if n <= 0:
+ return False
+ return (n & (n - 1)) == 0
+
+
+def autocast(f, enabled=True):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=enabled,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def load_partial_from_config(config):
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ if isinstance(xc[bi], list):
+ text_seq = xc[bi][0]
+ else:
+ text_seq = xc[bi]
+ lines = "\n".join(
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
+ )
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def partialclass(cls, *args, **kwargs):
+ class NewCls(cls):
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
+
+ return NewCls
+
+
+def make_path_absolute(path):
+ fs, p = fsspec.core.url_to_fs(path)
+ if fs.protocol == "file":
+ return os.path.abspath(p)
+ return path
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def isheatmap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+
+ return x.ndim == 2
+
+
+def isneighbors(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def expand_dims_like(x, y):
+ while x.dim() != y.dim():
+ x = x.unsqueeze(-1)
+ return x
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False, invalidate_cache=True):
+ module, cls = string.rsplit(".", 1)
+ if invalidate_cache:
+ importlib.invalidate_caches()
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
+ )
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def load_model_from_config(config, ckpt, verbose=True, freeze=True):
+ print(f"Loading model from {ckpt}")
+ if ckpt.endswith("ckpt"):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ elif ckpt.endswith("safetensors"):
+ sd = load_safetensors(ckpt)
+ else:
+ raise NotImplementedError
+
+ model = instantiate_from_config(config.model)
+ sd = pl_sd["state_dict"]
+
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if freeze:
+ for param in model.parameters():
+ param.requires_grad = False
+
+ model.eval()
+ return model