Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Optimize the VAE module #667

Closed
strint opened this issue Feb 23, 2024 · 3 comments
Closed

[Feature Request] Optimize the VAE module #667

strint opened this issue Feb 23, 2024 · 3 comments
Assignees
Labels

Comments

@strint
Copy link
Collaborator

strint commented Feb 23, 2024

No description provided.

@strint strint added the Request-bug Something isn't working label Feb 23, 2024
@strint strint self-assigned this Feb 23, 2024
@strint strint added Request-new_feature Request for a new feature Rsp-triaged and removed Request-bug Something isn't working labels Feb 23, 2024
@strint strint added this to the v0.12.1 milestone Feb 23, 2024
@aifartist
Copy link

Benchmark of no-compile, onediff, and stable-fast compiler.
1000 executions of vae in tight loop using the latent output from the UNet saved to a file. Latent is batchsize=12, 512x512 1 step sd-turbo. VAE is TinyVAE. Total time for 1000 executions in seconds after 5 warmup vae decode calls.

25.40 no compiler
19.36 onediff
11.94 stable-fast

Because of this to hit my goal of 200 images per second with 1 step sd-turbo I compile the Unet with onediff and the vae with stable-fast. I can avg 5 milliseconds per image using batchsize=12 on my 4090 doing this plus my own perf optimizations.

@aifartist
Copy link

While my 200 images/sec is purely a tech demo the usage of the TinyVAE is needed for real time video generation. Using 4 step LCM and TinyVAE I can generate single frame images, no batching at 512x512, in about 37ms. This gets me to 27 fps which is over the 24 fps minimum standard for relatively smooth videos.

NOTE: I've been busy with this other exploration and have yet to try your video optimizations to my camera -> LCM sd1.5 -> video demo. It'll be interesting to see if I can get to 30 fps with your optimizations to the UNet.

@strint strint modified the milestones: v0.12.1, v0.13.0 Mar 1, 2024
@strint strint added Response-need_weeks This issue need some weeks to be solved and removed Response-triaged labels Mar 9, 2024
@clackhan
Copy link
Contributor

clackhan commented Mar 15, 2024

@aifartist

Tiny VAE optimization has been completed.

Dependency:

With the next test code on A100-PCIE-40GB and set the arg --fuse-conv-bias-add-act, the execution time of Tiny VAE was reduced by approximately 40%.

The test example:

import os
import argparse
from onediffx import compile_pipe, compiler_config
from diffusers import LatentConsistencyModelImg2ImgPipeline, AutoencoderTiny
import torch
from PIL import Image
import time

def parse_args():
    parser = argparse.ArgumentParser(description="Simple demo of image generation.")
    parser.add_argument(
        "--model-id",
        type=str,
        default="SimianLuo/LCM_Dreamshaper_v7",
    )
    parser.add_argument("--bs", type=int, default=3)
    parser.add_argument("--warmup", type=int, default=1)
    parser.add_argument("--inference-num", type=int, default=3)
    parser.add_argument(
        "--vae-id", type=str, default="madebyollin/taesd", choices=["none", "madebyollin/taesd", "madebyollin/taesdxl"]
    )
    parser.add_argument("--fuse-conv-bias-add-act", action="store_true")
    parser.add_argument("--nsys", action="store_true")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--sfast", action="store_true")
    cmd_args = parser.parse_args()
    return cmd_args

args = parse_args()

def sfast_compile_model(model):
    from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig
    config = CompilationConfig.Default()
    # xformers and Triton are suggested for achieving best performance.
    # It might be slow for Triton to generate, compile and fine-tune kernels.
    try:
        import xformers
        config.enable_xformers = True
    except ImportError:
        print("xformers not installed, skip")
    # NOTE:
    # When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
    # Disable Triton if you encounter this problem.
    try:
        import triton
        config.enable_triton = True
    except ImportError:
        print("Triton not installed, skip")
    # NOTE:
    config.enable_cuda_graph = True
    if args.nsys:
       config.enable_cuda_graph = False 
    model = compile(model, config)
    return model


if args.fuse_conv_bias_add_act:
    os.environ["ONEFLOW_CONVOLUTION_BIAS_ADD_ACT_FUSION"] = "1"

WIDTH = 512
HEIGHT = 512


pipe = LatentConsistencyModelImg2ImgPipeline.from_pretrained(
    args.model_id,
    safety_checker=None,
)
if args.vae_id != "none":
    pipe.vae = AutoencoderTiny.from_pretrained(args.vae_id, torch_dtype=torch.float16)

pipe = pipe.to("cuda", torch.float16)

if args.sfast:
    print("With stable fast Compile")
    pipe = sfast_compile_model(pipe)
else:
    print("With OneDiff Compile")
    pipe = compile_pipe(pipe)


print("batch size = ", args.bs)
input_image = [Image.new("RGB", (WIDTH, HEIGHT))] * args.bs
prompt_embeds = [torch.randn((1, 77, 768))] * args.bs
prompt_embeds = torch.cat(prompt_embeds, dim=0)
for _ in range(args.warmup):
    _ = pipe(
        prompt_embeds=prompt_embeds,
        generator=None,
        image=input_image,
        strength=0.5,
        num_inference_steps=8,
        guidance_scale=1.,
        width=WIDTH,
        height=HEIGHT,
        original_inference_steps=1000,
        output_type="np",
    ).images[0]
for _ in range(args.inference_num):
    inf_start = time.time()
    _ = pipe(
        prompt_embeds=prompt_embeds,
        generator=None,
        image=input_image,
        strength=0.5,
        num_inference_steps=8,
        guidance_scale=1.,
        width=WIDTH,
        height=HEIGHT,
        original_inference_steps=1000,
        output_type="np",
    ).images[0]
    t_inf = time.time() - inf_start
    print(f"t_inf = {t_inf}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants