diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 448b6791..180e316e 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -146,6 +146,70 @@ def iree_backend_map(device): return iree_device +def replace_with_tk_kernels( + flow_dialect_ir, +): + kernels = [ + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/tk_int8/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" + ] + + # Replace all calls to old kernel with new kernel + print("Inserting kernels and updating calls to kernels...") + kernel_name = {} + for kernel in kernels: + kernel_name[kernel] = kernel.split("/")[-1].split(".")[0] + kernel_map = {} + prefix_map = {} + + base = flow_dialect_ir.split("\n") + new_base = [] + for line in base: + for kernel in kernels: + suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1] + bias_explicit = False + if "bias" in suffix: + bias_explicit = True + kernel_args = 3 + int(suffix[4:]) + suffix = kernel.split(".")[0].split("_")[-2] + B, M, N, K = suffix.split("x") + old_kernel = f"matmul_transpose_b_{B}x{M}x{N}x{K}" + if not old_kernel in line: + continue + if old_kernel in line and "func.func" in line: + if bias_explicit: + num_args = line.count("arg") + if num_args != kernel_args: + continue + kernel_map[kernel] = line.strip().split(" ")[1][1:-7] + prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-7] + if ( + old_kernel in line + and "flow.dispatch" in line + and not "func.func" in line + ): + line = line.replace(kernel_map[kernel], kernel_name[kernel]) + line = line.replace(prefix_map[kernel], kernel_name[kernel]) + new_base.append(line) + # Insert kernels in appropriate locations + final_ir = [] + for line in new_base: + for kernel in kernels: + if ( + prefix_map[kernel] + " {" in line + and "flow.executable" in line + and "private" in line + ): + data = urlopen(kernel).read().decode("utf-8") + data = data.split("\n") + translation_info = data[0].split("#translation = ")[1].strip() + data[10] = data[10].replace("#translation", translation_info) + final_ir.append("\n".join(data[2:-3])) + final_ir.append(line) + + print("tk kernels added") + return final_ir + + def compile_to_vmfb( module_str, device, @@ -161,6 +225,7 @@ def compile_to_vmfb( winograd=False, flagset_keywords=[], debug=False, + add_tk_kernels=False, ): flags = [] if mlir_source == "file" and not isinstance(module_str, str): @@ -296,6 +361,34 @@ def compile_to_vmfb( for idx, flag in enumerate(flags): if flag is None: flags.pop(idx) + input_ir_type = "torch" + if add_tk_kernels: + print("Adding tk kernels") + flags.extend(["--compile-to=flow"]) + if mlir_source == "file": + flatbuffer_blob = ireec.compile_file( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + elif mlir_source == "str": + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + + flow_ir = flatbuffer_blob.decode("utf-8") + + flow_ir_tk = replace_with_tk_kernels(flow_ir) + module_str = "\n".join(flow_ir_tk) + flags.pop() + flags.extend(["--compile-from=flow"]) + mlir_source = "str" + input_ir_type = "auto" + print("Compiling to", device, "with flags:", flags) # Forces a standard for naming files: @@ -312,7 +405,7 @@ def compile_to_vmfb( flatbuffer_blob = ireec.compile_file( module_str, target_backends=[device], - input_type="torch", + input_type=input_ir_type, extra_args=flags, ) elif mlir_source == "str": @@ -323,7 +416,7 @@ def compile_to_vmfb( flatbuffer_blob = ireec.compile_str( module_str, target_backends=[device], - input_type="torch", + input_type=input_ir_type, extra_args=flags, ) else: @@ -431,11 +524,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["EulerAncestralDiscrete"] = ( - EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", ) # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( # model_id, diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py index 368fb0d7..2ecd2c17 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -369,5 +369,11 @@ def is_valid_file(arg): help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", ) +p.add_argument( + "--add_tk_kernels", + type=bool, + default=False, + help="Flag to add compiled tk kernels.", +) args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 0ce7a808..7762cd83 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -187,6 +187,7 @@ def export_unet_model( weights_only=False, use_punet=False, quant_paths=None, + add_tk_kernels=False, ): if use_punet: submodel_name = "punet" @@ -209,6 +210,10 @@ def export_unet_model( if decomp_attn == True: ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False" + # Currently, only int8 tk kernels are integrated + if add_tk_kernels and precision != "i8": + add_tk_kernels = False + if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, @@ -220,6 +225,7 @@ def export_unet_model( return_path=not exit_on_vmfb, attn_spec=attn_spec, flagset_keywords=["punet"] if use_punet else [], + add_tk_kernels=add_tk_kernels, ) return vmfb_path elif use_punet: @@ -355,6 +361,7 @@ class CompiledUnet(CompiledModule): return_path=True, attn_spec=attn_spec, flagset_keywords=["punet"] if use_punet else [], + add_tk_kernels=add_tk_kernels, ) if exit_on_vmfb: exit() @@ -393,6 +400,7 @@ class CompiledUnet(CompiledModule): args.decomp_attn, attn_spec=args.attn_spec, input_mlir=args.input_mlir, + add_tk_kernels=args.add_tk_kernels, ) if args.input_mlir: exit()