Skip to content

Commit

Permalink
Stable Diffusion App Infra (hidet-org#103)
Browse files Browse the repository at this point in the history
Infrastructure for compiled stable diffusion app.

Towards hidet-org#57
  • Loading branch information
KTong821 authored Apr 11, 2024
1 parent a6ce95a commit d53acab
Show file tree
Hide file tree
Showing 17 changed files with 321 additions and 23 deletions.
2 changes: 2 additions & 0 deletions python/hidet/apps/diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .app import DiffusionApp
from .builder import create_stable_diffusion
44 changes: 44 additions & 0 deletions python/hidet/apps/diffusion/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

from hidet.graph.tensor import from_torch, full
from hidet.runtime.compiled_app import CompiledApp


class DiffusionApp:
def __init__(self, compiled_app: CompiledApp, hf_pipeline: StableDiffusionPipeline, height: int, width: int):
super().__init__()
assert height % 8 == 0 and width % 8 == 0, "Height and width must be multiples of 8"
self.height = height
self.width = width
self.compiled_app: CompiledApp = compiled_app
self.hf_pipeline: StableDiffusionPipeline = hf_pipeline

self.hf_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.hf_pipeline.scheduler.config)
self.hf_pipeline = self.hf_pipeline.to("cuda")

def _unet_forward(sample: torch.Tensor, timesteps: torch.Tensor, encoder_hidden_states: torch.Tensor, **kwargs):
h_sample = from_torch(sample)
h_timesteps = full([sample.shape[0]], timesteps.item(), dtype="int64", device="cuda")
h_encoder_hidden_states = from_torch(encoder_hidden_states)

down_outs = self.compiled_app.graphs["unet_down"](h_sample, h_timesteps, h_encoder_hidden_states)
h_sample = down_outs[0]
h_emb = down_outs[1]
h_down_block_residual_samples = down_outs[2:]

h_sample = self.compiled_app.graphs["unet_mid"](h_sample, h_emb, h_encoder_hidden_states)

h_sample = self.compiled_app.graphs["unet_up"](
h_sample, h_emb, h_encoder_hidden_states, *h_down_block_residual_samples
)

return (h_sample.torch(),)

self.hf_pipeline.unet.forward = _unet_forward

def generate_image(self, prompt: str, negative_prompt: str):

return self.hf_pipeline(
prompt=prompt, negative_prompt=negative_prompt, height=self.height, width=self.width
).images
142 changes: 142 additions & 0 deletions python/hidet/apps/diffusion/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Optional, Tuple
from diffusers import StableDiffusionPipeline

import hidet
from hidet.apps.diffusion.app import DiffusionApp
from hidet.apps.diffusion.modeling.pretrained import PretrainedModelForDiffusion
from hidet.apps.hf import load_diffusion_pipeline
from hidet.graph.flow_graph import trace_from
from hidet.graph import FlowGraph
from hidet.graph.tensor import symbol, Tensor
from hidet.runtime.compiled_app import create_compiled_app


def _build_unet_down_graph(
model: PretrainedModelForDiffusion,
dtype: str = "float32",
device: str = "cuda",
batch_size: int = 2,
num_channels_latents: int = 4,
height: int = 96,
width: int = 96,
embed_length: int = 77,
embed_hidden_dim: int = 1024,
kernel_search_space: int = 2,
):
latent_model_input: Tensor = symbol([batch_size, num_channels_latents, height, width], dtype=dtype, device=device)
timesteps: Tensor = symbol([batch_size], dtype="int64", device=device)
prompt_embeds: Tensor = symbol([batch_size, embed_length, embed_hidden_dim], dtype=dtype, device=device)

inputs = (latent_model_input, timesteps, prompt_embeds)
outputs = sample, emb, down_block_residual_samples = model.forward_down(*inputs)
graph: FlowGraph = trace_from([sample, emb, *down_block_residual_samples], list(inputs))

graph = hidet.graph.optimize(graph)

compiled_graph = graph.build(space=kernel_search_space)

return compiled_graph, inputs, outputs


def _build_unet_mid_graph(
model: PretrainedModelForDiffusion,
sample: Tensor,
emb: Tensor,
encoder_hidden_states: Tensor,
kernel_search_space: int = 2,
):
sample, emb, encoder_hidden_states = tuple(
symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in (sample, emb, encoder_hidden_states)
)

output = model.forward_mid(sample, emb, encoder_hidden_states)

graph: FlowGraph = trace_from(output, [sample, emb, encoder_hidden_states])

graph = hidet.graph.optimize(graph)

compiled_graph = graph.build(space=kernel_search_space)

return compiled_graph, output


def _build_unet_up_graph(
model: PretrainedModelForDiffusion,
sample: Tensor,
emb: Tensor,
encoder_hidden_states: Tensor,
down_block_residuals: Tuple[Tensor, ...],
kernel_search_space: int = 2,
):
sample, emb, encoder_hidden_states = tuple(
symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in (sample, emb, encoder_hidden_states)
)
down_block_residuals = tuple(symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in down_block_residuals)
output = model.forward_up(sample, emb, encoder_hidden_states, down_block_residuals)

graph: FlowGraph = trace_from(output, [sample, emb, encoder_hidden_states, *down_block_residuals])

graph = hidet.graph.optimize(graph)

compiled_graph = graph.build(space=kernel_search_space)

return compiled_graph, output


def create_stable_diffusion(
name: str,
revision: Optional[str] = None,
dtype: str = "float32",
device: str = "cuda",
batch_size: int = 1,
height: int = 768,
width: int = 768,
kernel_search_space: int = 2,
):
hf_pipeline: StableDiffusionPipeline = load_diffusion_pipeline(name=name, revision=revision, device=device)
# create the hidet model and load the pretrained weights from huggingface
model: PretrainedModelForDiffusion = PretrainedModelForDiffusion.create_pretrained_model(
name, revision=revision, device=device, dtype=dtype
)

unet_down_graph, inputs, outputs = _build_unet_down_graph(
model,
dtype=dtype,
device=device,
batch_size=batch_size * 2, # double size for prompt/negative prompt
num_channels_latents=model.config["in_channels"],
height=height // model.vae_scale_factor,
width=width // model.vae_scale_factor,
embed_length=model.embed_max_length,
embed_hidden_dim=model.embed_hidden_dim,
kernel_search_space=kernel_search_space,
)

_, _, prompt_embeds = inputs
sample, emb, down_block_residual_samples = outputs

unet_mid_graph, sample = _build_unet_mid_graph(
model, sample=sample, emb=emb, encoder_hidden_states=prompt_embeds, kernel_search_space=kernel_search_space
)

unet_up_graph, sample = _build_unet_up_graph(
model,
sample=sample,
emb=emb,
encoder_hidden_states=prompt_embeds,
down_block_residuals=down_block_residual_samples,
kernel_search_space=kernel_search_space,
)

return DiffusionApp(
compiled_app=create_compiled_app(
graphs={"unet_down": unet_down_graph, "unet_mid": unet_mid_graph, "unet_up": unet_up_graph},
modules={},
tensors={},
attributes={},
name=name,
),
hf_pipeline=hf_pipeline,
height=height,
width=width,
)
2 changes: 2 additions & 0 deletions python/hidet/apps/diffusion/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .stable_diffusion import *
from .pretrained import PretrainedModelForDiffusion
60 changes: 60 additions & 0 deletions python/hidet/apps/diffusion/modeling/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Optional
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline

from hidet.apps.hf import load_diffusion_pipeline
from hidet.apps.pretrained import PretrainedModel


class PretrainedModelForDiffusion(PretrainedModel):
@classmethod
def create_pretrained_model(
cls,
name: str,
revision: Optional[str] = None,
hf_pipeline: Optional[StableDiffusionPipeline] = None,
dtype: Optional[str] = None,
device: str = "cuda",
):
# load the pretrained huggingface model
# note: diffusers pipeline is more similar to a model than transformers pipeline
if hf_pipeline is None:
hf_pipeline: StableDiffusionPipeline = load_diffusion_pipeline(name=name, revision=revision, device=device)

pipeline_config = hf_pipeline.config

torch_unet = hf_pipeline.unet
pretrained_unet_class = cls.load_module(pipeline_config["unet"][1])

hidet_unet = pretrained_unet_class(
**dict(torch_unet.config),
vae_scale_factor=hf_pipeline.vae_scale_factor,
embed_max_length=hf_pipeline.text_encoder.config.max_position_embeddings,
embed_hidden_dim=hf_pipeline.text_encoder.config.hidden_size
)

hidet_unet.to(dtype=dtype, device=device)

cls.copy_weights(torch_unet, hidet_unet)

return hidet_unet

@property
def embed_max_length(self):
raise NotImplementedError()

@property
def embed_hidden_dim(self):
raise NotImplementedError()

@property
def vae_scale_factor(self):
raise NotImplementedError()

def forward_down(self, *args, **kwargs):
raise NotImplementedError()

def forward_mid(self, *args, **kwargs):
raise NotImplementedError()

def forward_up(self, *args, **kwargs):
raise NotImplementedError()
23 changes: 16 additions & 7 deletions python/hidet/apps/diffusion/modeling/stable_diffusion/unet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple
from hidet import nn
from hidet.apps.diffusion.modeling.pretrained import PretrainedModelForText2Image
from hidet.apps.diffusion.modeling.pretrained import PretrainedModelForDiffusion
from hidet.apps.diffusion.modeling.stable_diffusion.timestep import TimestepEmbedding, Timesteps
from hidet.apps.diffusion.modeling.stable_diffusion.unet_blocks import (
CrossAttnDownBlock2D,
Expand All @@ -9,7 +9,6 @@
MidBlock2DCrossAttn,
UpBlock2D,
)
from hidet.apps.modeling_outputs import UNet2DConditionOutput
from hidet.apps.pretrained import PretrainedModel
from hidet.apps.registry import RegistryEntry
from hidet.graph.tensor import Tensor
Expand All @@ -23,7 +22,7 @@
)


class UNet2DConditionModel(PretrainedModelForText2Image):
class UNet2DConditionModel(PretrainedModelForDiffusion):
def __init__(self, **kwargs):
super().__init__(kwargs)
self.conv_in = nn.Conv2d(
Expand Down Expand Up @@ -192,6 +191,18 @@ def __init__(self, **kwargs):
bias=True,
)

@property
def embed_max_length(self):
return self.config["embed_max_length"]

@property
def embed_hidden_dim(self):
return self.config["embed_hidden_dim"]

@property
def vae_scale_factor(self):
return self.config["vae_scale_factor"]

def get_down_block(self, down_block_type: str, **kwargs):
if down_block_type == "CrossAttnDownBlock2D":
return CrossAttnDownBlock2D(**{**self.config, **kwargs}) # type: ignore
Expand Down Expand Up @@ -276,9 +287,7 @@ def forward_up(

return sample

def forward(
self, sample: Tensor, timesteps: Tensor, encoder_hidden_states: Tensor, **kwargs
) -> UNet2DConditionOutput:
def forward(self, sample: Tensor, timesteps: Tensor, encoder_hidden_states: Tensor, **kwargs) -> Tensor:
timesteps = broadcast(timesteps, shape=(sample.shape[0],))

sample, emb, down_block_residual_samples = self.forward_down(sample, timesteps, encoder_hidden_states)
Expand All @@ -287,4 +296,4 @@ def forward(

sample = self.forward_up(sample, emb, encoder_hidden_states, down_block_residual_samples)

return UNet2DConditionOutput(last_hidden_state=sample, hidden_states=[sample])
return sample
16 changes: 15 additions & 1 deletion python/hidet/apps/hf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from typing import Optional

import torch
from transformers import AutoConfig, PretrainedConfig
from diffusers import StableDiffusionPipeline

import hidet


def _get_hf_auth_token():
return hidet.option.get_option('auth_tokens.for_huggingface')


def load_pretrained_config(model: str, revision: Optional[str] = None) -> PretrainedConfig:
huggingface_token = hidet.option.get_option('auth_tokens.for_huggingface')
huggingface_token = _get_hf_auth_token()
return AutoConfig.from_pretrained(model, revision=revision, token=huggingface_token)


def load_diffusion_pipeline(name: str, revision: Optional[str] = None, device: str = "cuda") -> StableDiffusionPipeline:
huggingface_token = _get_hf_auth_token()
with torch.device(device):
return StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=name, torch_dtype=torch.float32, revision=revision, token=huggingface_token
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def create_pretrained_model(
cls, config: PretrainedConfig, revision: Optional[str] = None, dtype: Optional[str] = None, device: str = "cuda"
):
# dynamically load model subclass
pretrained_model_class = cls.load_module(config)
pretrained_model_class = cls.load_module(config.architectures[0])

# load the pretrained huggingface model into cpu
with torch.device("cuda"): # reduce the time to load the model
Expand Down
15 changes: 14 additions & 1 deletion python/hidet/apps/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Generic, List, Set
import logging

import torch
from transformers import PretrainedConfig
Expand All @@ -7,6 +8,12 @@
from hidet.graph import Tensor, nn
from hidet.graph.nn.module import R
from hidet.graph.tensor import from_torch
from hidet.utils import prod


logger = logging.Logger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())


class PretrainedModel(nn.Module[R], Registry, Generic[R]):
Expand All @@ -33,8 +40,14 @@ def copy_weights(cls, torch_model: torch.nn.Module, hidet_model: nn.Module):
)

src = from_torch(tensor).to(member.dtype, member.device)

if src.shape != member.shape:
raise ValueError(f"Parameter {name} shape mismatch, hidet: {member.shape}, torch: {src.shape}")
if prod(src.shape) == prod(member.shape):
logging.warning("Attempting to reshape parameter %s from %s to %s.", name, src.shape, member.shape)
src = src.reshape(member.shape)
else:
raise ValueError(f"Parameter {name} shape mismatch, hidet: {member.shape}, torch: {src.shape}")

found_tensors.append(member)
member.copy_(src)

Expand Down
Loading

0 comments on commit d53acab

Please sign in to comment.