diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8811fca5f5a2..d1805ff605d8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -81,6 +81,8 @@ title: Overview - local: hybrid_inference/vae_decode title: VAE Decode + - local: hybrid_inference/vae_encode + title: VAE Encode - local: hybrid_inference/api_reference title: API Reference title: Hybrid Inference diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md index aa0a5e5ae58f..865aaba5ebb6 100644 --- a/docs/source/en/hybrid_inference/api_reference.md +++ b/docs/source/en/hybrid_inference/api_reference.md @@ -3,3 +3,7 @@ ## Remote Decode [[autodoc]] utils.remote_utils.remote_decode + +## Remote Encode + +[[autodoc]] utils.remote_utils.remote_encode diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index 9bbe245901df..b44393c77cbd 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir ## Available Models * **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed. -* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training. +* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training. * **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow. --- @@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir * **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. * **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. +## Changelog + +- March 10 2025: Added VAE encode +- March 2 2025: Initial release with VAE decoding + ## Contents -The documentation is organized into two sections: +The documentation is organized into three sections: * **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference. +* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference. * **API Reference** Dive into task-specific settings and parameters. diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md new file mode 100644 index 000000000000..dd285fa25c03 --- /dev/null +++ b/docs/source/en/hybrid_inference/vae_encode.md @@ -0,0 +1,183 @@ +# Getting Started: VAE Encode with Hybrid Inference + +VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations. + +## Memory + +These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs. + +For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality. + +
SD v1.5 + +| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | +|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:| +| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 | +| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 | +| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 | +| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 | +| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 | +| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 | +| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 | +| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 | +| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 | +| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 | +| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 | +| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 | +| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 | + + +
+ +
SDXL + +| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | +|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:| +| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 | +| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 | +| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 | +| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 | +| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 | +| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 | +| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 | +| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 | +| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 | +| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 | +| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 | +| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 | +| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 | + +
+ +## Available VAEs + +| | **Endpoint** | **Model** | +|:-:|:-----------:|:--------:| +| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | +| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | +| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | + + +> [!TIP] +> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). + + +## Code + +> [!TIP] +> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main` + + +A helper method simplifies interacting with Hybrid Inference. + +```python +from diffusers.utils.remote_utils import remote_encode +``` + +### Basic example + +Let's encode an image, then decode it to demonstrate. + +
+ +
+ +
Code + +```python +from diffusers.utils import load_image +from diffusers.utils.remote_utils import remote_decode + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true") + +latent = remote_encode( + endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/", + scaling_factor=0.3611, + shift_factor=0.1159, +) + +decoded = remote_decode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + scaling_factor=0.3611, + shift_factor=0.1159, +) +``` + +
+ +
+ +
+ + +### Generation + +Now let's look at a generation example, we'll encode the image, generate then remotely decode too! + +
Code + +```python +import torch +from diffusers import StableDiffusionImg2ImgPipeline +from diffusers.utils import load_image +from diffusers.utils.remote_utils import remote_decode, remote_encode + +pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + torch_dtype=torch.float16, + variant="fp16", + vae=None, +).to("cuda") + +init_image = load_image( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +) +init_image = init_image.resize((768, 512)) + +init_latent = remote_encode( + endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/", + image=init_image, + scaling_factor=0.18215, +) + +prompt = "A fantasy landscape, trending on artstation" +latent = pipe( + prompt=prompt, + image=init_latent, + strength=0.75, + output_type="latent", +).images + +image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + scaling_factor=0.18215, +) +image.save("fantasy_landscape.jpg") +``` + +
+ +
+ +
+ +## Integrations + +* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. +* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 3f88f347710f..fa12318f4714 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -56,3 +56,14 @@ if USE_PEFT_BACKEND and _CHECK_PEFT: dep_version_check("peft") + + +DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" + + +ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" +ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" +ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py index 12bcc94af74f..fbce33d97f54 100644 --- a/src/diffusers/utils/remote_utils.py +++ b/src/diffusers/utils/remote_utils.py @@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str: return "unknown" -def check_inputs( +def check_inputs_decode( endpoint: str, tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, @@ -89,7 +89,7 @@ def check_inputs( ) -def postprocess( +def postprocess_decode( response: requests.Response, processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, output_type: Literal["mp4", "pil", "pt"] = "pil", @@ -142,7 +142,7 @@ def postprocess( return output -def prepare( +def prepare_decode( tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, do_scaling: bool = True, @@ -293,7 +293,7 @@ def remote_decode( standard_warn=False, ) output_tensor_type = "binary" - check_inputs( + check_inputs_decode( endpoint, tensor, processor, @@ -309,7 +309,7 @@ def remote_decode( height, width, ) - kwargs = prepare( + kwargs = prepare_decode( tensor=tensor, processor=processor, do_scaling=do_scaling, @@ -324,7 +324,7 @@ def remote_decode( response = requests.post(endpoint, **kwargs) if not response.ok: raise RuntimeError(response.json()) - output = postprocess( + output = postprocess_decode( response=response, processor=processor, output_type=output_type, @@ -332,3 +332,94 @@ def remote_decode( partial_postprocess=partial_postprocess, ) return output + + +def check_inputs_encode( + endpoint: str, + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +): + pass + + +def postprocess_encode( + response: requests.Response, +): + output_tensor = response.content + parameters = response.headers + shape = json.loads(parameters["shape"]) + dtype = parameters["dtype"] + torch_dtype = DTYPE_MAP[dtype] + output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) + return output_tensor + + +def prepare_encode( + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +): + headers = {} + parameters = {} + if scaling_factor is not None: + parameters["scaling_factor"] = scaling_factor + if shift_factor is not None: + parameters["shift_factor"] = shift_factor + if isinstance(image, torch.Tensor): + data = safetensors.torch._tobytes(image, "tensor") + parameters["shape"] = list(image.shape) + parameters["dtype"] = str(image.dtype).split(".")[-1] + else: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + data = buffer.getvalue() + return {"data": data, "params": parameters, "headers": headers} + + +def remote_encode( + endpoint: str, + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +) -> "torch.Tensor": + """ + Hugging Face Hybrid Inference that allow running VAE encode remotely. + + Args: + endpoint (`str`): + Endpoint for Remote Decode. + image (`torch.Tensor` or `PIL.Image.Image`): + Image to be encoded. + scaling_factor (`float`, *optional*): + Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`]. + - SD v1: 0.18215 + - SD XL: 0.13025 + - Flux: 0.3611 + If `None`, input must be passed with scaling applied. + shift_factor (`float`, *optional*): + Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`. + - Flux: 0.1159 + If `None`, input must be passed with scaling applied. + + Returns: + output (`torch.Tensor`). + """ + check_inputs_encode( + endpoint, + image, + scaling_factor, + shift_factor, + ) + kwargs = prepare_encode( + image=image, + scaling_factor=scaling_factor, + shift_factor=shift_factor, + ) + response = requests.post(endpoint, **kwargs) + if not response.ok: + raise RuntimeError(response.json()) + output = postprocess_encode( + response=response, + ) + return output diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py index 11f9c24d16f6..cec96e729a48 100644 --- a/tests/remote/test_remote_decode.py +++ b/tests/remote/test_remote_decode.py @@ -21,7 +21,15 @@ import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.utils.remote_utils import remote_decode +from diffusers.utils.constants import ( + DECODE_ENDPOINT_FLUX, + DECODE_ENDPOINT_HUNYUAN_VIDEO, + DECODE_ENDPOINT_SD_V1, + DECODE_ENDPOINT_SD_XL, +) +from diffusers.utils.remote_utils import ( + remote_decode, +) from diffusers.utils.testing_utils import ( enable_full_determinism, slow, @@ -33,11 +41,6 @@ enable_full_determinism() -ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" - class RemoteAutoencoderKLMixin: shape: Tuple[int, ...] = None @@ -350,7 +353,7 @@ class RemoteAutoencoderKLSDv1Tests( 512, 512, ) - endpoint = ENDPOINT_SD_V1 + endpoint = DECODE_ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -374,7 +377,7 @@ class RemoteAutoencoderKLSDXLTests( 1024, 1024, ) - endpoint = ENDPOINT_SD_XL + endpoint = DECODE_ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -398,7 +401,7 @@ class RemoteAutoencoderKLFluxTests( 1024, 1024, ) - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -425,7 +428,7 @@ class RemoteAutoencoderKLFluxPackedTests( ) height = 1024 width = 1024 - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -453,7 +456,7 @@ class RemoteAutoencoderKLHunyuanVideoTests( 320, 512, ) - endpoint = ENDPOINT_HUNYUAN_VIDEO + endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO dtype = torch.float16 scaling_factor = 0.476986 processor_cls = VideoProcessor @@ -504,7 +507,7 @@ class RemoteAutoencoderKLSDv1SlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): - endpoint = ENDPOINT_SD_V1 + endpoint = DECODE_ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -515,7 +518,7 @@ class RemoteAutoencoderKLSDXLSlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): - endpoint = ENDPOINT_SD_XL + endpoint = DECODE_ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -527,7 +530,7 @@ class RemoteAutoencoderKLFluxSlowTests( unittest.TestCase, ): channels = 16 - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py new file mode 100644 index 000000000000..62ed97ee8f49 --- /dev/null +++ b/tests/remote/test_remote_encode.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import PIL.Image +import torch + +from diffusers.utils import load_image +from diffusers.utils.constants import ( + DECODE_ENDPOINT_FLUX, + DECODE_ENDPOINT_SD_V1, + DECODE_ENDPOINT_SD_XL, + ENCODE_ENDPOINT_FLUX, + ENCODE_ENDPOINT_SD_V1, + ENCODE_ENDPOINT_SD_XL, +) +from diffusers.utils.remote_utils import ( + remote_decode, + remote_encode, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + slow, +) + + +enable_full_determinism() + +IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true" + + +class RemoteAutoencoderKLEncodeMixin: + channels: int = None + endpoint: str = None + decode_endpoint: str = None + dtype: torch.dtype = None + scaling_factor: float = None + shift_factor: float = None + image: PIL.Image.Image = None + + def get_dummy_inputs(self): + if self.image is None: + self.image = load_image(IMAGE) + inputs = { + "endpoint": self.endpoint, + "image": self.image, + "scaling_factor": self.scaling_factor, + "shift_factor": self.shift_factor, + } + return inputs + + def test_image_input(self): + inputs = self.get_dummy_inputs() + height, width = inputs["image"].height, inputs["image"].width + output = remote_encode(**inputs) + self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) + decoded = remote_decode( + tensor=output, + endpoint=self.decode_endpoint, + scaling_factor=self.scaling_factor, + shift_factor=self.shift_factor, + image_format="png", + ) + self.assertEqual(decoded.height, height) + self.assertEqual(decoded.width, width) + # image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten()) + # decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten()) + # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent? + + +class RemoteAutoencoderKLSDv1Tests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 4 + endpoint = ENCODE_ENDPOINT_SD_V1 + decode_endpoint = DECODE_ENDPOINT_SD_V1 + dtype = torch.float16 + scaling_factor = 0.18215 + shift_factor = None + + +class RemoteAutoencoderKLSDXLTests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 4 + endpoint = ENCODE_ENDPOINT_SD_XL + decode_endpoint = DECODE_ENDPOINT_SD_XL + dtype = torch.float16 + scaling_factor = 0.13025 + shift_factor = None + + +class RemoteAutoencoderKLFluxTests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 16 + endpoint = ENCODE_ENDPOINT_FLUX + decode_endpoint = DECODE_ENDPOINT_FLUX + dtype = torch.bfloat16 + scaling_factor = 0.3611 + shift_factor = 0.1159 + + +class RemoteAutoencoderKLEncodeSlowTestMixin: + channels: int = 4 + endpoint: str = None + decode_endpoint: str = None + dtype: torch.dtype = None + scaling_factor: float = None + shift_factor: float = None + image: PIL.Image.Image = None + + def get_dummy_inputs(self): + if self.image is None: + self.image = load_image(IMAGE) + inputs = { + "endpoint": self.endpoint, + "image": self.image, + "scaling_factor": self.scaling_factor, + "shift_factor": self.shift_factor, + } + return inputs + + def test_multi_res(self): + inputs = self.get_dummy_inputs() + for height in { + 320, + 512, + 640, + 704, + 896, + 1024, + 1208, + 1384, + 1536, + 1608, + 1864, + 2048, + }: + for width in { + 320, + 512, + 640, + 704, + 896, + 1024, + 1208, + 1384, + 1536, + 1608, + 1864, + 2048, + }: + inputs["image"] = inputs["image"].resize( + ( + width, + height, + ) + ) + output = remote_encode(**inputs) + self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) + decoded = remote_decode( + tensor=output, + endpoint=self.decode_endpoint, + scaling_factor=self.scaling_factor, + shift_factor=self.shift_factor, + image_format="png", + ) + self.assertEqual(decoded.height, height) + self.assertEqual(decoded.width, width) + decoded.save(f"test_multi_res_{height}_{width}.png") + + +@slow +class RemoteAutoencoderKLSDv1SlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + endpoint = ENCODE_ENDPOINT_SD_V1 + decode_endpoint = DECODE_ENDPOINT_SD_V1 + dtype = torch.float16 + scaling_factor = 0.18215 + shift_factor = None + + +@slow +class RemoteAutoencoderKLSDXLSlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + endpoint = ENCODE_ENDPOINT_SD_XL + decode_endpoint = DECODE_ENDPOINT_SD_XL + dtype = torch.float16 + scaling_factor = 0.13025 + shift_factor = None + + +@slow +class RemoteAutoencoderKLFluxSlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + channels = 16 + endpoint = ENCODE_ENDPOINT_FLUX + decode_endpoint = DECODE_ENDPOINT_FLUX + dtype = torch.bfloat16 + scaling_factor = 0.3611 + shift_factor = 0.1159