diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 8811fca5f5a2..d1805ff605d8 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -81,6 +81,8 @@
title: Overview
- local: hybrid_inference/vae_decode
title: VAE Decode
+ - local: hybrid_inference/vae_encode
+ title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference
diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md
index aa0a5e5ae58f..865aaba5ebb6 100644
--- a/docs/source/en/hybrid_inference/api_reference.md
+++ b/docs/source/en/hybrid_inference/api_reference.md
@@ -3,3 +3,7 @@
## Remote Decode
[[autodoc]] utils.remote_utils.remote_decode
+
+## Remote Encode
+
+[[autodoc]] utils.remote_utils.remote_encode
diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md
index 9bbe245901df..b44393c77cbd 100644
--- a/docs/source/en/hybrid_inference/overview.md
+++ b/docs/source/en/hybrid_inference/overview.md
@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
## Available Models
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
-* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
+* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
---
@@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
+## Changelog
+
+- March 10 2025: Added VAE encode
+- March 2 2025: Initial release with VAE decoding
+
## Contents
-The documentation is organized into two sections:
+The documentation is organized into three sections:
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
+* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
* **API Reference** Dive into task-specific settings and parameters.
diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md
new file mode 100644
index 000000000000..dd285fa25c03
--- /dev/null
+++ b/docs/source/en/hybrid_inference/vae_encode.md
@@ -0,0 +1,183 @@
+# Getting Started: VAE Encode with Hybrid Inference
+
+VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
+
+## Memory
+
+These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
+
+For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
+
+SD v1.5
+
+| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
+|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
+| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
+| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
+| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
+| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
+| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
+| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
+| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
+| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
+| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
+| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
+| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
+| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
+
+
+
+
+SDXL
+
+| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
+|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
+| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
+| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
+| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
+| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
+| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
+| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
+| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
+| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
+| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
+| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
+| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
+| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
+
+
+
+## Available VAEs
+
+| | **Endpoint** | **Model** |
+|:-:|:-----------:|:--------:|
+| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
+| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
+| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
+
+
+> [!TIP]
+> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
+
+
+## Code
+
+> [!TIP]
+> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
+
+
+A helper method simplifies interacting with Hybrid Inference.
+
+```python
+from diffusers.utils.remote_utils import remote_encode
+```
+
+### Basic example
+
+Let's encode an image, then decode it to demonstrate.
+
+
+
+
+
+Code
+
+```python
+from diffusers.utils import load_image
+from diffusers.utils.remote_utils import remote_decode
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
+
+latent = remote_encode(
+ endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+
+decoded = remote_decode(
+ endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+```
+
+
+
+
+
+
+
+
+### Generation
+
+Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
+
+Code
+
+```python
+import torch
+from diffusers import StableDiffusionImg2ImgPipeline
+from diffusers.utils import load_image
+from diffusers.utils.remote_utils import remote_decode, remote_encode
+
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ vae=None,
+).to("cuda")
+
+init_image = load_image(
+ "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+)
+init_image = init_image.resize((768, 512))
+
+init_latent = remote_encode(
+ endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
+ image=init_image,
+ scaling_factor=0.18215,
+)
+
+prompt = "A fantasy landscape, trending on artstation"
+latent = pipe(
+ prompt=prompt,
+ image=init_latent,
+ strength=0.75,
+ output_type="latent",
+).images
+
+image = remote_decode(
+ endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ scaling_factor=0.18215,
+)
+image.save("fantasy_landscape.jpg")
+```
+
+
+
+
+
+
+
+## Integrations
+
+* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
+* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index 3f88f347710f..fa12318f4714 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -56,3 +56,14 @@
if USE_PEFT_BACKEND and _CHECK_PEFT:
dep_version_check("peft")
+
+
+DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
+DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
+DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
+DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
+
+
+ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
+ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
+ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py
index 12bcc94af74f..fbce33d97f54 100644
--- a/src/diffusers/utils/remote_utils.py
+++ b/src/diffusers/utils/remote_utils.py
@@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str:
return "unknown"
-def check_inputs(
+def check_inputs_decode(
endpoint: str,
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
@@ -89,7 +89,7 @@ def check_inputs(
)
-def postprocess(
+def postprocess_decode(
response: requests.Response,
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
output_type: Literal["mp4", "pil", "pt"] = "pil",
@@ -142,7 +142,7 @@ def postprocess(
return output
-def prepare(
+def prepare_decode(
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
do_scaling: bool = True,
@@ -293,7 +293,7 @@ def remote_decode(
standard_warn=False,
)
output_tensor_type = "binary"
- check_inputs(
+ check_inputs_decode(
endpoint,
tensor,
processor,
@@ -309,7 +309,7 @@ def remote_decode(
height,
width,
)
- kwargs = prepare(
+ kwargs = prepare_decode(
tensor=tensor,
processor=processor,
do_scaling=do_scaling,
@@ -324,7 +324,7 @@ def remote_decode(
response = requests.post(endpoint, **kwargs)
if not response.ok:
raise RuntimeError(response.json())
- output = postprocess(
+ output = postprocess_decode(
response=response,
processor=processor,
output_type=output_type,
@@ -332,3 +332,94 @@ def remote_decode(
partial_postprocess=partial_postprocess,
)
return output
+
+
+def check_inputs_encode(
+ endpoint: str,
+ image: Union["torch.Tensor", Image.Image],
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+):
+ pass
+
+
+def postprocess_encode(
+ response: requests.Response,
+):
+ output_tensor = response.content
+ parameters = response.headers
+ shape = json.loads(parameters["shape"])
+ dtype = parameters["dtype"]
+ torch_dtype = DTYPE_MAP[dtype]
+ output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
+ return output_tensor
+
+
+def prepare_encode(
+ image: Union["torch.Tensor", Image.Image],
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+):
+ headers = {}
+ parameters = {}
+ if scaling_factor is not None:
+ parameters["scaling_factor"] = scaling_factor
+ if shift_factor is not None:
+ parameters["shift_factor"] = shift_factor
+ if isinstance(image, torch.Tensor):
+ data = safetensors.torch._tobytes(image, "tensor")
+ parameters["shape"] = list(image.shape)
+ parameters["dtype"] = str(image.dtype).split(".")[-1]
+ else:
+ buffer = io.BytesIO()
+ image.save(buffer, format="PNG")
+ data = buffer.getvalue()
+ return {"data": data, "params": parameters, "headers": headers}
+
+
+def remote_encode(
+ endpoint: str,
+ image: Union["torch.Tensor", Image.Image],
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+) -> "torch.Tensor":
+ """
+ Hugging Face Hybrid Inference that allow running VAE encode remotely.
+
+ Args:
+ endpoint (`str`):
+ Endpoint for Remote Decode.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ Image to be encoded.
+ scaling_factor (`float`, *optional*):
+ Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
+ - SD v1: 0.18215
+ - SD XL: 0.13025
+ - Flux: 0.3611
+ If `None`, input must be passed with scaling applied.
+ shift_factor (`float`, *optional*):
+ Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
+ - Flux: 0.1159
+ If `None`, input must be passed with scaling applied.
+
+ Returns:
+ output (`torch.Tensor`).
+ """
+ check_inputs_encode(
+ endpoint,
+ image,
+ scaling_factor,
+ shift_factor,
+ )
+ kwargs = prepare_encode(
+ image=image,
+ scaling_factor=scaling_factor,
+ shift_factor=shift_factor,
+ )
+ response = requests.post(endpoint, **kwargs)
+ if not response.ok:
+ raise RuntimeError(response.json())
+ output = postprocess_encode(
+ response=response,
+ )
+ return output
diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py
index 11f9c24d16f6..cec96e729a48 100644
--- a/tests/remote/test_remote_decode.py
+++ b/tests/remote/test_remote_decode.py
@@ -21,7 +21,15 @@
import torch
from diffusers.image_processor import VaeImageProcessor
-from diffusers.utils.remote_utils import remote_decode
+from diffusers.utils.constants import (
+ DECODE_ENDPOINT_FLUX,
+ DECODE_ENDPOINT_HUNYUAN_VIDEO,
+ DECODE_ENDPOINT_SD_V1,
+ DECODE_ENDPOINT_SD_XL,
+)
+from diffusers.utils.remote_utils import (
+ remote_decode,
+)
from diffusers.utils.testing_utils import (
enable_full_determinism,
slow,
@@ -33,11 +41,6 @@
enable_full_determinism()
-ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
-ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
-ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
-ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
-
class RemoteAutoencoderKLMixin:
shape: Tuple[int, ...] = None
@@ -350,7 +353,7 @@ class RemoteAutoencoderKLSDv1Tests(
512,
512,
)
- endpoint = ENDPOINT_SD_V1
+ endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@@ -374,7 +377,7 @@ class RemoteAutoencoderKLSDXLTests(
1024,
1024,
)
- endpoint = ENDPOINT_SD_XL
+ endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@@ -398,7 +401,7 @@ class RemoteAutoencoderKLFluxTests(
1024,
1024,
)
- endpoint = ENDPOINT_FLUX
+ endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
@@ -425,7 +428,7 @@ class RemoteAutoencoderKLFluxPackedTests(
)
height = 1024
width = 1024
- endpoint = ENDPOINT_FLUX
+ endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
@@ -453,7 +456,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
320,
512,
)
- endpoint = ENDPOINT_HUNYUAN_VIDEO
+ endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO
dtype = torch.float16
scaling_factor = 0.476986
processor_cls = VideoProcessor
@@ -504,7 +507,7 @@ class RemoteAutoencoderKLSDv1SlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
- endpoint = ENDPOINT_SD_V1
+ endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@@ -515,7 +518,7 @@ class RemoteAutoencoderKLSDXLSlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
- endpoint = ENDPOINT_SD_XL
+ endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@@ -527,7 +530,7 @@ class RemoteAutoencoderKLFluxSlowTests(
unittest.TestCase,
):
channels = 16
- endpoint = ENDPOINT_FLUX
+ endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py
new file mode 100644
index 000000000000..62ed97ee8f49
--- /dev/null
+++ b/tests/remote/test_remote_encode.py
@@ -0,0 +1,224 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import PIL.Image
+import torch
+
+from diffusers.utils import load_image
+from diffusers.utils.constants import (
+ DECODE_ENDPOINT_FLUX,
+ DECODE_ENDPOINT_SD_V1,
+ DECODE_ENDPOINT_SD_XL,
+ ENCODE_ENDPOINT_FLUX,
+ ENCODE_ENDPOINT_SD_V1,
+ ENCODE_ENDPOINT_SD_XL,
+)
+from diffusers.utils.remote_utils import (
+ remote_decode,
+ remote_encode,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ slow,
+)
+
+
+enable_full_determinism()
+
+IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true"
+
+
+class RemoteAutoencoderKLEncodeMixin:
+ channels: int = None
+ endpoint: str = None
+ decode_endpoint: str = None
+ dtype: torch.dtype = None
+ scaling_factor: float = None
+ shift_factor: float = None
+ image: PIL.Image.Image = None
+
+ def get_dummy_inputs(self):
+ if self.image is None:
+ self.image = load_image(IMAGE)
+ inputs = {
+ "endpoint": self.endpoint,
+ "image": self.image,
+ "scaling_factor": self.scaling_factor,
+ "shift_factor": self.shift_factor,
+ }
+ return inputs
+
+ def test_image_input(self):
+ inputs = self.get_dummy_inputs()
+ height, width = inputs["image"].height, inputs["image"].width
+ output = remote_encode(**inputs)
+ self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
+ decoded = remote_decode(
+ tensor=output,
+ endpoint=self.decode_endpoint,
+ scaling_factor=self.scaling_factor,
+ shift_factor=self.shift_factor,
+ image_format="png",
+ )
+ self.assertEqual(decoded.height, height)
+ self.assertEqual(decoded.width, width)
+ # image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten())
+ # decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten())
+ # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?
+
+
+class RemoteAutoencoderKLSDv1Tests(
+ RemoteAutoencoderKLEncodeMixin,
+ unittest.TestCase,
+):
+ channels = 4
+ endpoint = ENCODE_ENDPOINT_SD_V1
+ decode_endpoint = DECODE_ENDPOINT_SD_V1
+ dtype = torch.float16
+ scaling_factor = 0.18215
+ shift_factor = None
+
+
+class RemoteAutoencoderKLSDXLTests(
+ RemoteAutoencoderKLEncodeMixin,
+ unittest.TestCase,
+):
+ channels = 4
+ endpoint = ENCODE_ENDPOINT_SD_XL
+ decode_endpoint = DECODE_ENDPOINT_SD_XL
+ dtype = torch.float16
+ scaling_factor = 0.13025
+ shift_factor = None
+
+
+class RemoteAutoencoderKLFluxTests(
+ RemoteAutoencoderKLEncodeMixin,
+ unittest.TestCase,
+):
+ channels = 16
+ endpoint = ENCODE_ENDPOINT_FLUX
+ decode_endpoint = DECODE_ENDPOINT_FLUX
+ dtype = torch.bfloat16
+ scaling_factor = 0.3611
+ shift_factor = 0.1159
+
+
+class RemoteAutoencoderKLEncodeSlowTestMixin:
+ channels: int = 4
+ endpoint: str = None
+ decode_endpoint: str = None
+ dtype: torch.dtype = None
+ scaling_factor: float = None
+ shift_factor: float = None
+ image: PIL.Image.Image = None
+
+ def get_dummy_inputs(self):
+ if self.image is None:
+ self.image = load_image(IMAGE)
+ inputs = {
+ "endpoint": self.endpoint,
+ "image": self.image,
+ "scaling_factor": self.scaling_factor,
+ "shift_factor": self.shift_factor,
+ }
+ return inputs
+
+ def test_multi_res(self):
+ inputs = self.get_dummy_inputs()
+ for height in {
+ 320,
+ 512,
+ 640,
+ 704,
+ 896,
+ 1024,
+ 1208,
+ 1384,
+ 1536,
+ 1608,
+ 1864,
+ 2048,
+ }:
+ for width in {
+ 320,
+ 512,
+ 640,
+ 704,
+ 896,
+ 1024,
+ 1208,
+ 1384,
+ 1536,
+ 1608,
+ 1864,
+ 2048,
+ }:
+ inputs["image"] = inputs["image"].resize(
+ (
+ width,
+ height,
+ )
+ )
+ output = remote_encode(**inputs)
+ self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
+ decoded = remote_decode(
+ tensor=output,
+ endpoint=self.decode_endpoint,
+ scaling_factor=self.scaling_factor,
+ shift_factor=self.shift_factor,
+ image_format="png",
+ )
+ self.assertEqual(decoded.height, height)
+ self.assertEqual(decoded.width, width)
+ decoded.save(f"test_multi_res_{height}_{width}.png")
+
+
+@slow
+class RemoteAutoencoderKLSDv1SlowTests(
+ RemoteAutoencoderKLEncodeSlowTestMixin,
+ unittest.TestCase,
+):
+ endpoint = ENCODE_ENDPOINT_SD_V1
+ decode_endpoint = DECODE_ENDPOINT_SD_V1
+ dtype = torch.float16
+ scaling_factor = 0.18215
+ shift_factor = None
+
+
+@slow
+class RemoteAutoencoderKLSDXLSlowTests(
+ RemoteAutoencoderKLEncodeSlowTestMixin,
+ unittest.TestCase,
+):
+ endpoint = ENCODE_ENDPOINT_SD_XL
+ decode_endpoint = DECODE_ENDPOINT_SD_XL
+ dtype = torch.float16
+ scaling_factor = 0.13025
+ shift_factor = None
+
+
+@slow
+class RemoteAutoencoderKLFluxSlowTests(
+ RemoteAutoencoderKLEncodeSlowTestMixin,
+ unittest.TestCase,
+):
+ channels = 16
+ endpoint = ENCODE_ENDPOINT_FLUX
+ decode_endpoint = DECODE_ENDPOINT_FLUX
+ dtype = torch.bfloat16
+ scaling_factor = 0.3611
+ shift_factor = 0.1159