Skip to content

Commit

Permalink
Turbo support #2
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 17, 2024
1 parent 8348ff7 commit 11430ee
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def compile_to_vmfb(
elif ireec_flags == None:
ireec_flags = []

debug = True
debug = False
if debug:
flags.extend(
[
Expand Down
22 changes: 22 additions & 0 deletions models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,28 @@
}
"""

sdxl_turbo_sched_unet_bench_f16 = """
module @sdxl_compiled_pipeline {
func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<1x6xf16>, tensor<i64>) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"}
func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf16>, %arg1: tensor<1x64x2048xf16>, %arg2: tensor<1x1280xf16>, %arg3: tensor<1x6xf16>, %arg4: tensor<1xf16>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"}
func.func @produce_image_latents(%sample: tensor<1x4x128x128xf16>, %p_embeds: tensor<1x64x2048xf16>, %t_embeds: tensor<1x1280xf16>, %guidance_scale: tensor<1xf16>) -> tensor<1x4x128x128xf16> {
%noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<1x6xf16>, tensor<i64>)
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%steps_int = tensor.extract %steps[] : tensor<i64>
%n_steps = arith.index_cast %steps_int: i64 to index
%res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<1x4x128x128xf16>) {
%step_64 = arith.index_cast %arg0 : index to i64
%this_step = tensor.from_elements %step_64 : tensor<1xi64>
%inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf16>, tensor<1x64x2048xf16>, tensor<1x1280xf16>, tensor<1x6xf16>, tensor<1xf16>, tensor<1xi64>) -> tensor<1x4x128x128xf16>
scf.yield %inner : tensor<1x4x128x128xf16>
}
return %res : tensor<1x4x128x128xf16>
}
}
"""

sdxl_sched_unet_bench_f32 = """
module @sdxl_compiled_pipeline {
func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor<i64>) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from turbine_models.custom_models.sdxl_inference.pipeline_ir import (
sdxl_sched_unet_bench_f32,
sdxl_sched_unet_bench_f16,
sdxl_turbo_sched_unet_bench_f16,
sdxl_pipeline_bench_f32,
sdxl_pipeline_bench_f16,
)
Expand Down Expand Up @@ -354,6 +355,9 @@ def export_submodel(
if self.precision == "fp32"
else sdxl_sched_unet_bench_f16
)
if self.do_classifier_free_guidance == False:
assert self.precision == "fp16", "turbo only supported in fp16 precision."
pipeline_file = sdxl_turbo_sched_unet_bench_f16
pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
self.device,
Expand Down Expand Up @@ -551,7 +555,6 @@ def generate_images(

for i in range(batch_count):
unet_start = time.time()

latents = self.runners["pipe"].ctx.modules.sdxl_compiled_pipeline[
"produce_image_latents"
](samples[i], prompt_embeds, add_text_embeds, guidance_scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
import shark_turbine.ops as ops
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down Expand Up @@ -79,9 +80,10 @@ def initialize(self, sample):
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
add_time_ids = add_time_ids.repeat(sample.shape[0], 1).type(self.dtype)
add_time_ids = torch.tensor([add_time_ids], dtype=self.dtype)
if self.do_classifier_free_guidance:
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
add_time_ids = add_time_ids.repeat(sample.shape[0], 1).type(self.dtype)
timesteps = self.scheduler.timesteps
step_indexes = torch.tensor(len(timesteps))
sample = sample * self.scheduler.init_noise_sigma
Expand Down

0 comments on commit 11430ee

Please sign in to comment.