diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index fe00d8c078ff..49cefcd8a142 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -749,6 +749,16 @@ def __init__( self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -801,18 +811,12 @@ def disable_slicing(self) -> None: self.use_slicing = False def clear_cache(self): - def _count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, WanCausalConv3d): - count += 1 - return count - - self._conv_num = _count_conv3d(self.decoder) + # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call + self._conv_num = self._cached_conv_counts["decoder"] self._conv_idx = [0] self._feat_map = [None] * self._conv_num # cache encode - self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_num = self._cached_conv_counts["encoder"] self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num