Skip to content
7 changes: 2 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_asym_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder


class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
KL loss for encoding images into latents and decoding latent representations into images.
Expand Down Expand Up @@ -107,9 +107,6 @@ def __init__(
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)

self.use_slicing = False
self.use_tiling = False

self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)

Expand Down
25 changes: 2 additions & 23 deletions src/diffusers/models/autoencoders/autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
from .vae import DecoderOutput, EncoderOutput
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput


class ResBlock(nn.Module):
Expand Down Expand Up @@ -378,7 +378,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
[SANA](https://huggingface.co/papers/2410.10629).
Expand Down Expand Up @@ -536,27 +536,6 @@ def enable_tiling(
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

def disable_tiling(self) -> None:
r"""
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape

Expand Down
33 changes: 2 additions & 31 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder


class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.

Expand Down Expand Up @@ -138,35 +138,6 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling

def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)

def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
Expand Down
32 changes: 2 additions & 30 deletions src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D
from .vae import AutoencoderMixin


class AllegroTemporalConvLayer(nn.Module):
Expand Down Expand Up @@ -673,7 +674,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
return sample


class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro).
Expand Down Expand Up @@ -795,35 +796,6 @@ def __init__(
sample_size - self.tile_overlap_w,
)

def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = True

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
Expand Down
25 changes: 2 additions & 23 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -955,7 +955,7 @@ def forward(
return hidden_states, new_conv_cache


class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
Expand Down Expand Up @@ -1124,27 +1124,6 @@ def enable_tiling(
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

Expand Down
25 changes: 2 additions & 23 deletions src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, IdentityDistribution
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution


logger = get_logger(__name__)
Expand Down Expand Up @@ -875,7 +875,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).

Expand Down Expand Up @@ -1031,27 +1031,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
enc = self.quant_conv(x)
Expand Down
25 changes: 2 additions & 23 deletions src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..attention_processor import Attention
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -624,7 +624,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
Expand Down Expand Up @@ -763,27 +763,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

Expand Down
25 changes: 2 additions & 23 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution


class LTXVideoCausalConv3d(nn.Module):
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No
return hidden_states


class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video).
Expand Down Expand Up @@ -1219,27 +1219,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

Expand Down
Loading
Loading