|
| 1 | +# Copyright (c) Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the BSD License (the "License"); you may not use this file |
| 4 | +# except in compliance with the License. See the license file found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Stable Diffusion / LCM model definitions. |
| 9 | +
|
| 10 | +This module provides reusable model wrappers that can be used with any backend |
| 11 | +(OpenVINO, XNNPACK, etc.) for exporting Latent Consistency Models. |
| 12 | +""" |
| 13 | + |
| 14 | +import logging |
| 15 | +from typing import Any, Optional |
| 16 | + |
| 17 | +import torch |
| 18 | + |
| 19 | +try: |
| 20 | + from diffusers import DiffusionPipeline |
| 21 | +except ImportError: |
| 22 | + raise ImportError( |
| 23 | + "Please install diffusers and transformers: pip install diffusers transformers" |
| 24 | + ) |
| 25 | + |
| 26 | +logger = logging.getLogger(__name__) |
| 27 | + |
| 28 | + |
| 29 | +class TextEncoderWrapper(torch.nn.Module): |
| 30 | + """Wrapper for CLIP text encoder that extracts last_hidden_state""" |
| 31 | + |
| 32 | + def __init__(self, text_encoder): |
| 33 | + super().__init__() |
| 34 | + self.text_encoder = text_encoder |
| 35 | + |
| 36 | + def forward(self, input_ids): |
| 37 | + # Call text encoder and extract last_hidden_state |
| 38 | + output = self.text_encoder(input_ids, return_dict=True) |
| 39 | + return output.last_hidden_state |
| 40 | + |
| 41 | + |
| 42 | +class UNetWrapper(torch.nn.Module): |
| 43 | + """Wrapper for UNet that extracts sample tensor from output""" |
| 44 | + |
| 45 | + def __init__(self, unet): |
| 46 | + super().__init__() |
| 47 | + self.unet = unet |
| 48 | + |
| 49 | + def forward(self, latents, timestep, encoder_hidden_states): |
| 50 | + # Call UNet and extract sample from the output |
| 51 | + output = self.unet(latents, timestep, encoder_hidden_states, return_dict=True) |
| 52 | + return output.sample |
| 53 | + |
| 54 | + |
| 55 | +class VAEDecoder(torch.nn.Module): |
| 56 | + """Wrapper for VAE decoder with scaling and normalization""" |
| 57 | + |
| 58 | + def __init__(self, vae): |
| 59 | + super().__init__() |
| 60 | + self.vae = vae |
| 61 | + |
| 62 | + def forward(self, latents): |
| 63 | + # Scale latents |
| 64 | + latents = latents / self.vae.config.scaling_factor |
| 65 | + # Decode |
| 66 | + image = self.vae.decode(latents).sample |
| 67 | + # Scale to [0, 1] |
| 68 | + image = (image / 2 + 0.5).clamp(0, 1) |
| 69 | + return image |
| 70 | + |
| 71 | + |
| 72 | +class LCMModelLoader: |
| 73 | + """ |
| 74 | + Backend-agnostic loader for Latent Consistency Model components. |
| 75 | +
|
| 76 | + This class handles loading the LCM pipeline from HuggingFace and extracting |
| 77 | + individual components (text_encoder, unet, vae) as PyTorch modules ready |
| 78 | + for export to any backend. |
| 79 | + """ |
| 80 | + |
| 81 | + def __init__( |
| 82 | + self, |
| 83 | + model_id: str = "SimianLuo/LCM_Dreamshaper_v7", |
| 84 | + dtype: torch.dtype = torch.float16, |
| 85 | + ): |
| 86 | + """ |
| 87 | + Initialize the LCM model loader. |
| 88 | +
|
| 89 | + Args: |
| 90 | + model_id: HuggingFace model ID for the LCM model |
| 91 | + dtype: Target dtype for the models (fp16 or fp32) |
| 92 | + """ |
| 93 | + self.model_id = model_id |
| 94 | + self.dtype = dtype |
| 95 | + self.pipeline: Optional[DiffusionPipeline] = None |
| 96 | + self.text_encoder: Any = None |
| 97 | + self.unet: Any = None |
| 98 | + self.vae: Any = None |
| 99 | + self.tokenizer: Any = None |
| 100 | + |
| 101 | + def load_models(self) -> bool: |
| 102 | + """ |
| 103 | + Load the LCM pipeline and extract components. |
| 104 | +
|
| 105 | + Returns: |
| 106 | + True if successful, False otherwise |
| 107 | + """ |
| 108 | + try: |
| 109 | + logger.info(f"Loading LCM pipeline: {self.model_id} (dtype: {self.dtype})") |
| 110 | + self.pipeline = DiffusionPipeline.from_pretrained( |
| 111 | + self.model_id, use_safetensors=True |
| 112 | + ) |
| 113 | + |
| 114 | + # Extract individual components and convert to desired dtype |
| 115 | + self.text_encoder = self.pipeline.text_encoder.to(dtype=self.dtype) |
| 116 | + self.unet = self.pipeline.unet.to(dtype=self.dtype) |
| 117 | + self.vae = self.pipeline.vae.to(dtype=self.dtype) |
| 118 | + self.tokenizer = self.pipeline.tokenizer |
| 119 | + |
| 120 | + # Set models to evaluation mode |
| 121 | + self.text_encoder.eval() |
| 122 | + self.unet.eval() |
| 123 | + self.vae.eval() |
| 124 | + |
| 125 | + logger.info("Successfully loaded all LCM model components") |
| 126 | + return True |
| 127 | + |
| 128 | + except Exception as e: |
| 129 | + logger.error(f"Failed to load models: {e}") |
| 130 | + import traceback |
| 131 | + |
| 132 | + traceback.print_exc() |
| 133 | + return False |
| 134 | + |
| 135 | + def get_text_encoder_wrapper(self) -> TextEncoderWrapper: |
| 136 | + """Get wrapped text encoder ready for export""" |
| 137 | + if self.text_encoder is None: |
| 138 | + raise ValueError("Models not loaded. Call load_models() first.") |
| 139 | + return TextEncoderWrapper(self.text_encoder) |
| 140 | + |
| 141 | + def get_unet_wrapper(self) -> UNetWrapper: |
| 142 | + """Get wrapped UNet ready for export""" |
| 143 | + if self.unet is None: |
| 144 | + raise ValueError("Models not loaded. Call load_models() first.") |
| 145 | + return UNetWrapper(self.unet) |
| 146 | + |
| 147 | + def get_vae_decoder(self) -> VAEDecoder: |
| 148 | + """Get wrapped VAE decoder ready for export""" |
| 149 | + if self.vae is None: |
| 150 | + raise ValueError("Models not loaded. Call load_models() first.") |
| 151 | + return VAEDecoder(self.vae) |
| 152 | + |
| 153 | + def get_dummy_inputs(self): |
| 154 | + """ |
| 155 | + Get dummy inputs for each model component. |
| 156 | +
|
| 157 | + Returns: |
| 158 | + Dictionary with dummy inputs for text_encoder, unet, and vae_decoder |
| 159 | + """ |
| 160 | + if self.unet is None: |
| 161 | + raise ValueError("Models not loaded. Call load_models() first.") |
| 162 | + |
| 163 | + # Text encoder dummy input |
| 164 | + text_encoder_input = torch.ones(1, 77, dtype=torch.long) |
| 165 | + |
| 166 | + # UNet dummy inputs |
| 167 | + batch_size = 1 |
| 168 | + latent_channels = 4 |
| 169 | + latent_height = 64 |
| 170 | + latent_width = 64 |
| 171 | + text_embed_dim = self.unet.config.cross_attention_dim |
| 172 | + text_seq_len = 77 |
| 173 | + |
| 174 | + unet_inputs = ( |
| 175 | + torch.randn( |
| 176 | + batch_size, |
| 177 | + latent_channels, |
| 178 | + latent_height, |
| 179 | + latent_width, |
| 180 | + dtype=self.dtype, |
| 181 | + ), |
| 182 | + torch.tensor([981]), # Random timestep |
| 183 | + torch.randn(batch_size, text_seq_len, text_embed_dim, dtype=self.dtype), |
| 184 | + ) |
| 185 | + |
| 186 | + # VAE decoder dummy input |
| 187 | + vae_input = torch.randn(1, 4, 64, 64, dtype=self.dtype) |
| 188 | + |
| 189 | + return { |
| 190 | + "text_encoder": (text_encoder_input,), |
| 191 | + "unet": unet_inputs, |
| 192 | + "vae_decoder": (vae_input,), |
| 193 | + } |
0 commit comments