diff --git a/python/hidet/apps/diffusion/__init__.py b/python/hidet/apps/diffusion/__init__.py new file mode 100644 index 000000000..9d09a6723 --- /dev/null +++ b/python/hidet/apps/diffusion/__init__.py @@ -0,0 +1,2 @@ +from .app import DiffusionApp +from .builder import create_stable_diffusion diff --git a/python/hidet/apps/diffusion/app.py b/python/hidet/apps/diffusion/app.py new file mode 100644 index 000000000..cdf21e595 --- /dev/null +++ b/python/hidet/apps/diffusion/app.py @@ -0,0 +1,44 @@ +import torch +from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler + +from hidet.graph.tensor import from_torch, full +from hidet.runtime.compiled_app import CompiledApp + + +class DiffusionApp: + def __init__(self, compiled_app: CompiledApp, hf_pipeline: StableDiffusionPipeline, height: int, width: int): + super().__init__() + assert height % 8 == 0 and width % 8 == 0, "Height and width must be multiples of 8" + self.height = height + self.width = width + self.compiled_app: CompiledApp = compiled_app + self.hf_pipeline: StableDiffusionPipeline = hf_pipeline + + self.hf_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.hf_pipeline.scheduler.config) + self.hf_pipeline = self.hf_pipeline.to("cuda") + + def _unet_forward(sample: torch.Tensor, timesteps: torch.Tensor, encoder_hidden_states: torch.Tensor, **kwargs): + h_sample = from_torch(sample) + h_timesteps = full([sample.shape[0]], timesteps.item(), dtype="int64", device="cuda") + h_encoder_hidden_states = from_torch(encoder_hidden_states) + + down_outs = self.compiled_app.graphs["unet_down"](h_sample, h_timesteps, h_encoder_hidden_states) + h_sample = down_outs[0] + h_emb = down_outs[1] + h_down_block_residual_samples = down_outs[2:] + + h_sample = self.compiled_app.graphs["unet_mid"](h_sample, h_emb, h_encoder_hidden_states) + + h_sample = self.compiled_app.graphs["unet_up"]( + h_sample, h_emb, h_encoder_hidden_states, *h_down_block_residual_samples + ) + + return (h_sample.torch(),) + + self.hf_pipeline.unet.forward = _unet_forward + + def generate_image(self, prompt: str, negative_prompt: str): + + return self.hf_pipeline( + prompt=prompt, negative_prompt=negative_prompt, height=self.height, width=self.width + ).images diff --git a/python/hidet/apps/diffusion/builder.py b/python/hidet/apps/diffusion/builder.py new file mode 100644 index 000000000..80ca26aa4 --- /dev/null +++ b/python/hidet/apps/diffusion/builder.py @@ -0,0 +1,142 @@ +from typing import Optional, Tuple +from diffusers import StableDiffusionPipeline + +import hidet +from hidet.apps.diffusion.app import DiffusionApp +from hidet.apps.diffusion.modeling.pretrained import PretrainedModelForDiffusion +from hidet.apps.hf import load_diffusion_pipeline +from hidet.graph.flow_graph import trace_from +from hidet.graph import FlowGraph +from hidet.graph.tensor import symbol, Tensor +from hidet.runtime.compiled_app import create_compiled_app + + +def _build_unet_down_graph( + model: PretrainedModelForDiffusion, + dtype: str = "float32", + device: str = "cuda", + batch_size: int = 2, + num_channels_latents: int = 4, + height: int = 96, + width: int = 96, + embed_length: int = 77, + embed_hidden_dim: int = 1024, + kernel_search_space: int = 2, +): + latent_model_input: Tensor = symbol([batch_size, num_channels_latents, height, width], dtype=dtype, device=device) + timesteps: Tensor = symbol([batch_size], dtype="int64", device=device) + prompt_embeds: Tensor = symbol([batch_size, embed_length, embed_hidden_dim], dtype=dtype, device=device) + + inputs = (latent_model_input, timesteps, prompt_embeds) + outputs = sample, emb, down_block_residual_samples = model.forward_down(*inputs) + graph: FlowGraph = trace_from([sample, emb, *down_block_residual_samples], list(inputs)) + + graph = hidet.graph.optimize(graph) + + compiled_graph = graph.build(space=kernel_search_space) + + return compiled_graph, inputs, outputs + + +def _build_unet_mid_graph( + model: PretrainedModelForDiffusion, + sample: Tensor, + emb: Tensor, + encoder_hidden_states: Tensor, + kernel_search_space: int = 2, +): + sample, emb, encoder_hidden_states = tuple( + symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in (sample, emb, encoder_hidden_states) + ) + + output = model.forward_mid(sample, emb, encoder_hidden_states) + + graph: FlowGraph = trace_from(output, [sample, emb, encoder_hidden_states]) + + graph = hidet.graph.optimize(graph) + + compiled_graph = graph.build(space=kernel_search_space) + + return compiled_graph, output + + +def _build_unet_up_graph( + model: PretrainedModelForDiffusion, + sample: Tensor, + emb: Tensor, + encoder_hidden_states: Tensor, + down_block_residuals: Tuple[Tensor, ...], + kernel_search_space: int = 2, +): + sample, emb, encoder_hidden_states = tuple( + symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in (sample, emb, encoder_hidden_states) + ) + down_block_residuals = tuple(symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in down_block_residuals) + output = model.forward_up(sample, emb, encoder_hidden_states, down_block_residuals) + + graph: FlowGraph = trace_from(output, [sample, emb, encoder_hidden_states, *down_block_residuals]) + + graph = hidet.graph.optimize(graph) + + compiled_graph = graph.build(space=kernel_search_space) + + return compiled_graph, output + + +def create_stable_diffusion( + name: str, + revision: Optional[str] = None, + dtype: str = "float32", + device: str = "cuda", + batch_size: int = 1, + height: int = 768, + width: int = 768, + kernel_search_space: int = 2, +): + hf_pipeline: StableDiffusionPipeline = load_diffusion_pipeline(name=name, revision=revision, device=device) + # create the hidet model and load the pretrained weights from huggingface + model: PretrainedModelForDiffusion = PretrainedModelForDiffusion.create_pretrained_model( + name, revision=revision, device=device, dtype=dtype + ) + + unet_down_graph, inputs, outputs = _build_unet_down_graph( + model, + dtype=dtype, + device=device, + batch_size=batch_size * 2, # double size for prompt/negative prompt + num_channels_latents=model.config["in_channels"], + height=height // model.vae_scale_factor, + width=width // model.vae_scale_factor, + embed_length=model.embed_max_length, + embed_hidden_dim=model.embed_hidden_dim, + kernel_search_space=kernel_search_space, + ) + + _, _, prompt_embeds = inputs + sample, emb, down_block_residual_samples = outputs + + unet_mid_graph, sample = _build_unet_mid_graph( + model, sample=sample, emb=emb, encoder_hidden_states=prompt_embeds, kernel_search_space=kernel_search_space + ) + + unet_up_graph, sample = _build_unet_up_graph( + model, + sample=sample, + emb=emb, + encoder_hidden_states=prompt_embeds, + down_block_residuals=down_block_residual_samples, + kernel_search_space=kernel_search_space, + ) + + return DiffusionApp( + compiled_app=create_compiled_app( + graphs={"unet_down": unet_down_graph, "unet_mid": unet_mid_graph, "unet_up": unet_up_graph}, + modules={}, + tensors={}, + attributes={}, + name=name, + ), + hf_pipeline=hf_pipeline, + height=height, + width=width, + ) diff --git a/python/hidet/apps/diffusion/modeling/__init__.py b/python/hidet/apps/diffusion/modeling/__init__.py new file mode 100644 index 000000000..ac24614cc --- /dev/null +++ b/python/hidet/apps/diffusion/modeling/__init__.py @@ -0,0 +1,2 @@ +from .stable_diffusion import * +from .pretrained import PretrainedModelForDiffusion diff --git a/python/hidet/apps/diffusion/modeling/pretrained.py b/python/hidet/apps/diffusion/modeling/pretrained.py new file mode 100644 index 000000000..19149b9b8 --- /dev/null +++ b/python/hidet/apps/diffusion/modeling/pretrained.py @@ -0,0 +1,60 @@ +from typing import Optional +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline + +from hidet.apps.hf import load_diffusion_pipeline +from hidet.apps.pretrained import PretrainedModel + + +class PretrainedModelForDiffusion(PretrainedModel): + @classmethod + def create_pretrained_model( + cls, + name: str, + revision: Optional[str] = None, + hf_pipeline: Optional[StableDiffusionPipeline] = None, + dtype: Optional[str] = None, + device: str = "cuda", + ): + # load the pretrained huggingface model + # note: diffusers pipeline is more similar to a model than transformers pipeline + if hf_pipeline is None: + hf_pipeline: StableDiffusionPipeline = load_diffusion_pipeline(name=name, revision=revision, device=device) + + pipeline_config = hf_pipeline.config + + torch_unet = hf_pipeline.unet + pretrained_unet_class = cls.load_module(pipeline_config["unet"][1]) + + hidet_unet = pretrained_unet_class( + **dict(torch_unet.config), + vae_scale_factor=hf_pipeline.vae_scale_factor, + embed_max_length=hf_pipeline.text_encoder.config.max_position_embeddings, + embed_hidden_dim=hf_pipeline.text_encoder.config.hidden_size + ) + + hidet_unet.to(dtype=dtype, device=device) + + cls.copy_weights(torch_unet, hidet_unet) + + return hidet_unet + + @property + def embed_max_length(self): + raise NotImplementedError() + + @property + def embed_hidden_dim(self): + raise NotImplementedError() + + @property + def vae_scale_factor(self): + raise NotImplementedError() + + def forward_down(self, *args, **kwargs): + raise NotImplementedError() + + def forward_mid(self, *args, **kwargs): + raise NotImplementedError() + + def forward_up(self, *args, **kwargs): + raise NotImplementedError() diff --git a/python/hidet/apps/diffusion/modeling/stable_diffusion/unet.py b/python/hidet/apps/diffusion/modeling/stable_diffusion/unet.py index a7f0a297e..d6bf72c10 100644 --- a/python/hidet/apps/diffusion/modeling/stable_diffusion/unet.py +++ b/python/hidet/apps/diffusion/modeling/stable_diffusion/unet.py @@ -1,6 +1,6 @@ from typing import Tuple from hidet import nn -from hidet.apps.diffusion.modeling.pretrained import PretrainedModelForText2Image +from hidet.apps.diffusion.modeling.pretrained import PretrainedModelForDiffusion from hidet.apps.diffusion.modeling.stable_diffusion.timestep import TimestepEmbedding, Timesteps from hidet.apps.diffusion.modeling.stable_diffusion.unet_blocks import ( CrossAttnDownBlock2D, @@ -9,7 +9,6 @@ MidBlock2DCrossAttn, UpBlock2D, ) -from hidet.apps.modeling_outputs import UNet2DConditionOutput from hidet.apps.pretrained import PretrainedModel from hidet.apps.registry import RegistryEntry from hidet.graph.tensor import Tensor @@ -23,7 +22,7 @@ ) -class UNet2DConditionModel(PretrainedModelForText2Image): +class UNet2DConditionModel(PretrainedModelForDiffusion): def __init__(self, **kwargs): super().__init__(kwargs) self.conv_in = nn.Conv2d( @@ -192,6 +191,18 @@ def __init__(self, **kwargs): bias=True, ) + @property + def embed_max_length(self): + return self.config["embed_max_length"] + + @property + def embed_hidden_dim(self): + return self.config["embed_hidden_dim"] + + @property + def vae_scale_factor(self): + return self.config["vae_scale_factor"] + def get_down_block(self, down_block_type: str, **kwargs): if down_block_type == "CrossAttnDownBlock2D": return CrossAttnDownBlock2D(**{**self.config, **kwargs}) # type: ignore @@ -276,9 +287,7 @@ def forward_up( return sample - def forward( - self, sample: Tensor, timesteps: Tensor, encoder_hidden_states: Tensor, **kwargs - ) -> UNet2DConditionOutput: + def forward(self, sample: Tensor, timesteps: Tensor, encoder_hidden_states: Tensor, **kwargs) -> Tensor: timesteps = broadcast(timesteps, shape=(sample.shape[0],)) sample, emb, down_block_residual_samples = self.forward_down(sample, timesteps, encoder_hidden_states) @@ -287,4 +296,4 @@ def forward( sample = self.forward_up(sample, emb, encoder_hidden_states, down_block_residual_samples) - return UNet2DConditionOutput(last_hidden_state=sample, hidden_states=[sample]) + return sample diff --git a/python/hidet/apps/hf.py b/python/hidet/apps/hf.py index 381afdd4d..8d3d7cd8a 100644 --- a/python/hidet/apps/hf.py +++ b/python/hidet/apps/hf.py @@ -1,10 +1,24 @@ from typing import Optional +import torch from transformers import AutoConfig, PretrainedConfig +from diffusers import StableDiffusionPipeline import hidet +def _get_hf_auth_token(): + return hidet.option.get_option('auth_tokens.for_huggingface') + + def load_pretrained_config(model: str, revision: Optional[str] = None) -> PretrainedConfig: - huggingface_token = hidet.option.get_option('auth_tokens.for_huggingface') + huggingface_token = _get_hf_auth_token() return AutoConfig.from_pretrained(model, revision=revision, token=huggingface_token) + + +def load_diffusion_pipeline(name: str, revision: Optional[str] = None, device: str = "cuda") -> StableDiffusionPipeline: + huggingface_token = _get_hf_auth_token() + with torch.device(device): + return StableDiffusionPipeline.from_pretrained( + pretrained_model_name_or_path=name, torch_dtype=torch.float32, revision=revision, token=huggingface_token + ) diff --git a/python/hidet/apps/image_classification/modeling/pretrained.py b/python/hidet/apps/image_classification/modeling/pretrained.py index b4f8f1f3f..a9d139335 100644 --- a/python/hidet/apps/image_classification/modeling/pretrained.py +++ b/python/hidet/apps/image_classification/modeling/pretrained.py @@ -15,7 +15,7 @@ def create_pretrained_model( cls, config: PretrainedConfig, revision: Optional[str] = None, dtype: Optional[str] = None, device: str = "cuda" ): # dynamically load model subclass - pretrained_model_class = cls.load_module(config) + pretrained_model_class = cls.load_module(config.architectures[0]) # load the pretrained huggingface model into cpu with torch.device("cuda"): # reduce the time to load the model diff --git a/python/hidet/apps/pretrained.py b/python/hidet/apps/pretrained.py index 5a0734188..53c6e49dc 100644 --- a/python/hidet/apps/pretrained.py +++ b/python/hidet/apps/pretrained.py @@ -1,4 +1,5 @@ from typing import Generic, List, Set +import logging import torch from transformers import PretrainedConfig @@ -7,6 +8,12 @@ from hidet.graph import Tensor, nn from hidet.graph.nn.module import R from hidet.graph.tensor import from_torch +from hidet.utils import prod + + +logger = logging.Logger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) class PretrainedModel(nn.Module[R], Registry, Generic[R]): @@ -33,8 +40,14 @@ def copy_weights(cls, torch_model: torch.nn.Module, hidet_model: nn.Module): ) src = from_torch(tensor).to(member.dtype, member.device) + if src.shape != member.shape: - raise ValueError(f"Parameter {name} shape mismatch, hidet: {member.shape}, torch: {src.shape}") + if prod(src.shape) == prod(member.shape): + logging.warning("Attempting to reshape parameter %s from %s to %s.", name, src.shape, member.shape) + src = src.reshape(member.shape) + else: + raise ValueError(f"Parameter {name} shape mismatch, hidet: {member.shape}, torch: {src.shape}") + found_tensors.append(member) member.copy_(src) diff --git a/python/hidet/apps/registry.py b/python/hidet/apps/registry.py index 57d5ad8d2..3953aa679 100644 --- a/python/hidet/apps/registry.py +++ b/python/hidet/apps/registry.py @@ -2,8 +2,6 @@ from dataclasses import astuple, dataclass from typing import Dict -from transformers import PretrainedConfig - @dataclass class RegistryEntry: @@ -46,13 +44,7 @@ class Registry: module_registry: Dict[str, RegistryEntry] = {} @classmethod - def load_module(cls, config: PretrainedConfig): - architectures = getattr(config, "architectures") - if not architectures: - raise ValueError(f"Config {config.name_or_path} has no architecture.") - - # assume only 1 architecture available for now - architecture = architectures[0] + def load_module(cls, architecture: str): if architecture not in cls.module_registry: raise KeyError( f"No model class with architecture {architecture} found." diff --git a/python/hidet/graph/nn/__init__.py b/python/hidet/graph/nn/__init__.py index def2ba30e..1702ad295 100644 --- a/python/hidet/graph/nn/__init__.py +++ b/python/hidet/graph/nn/__init__.py @@ -19,6 +19,6 @@ from .activations import Relu, Gelu, Geglu, Tanh from .convolutions import Conv2d from .linear import Linear, LinearTransposed -from .norms import BatchNorm2d, LayerNorm +from .norms import BatchNorm2d, LayerNorm, GroupNorm from .poolings import MaxPool2d, AvgPool2d, AdaptiveAvgPool2d from .transforms import Embedding diff --git a/python/hidet/option.py b/python/hidet/option.py index fe7e85287..6b93f6ac7 100644 --- a/python/hidet/option.py +++ b/python/hidet/option.py @@ -289,7 +289,7 @@ def register_hidet_options(): name='auth_tokens.for_huggingface', type_hint='str', default_value='', - description='The auth-tokens to use for accessing private huggingface models. ', + description='The auth token to use for accessing private huggingface models.', ) config_file_path = os.path.join(os.path.expanduser('~'), '.config', 'hidet') diff --git a/requirements-dev.txt b/requirements-dev.txt index afcfabba3..960a133db 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,6 +15,7 @@ pylint==2.13.9 torch torchvision datasets +diffusers==0.27.1 transformers==4.37 sentencepiece sacremoses diff --git a/tests/apps/diffusion/test_diffusion_builder.py b/tests/apps/diffusion/test_diffusion_builder.py new file mode 100644 index 000000000..26f020a8c --- /dev/null +++ b/tests/apps/diffusion/test_diffusion_builder.py @@ -0,0 +1,20 @@ +from hidet.apps.diffusion.app import DiffusionApp +from hidet.apps.diffusion.builder import create_stable_diffusion +import pytest + + +def test_create_stable_diffusion(): + diffusion_app: DiffusionApp = create_stable_diffusion( + "stabilityai/stable-diffusion-2-1", kernel_search_space=0, height=512, width=512 + ) + res = diffusion_app.generate_image( + "Software engineer writing code at desk with laptop, soft glow, detailed image.", "blurry, multiple fingers" + ) + + assert res[0].height == 512 + assert res[0].width == 512 + assert res[0].mode == "RGB" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/apps/image_classification/test_builder.py b/tests/apps/image_classification/test_image_classifier_builder.py similarity index 100% rename from tests/apps/image_classification/test_builder.py rename to tests/apps/image_classification/test_image_classifier_builder.py diff --git a/tests/apps/test_pretrained.py b/tests/apps/test_pretrained.py index 547f304d1..c06cc0799 100644 --- a/tests/apps/test_pretrained.py +++ b/tests/apps/test_pretrained.py @@ -2,7 +2,6 @@ import torch from hidet.apps import PretrainedModel, hf from hidet.apps.image_classification.modeling.resnet.modeling import ResNetForImageClassification -from hidet.graph.tensor import empty from hidet.option import get_option from transformers import AutoModelForImageClassification, PretrainedConfig, ResNetConfig diff --git a/tests/apps/test_registry.py b/tests/apps/test_registry.py index be24ddfa9..b32820bd0 100644 --- a/tests/apps/test_registry.py +++ b/tests/apps/test_registry.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize('model_name', ["microsoft/resnet-50"]) def test_load_module(model_name: str): config: PretrainedConfig = hf.load_pretrained_config(model_name) - assert Registry.load_module(config) is ResNetForImageClassification + assert Registry.load_module(config.architectures[0]) is ResNetForImageClassification if __name__ == '__main__':