forked from hidet-org/hidet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stable Diffusion App Infra (hidet-org#103)
Infrastructure for compiled stable diffusion app. Towards hidet-org#57
- Loading branch information
Showing
17 changed files
with
321 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .app import DiffusionApp | ||
from .builder import create_stable_diffusion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | ||
|
||
from hidet.graph.tensor import from_torch, full | ||
from hidet.runtime.compiled_app import CompiledApp | ||
|
||
|
||
class DiffusionApp: | ||
def __init__(self, compiled_app: CompiledApp, hf_pipeline: StableDiffusionPipeline, height: int, width: int): | ||
super().__init__() | ||
assert height % 8 == 0 and width % 8 == 0, "Height and width must be multiples of 8" | ||
self.height = height | ||
self.width = width | ||
self.compiled_app: CompiledApp = compiled_app | ||
self.hf_pipeline: StableDiffusionPipeline = hf_pipeline | ||
|
||
self.hf_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.hf_pipeline.scheduler.config) | ||
self.hf_pipeline = self.hf_pipeline.to("cuda") | ||
|
||
def _unet_forward(sample: torch.Tensor, timesteps: torch.Tensor, encoder_hidden_states: torch.Tensor, **kwargs): | ||
h_sample = from_torch(sample) | ||
h_timesteps = full([sample.shape[0]], timesteps.item(), dtype="int64", device="cuda") | ||
h_encoder_hidden_states = from_torch(encoder_hidden_states) | ||
|
||
down_outs = self.compiled_app.graphs["unet_down"](h_sample, h_timesteps, h_encoder_hidden_states) | ||
h_sample = down_outs[0] | ||
h_emb = down_outs[1] | ||
h_down_block_residual_samples = down_outs[2:] | ||
|
||
h_sample = self.compiled_app.graphs["unet_mid"](h_sample, h_emb, h_encoder_hidden_states) | ||
|
||
h_sample = self.compiled_app.graphs["unet_up"]( | ||
h_sample, h_emb, h_encoder_hidden_states, *h_down_block_residual_samples | ||
) | ||
|
||
return (h_sample.torch(),) | ||
|
||
self.hf_pipeline.unet.forward = _unet_forward | ||
|
||
def generate_image(self, prompt: str, negative_prompt: str): | ||
|
||
return self.hf_pipeline( | ||
prompt=prompt, negative_prompt=negative_prompt, height=self.height, width=self.width | ||
).images |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from typing import Optional, Tuple | ||
from diffusers import StableDiffusionPipeline | ||
|
||
import hidet | ||
from hidet.apps.diffusion.app import DiffusionApp | ||
from hidet.apps.diffusion.modeling.pretrained import PretrainedModelForDiffusion | ||
from hidet.apps.hf import load_diffusion_pipeline | ||
from hidet.graph.flow_graph import trace_from | ||
from hidet.graph import FlowGraph | ||
from hidet.graph.tensor import symbol, Tensor | ||
from hidet.runtime.compiled_app import create_compiled_app | ||
|
||
|
||
def _build_unet_down_graph( | ||
model: PretrainedModelForDiffusion, | ||
dtype: str = "float32", | ||
device: str = "cuda", | ||
batch_size: int = 2, | ||
num_channels_latents: int = 4, | ||
height: int = 96, | ||
width: int = 96, | ||
embed_length: int = 77, | ||
embed_hidden_dim: int = 1024, | ||
kernel_search_space: int = 2, | ||
): | ||
latent_model_input: Tensor = symbol([batch_size, num_channels_latents, height, width], dtype=dtype, device=device) | ||
timesteps: Tensor = symbol([batch_size], dtype="int64", device=device) | ||
prompt_embeds: Tensor = symbol([batch_size, embed_length, embed_hidden_dim], dtype=dtype, device=device) | ||
|
||
inputs = (latent_model_input, timesteps, prompt_embeds) | ||
outputs = sample, emb, down_block_residual_samples = model.forward_down(*inputs) | ||
graph: FlowGraph = trace_from([sample, emb, *down_block_residual_samples], list(inputs)) | ||
|
||
graph = hidet.graph.optimize(graph) | ||
|
||
compiled_graph = graph.build(space=kernel_search_space) | ||
|
||
return compiled_graph, inputs, outputs | ||
|
||
|
||
def _build_unet_mid_graph( | ||
model: PretrainedModelForDiffusion, | ||
sample: Tensor, | ||
emb: Tensor, | ||
encoder_hidden_states: Tensor, | ||
kernel_search_space: int = 2, | ||
): | ||
sample, emb, encoder_hidden_states = tuple( | ||
symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in (sample, emb, encoder_hidden_states) | ||
) | ||
|
||
output = model.forward_mid(sample, emb, encoder_hidden_states) | ||
|
||
graph: FlowGraph = trace_from(output, [sample, emb, encoder_hidden_states]) | ||
|
||
graph = hidet.graph.optimize(graph) | ||
|
||
compiled_graph = graph.build(space=kernel_search_space) | ||
|
||
return compiled_graph, output | ||
|
||
|
||
def _build_unet_up_graph( | ||
model: PretrainedModelForDiffusion, | ||
sample: Tensor, | ||
emb: Tensor, | ||
encoder_hidden_states: Tensor, | ||
down_block_residuals: Tuple[Tensor, ...], | ||
kernel_search_space: int = 2, | ||
): | ||
sample, emb, encoder_hidden_states = tuple( | ||
symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in (sample, emb, encoder_hidden_states) | ||
) | ||
down_block_residuals = tuple(symbol(list(x.shape), dtype=x.dtype, device=x.device) for x in down_block_residuals) | ||
output = model.forward_up(sample, emb, encoder_hidden_states, down_block_residuals) | ||
|
||
graph: FlowGraph = trace_from(output, [sample, emb, encoder_hidden_states, *down_block_residuals]) | ||
|
||
graph = hidet.graph.optimize(graph) | ||
|
||
compiled_graph = graph.build(space=kernel_search_space) | ||
|
||
return compiled_graph, output | ||
|
||
|
||
def create_stable_diffusion( | ||
name: str, | ||
revision: Optional[str] = None, | ||
dtype: str = "float32", | ||
device: str = "cuda", | ||
batch_size: int = 1, | ||
height: int = 768, | ||
width: int = 768, | ||
kernel_search_space: int = 2, | ||
): | ||
hf_pipeline: StableDiffusionPipeline = load_diffusion_pipeline(name=name, revision=revision, device=device) | ||
# create the hidet model and load the pretrained weights from huggingface | ||
model: PretrainedModelForDiffusion = PretrainedModelForDiffusion.create_pretrained_model( | ||
name, revision=revision, device=device, dtype=dtype | ||
) | ||
|
||
unet_down_graph, inputs, outputs = _build_unet_down_graph( | ||
model, | ||
dtype=dtype, | ||
device=device, | ||
batch_size=batch_size * 2, # double size for prompt/negative prompt | ||
num_channels_latents=model.config["in_channels"], | ||
height=height // model.vae_scale_factor, | ||
width=width // model.vae_scale_factor, | ||
embed_length=model.embed_max_length, | ||
embed_hidden_dim=model.embed_hidden_dim, | ||
kernel_search_space=kernel_search_space, | ||
) | ||
|
||
_, _, prompt_embeds = inputs | ||
sample, emb, down_block_residual_samples = outputs | ||
|
||
unet_mid_graph, sample = _build_unet_mid_graph( | ||
model, sample=sample, emb=emb, encoder_hidden_states=prompt_embeds, kernel_search_space=kernel_search_space | ||
) | ||
|
||
unet_up_graph, sample = _build_unet_up_graph( | ||
model, | ||
sample=sample, | ||
emb=emb, | ||
encoder_hidden_states=prompt_embeds, | ||
down_block_residuals=down_block_residual_samples, | ||
kernel_search_space=kernel_search_space, | ||
) | ||
|
||
return DiffusionApp( | ||
compiled_app=create_compiled_app( | ||
graphs={"unet_down": unet_down_graph, "unet_mid": unet_mid_graph, "unet_up": unet_up_graph}, | ||
modules={}, | ||
tensors={}, | ||
attributes={}, | ||
name=name, | ||
), | ||
hf_pipeline=hf_pipeline, | ||
height=height, | ||
width=width, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .stable_diffusion import * | ||
from .pretrained import PretrainedModelForDiffusion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import Optional | ||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline | ||
|
||
from hidet.apps.hf import load_diffusion_pipeline | ||
from hidet.apps.pretrained import PretrainedModel | ||
|
||
|
||
class PretrainedModelForDiffusion(PretrainedModel): | ||
@classmethod | ||
def create_pretrained_model( | ||
cls, | ||
name: str, | ||
revision: Optional[str] = None, | ||
hf_pipeline: Optional[StableDiffusionPipeline] = None, | ||
dtype: Optional[str] = None, | ||
device: str = "cuda", | ||
): | ||
# load the pretrained huggingface model | ||
# note: diffusers pipeline is more similar to a model than transformers pipeline | ||
if hf_pipeline is None: | ||
hf_pipeline: StableDiffusionPipeline = load_diffusion_pipeline(name=name, revision=revision, device=device) | ||
|
||
pipeline_config = hf_pipeline.config | ||
|
||
torch_unet = hf_pipeline.unet | ||
pretrained_unet_class = cls.load_module(pipeline_config["unet"][1]) | ||
|
||
hidet_unet = pretrained_unet_class( | ||
**dict(torch_unet.config), | ||
vae_scale_factor=hf_pipeline.vae_scale_factor, | ||
embed_max_length=hf_pipeline.text_encoder.config.max_position_embeddings, | ||
embed_hidden_dim=hf_pipeline.text_encoder.config.hidden_size | ||
) | ||
|
||
hidet_unet.to(dtype=dtype, device=device) | ||
|
||
cls.copy_weights(torch_unet, hidet_unet) | ||
|
||
return hidet_unet | ||
|
||
@property | ||
def embed_max_length(self): | ||
raise NotImplementedError() | ||
|
||
@property | ||
def embed_hidden_dim(self): | ||
raise NotImplementedError() | ||
|
||
@property | ||
def vae_scale_factor(self): | ||
raise NotImplementedError() | ||
|
||
def forward_down(self, *args, **kwargs): | ||
raise NotImplementedError() | ||
|
||
def forward_mid(self, *args, **kwargs): | ||
raise NotImplementedError() | ||
|
||
def forward_up(self, *args, **kwargs): | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,24 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from transformers import AutoConfig, PretrainedConfig | ||
from diffusers import StableDiffusionPipeline | ||
|
||
import hidet | ||
|
||
|
||
def _get_hf_auth_token(): | ||
return hidet.option.get_option('auth_tokens.for_huggingface') | ||
|
||
|
||
def load_pretrained_config(model: str, revision: Optional[str] = None) -> PretrainedConfig: | ||
huggingface_token = hidet.option.get_option('auth_tokens.for_huggingface') | ||
huggingface_token = _get_hf_auth_token() | ||
return AutoConfig.from_pretrained(model, revision=revision, token=huggingface_token) | ||
|
||
|
||
def load_diffusion_pipeline(name: str, revision: Optional[str] = None, device: str = "cuda") -> StableDiffusionPipeline: | ||
huggingface_token = _get_hf_auth_token() | ||
with torch.device(device): | ||
return StableDiffusionPipeline.from_pretrained( | ||
pretrained_model_name_or_path=name, torch_dtype=torch.float32, revision=revision, token=huggingface_token | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.