Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no speed up #19

Open
liujianzhi opened this issue Apr 10, 2023 · 13 comments
Open

no speed up #19

liujianzhi opened this issue Apr 10, 2023 · 13 comments

Comments

@liujianzhi
Copy link

I have the same issues with #6. I test the tome without xformers using 3090, the inference speed is the same with no TOME. I make TOME apply to all three. I use the 512x512 images, The result is 7036s with TOME and 6827s without TOME on inference 2000 images. Why is that?

@dbolya
Copy link
Owner

dbolya commented Apr 10, 2023

If you're not using xformers, ToMe should increase inference speed even with 512x512 images so there's something going wrong here.

A couple of steps to help debug:

  1. Are you sure ToMe is being called? You can test this by adding a print statement in the bipartite matching function on this line.
  2. Are you using the system for other things while you run your benchmark? If so, then that might skew the results.
  3. You might need to use a bigger batch size to see the full effect. I used a batch size of 3 for my benchmarks.

Also, what stable diffusion code base are you using? Is it one of the supported ones?

@liujianzhi
Copy link
Author

Thank you for your comment!

  1. I have added a print statement in that line. And I find it output when running.
  2. I am running a StableDiffusionPipeline==0.11.1 from diffusers. I run an ablation study at the same time on different 3090. (w/o TOME, r=50% TOME and r=90% TOME).
  3. I think stable diffusion only accepts batch size=1 for input.
    My code is here:
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")

tomesd.apply_patch(pipe, merge_attn=True, merge_crossattn=True, merge_mlp=True, ratio=0.9)

import time

a = time.time()

for i in range(2000):
    image = pipe("cat").images[0]

print(time.time() - a)```

@ethansmith2000
Copy link

ethansmith2000 commented Apr 13, 2023

@liujianzhi @dbolya I am having a similar issue running on an a100.

my baseline time is 25it/s on float16

at 50% ratio
my time only gets up to about 25.5it/s on average
at 75%
it reaches 26.2it/s but also the image quality is pretty demolished

anything below that, the speed actually gets slower
at 10% ratio the speed caps out at 22it/s

Additionally, increasing max_downsample further slows it down

I imagine the merging/unmerging process is causing much of the latency?

--
However trying at higher batch sizes and resolutions, i do start to see some benefits
512x512 batch 8:
baseline: 5.34it/s
ratio 50%, downsample=1: 7.21it/s
ratio 25% downsample=1 5.92 it/s
ratio 50% downsample=8, 7.18it/s
ratio 25%, downsample=8 5.80it/s

1024x1024, batch size 1 (these were more exciting :))
baseline: 4.39 it/s
ratio 50% downsample 1 8.10 it/s
ratio 25% downsample=1 5.94it/s
ratio 50% downsample=2 8.40 it/s
ratio 50% downsample=4 8.35 it/s
ratio 50% downsample=4 8.31 it/s

the benefit from downsample 1 to 2, but then decrease at higher values seems to support issues with latency from merging process itself since i imagine doing token merging when the hidden states are already quite small isn't as useful.

2.35
1024x1024 batch size 4
Baseline: OUT OF MEMORY
ratio 50% downsample 1 2.35it/s
ratio 50% downsample 2 2.49it/s
ratio 75% downsample=2 3.58it/s

1024x1024 batch size 8
Baseline: OUT OF MEMORY
ratio 50% downsample 1 OUT OF MEMORY
ratio 50% downsample=2 OUT OF MEMORY
ratio 75% downsample=2 OUT OF MEMORY

@dbolya
Copy link
Owner

dbolya commented Apr 14, 2023

@liujianzhi
Hmm, if you're not using xformers, you should see a speed-up in your case. I wonder if tomesd is broken for diffusers 0.11.1? I've been using diffusers==0.14.0. If possible, could you try upgrading?

@ethansmith2000
This is expected behavior when using xformers (which I now realize you might not have been using given your edit, but the rest still applies with bs=1 on a beefy card like an A100). That's because computing the merge for ToMe comes at a cost, which normally is small compared to the rest of the network. However, with xformers speeding up the network, this cost is very noticeable with small images. Hence my recommendation to only use ToMe + xformers for large images.

Similarly, that's why the default downsample is at 1, where there are the most tokens. Even without xformers, there's not much benefit of applying ToMe deeper into the network where there are fewer tokens (see the ablations in the paper).

@ethansmith2000
Copy link

Thank you daniel, makes sense!

@liujianzhi
Copy link
Author

Sorry. I upgrade my diffusers==0.14.0. And I still find that 5948.799s (w/o TOME) and 5623.201s (r=90% TOME) in generating 2000 images.

@dbolya
Copy link
Owner

dbolya commented Apr 14, 2023

Sorry. I upgrade my diffusers==0.14.0. And I still find that 5948.799s (w/o TOME) and 5623.201s (r=90% TOME) in generating 2000 images.

I just tried to reproduce using your code above:

  1. Without ToMe I get ~26 it/s (~4000s total for 2k images).
  2. With ToMe and ratio=0.5 (no other arguments) I get ~32 it/s (~3100s total)
  3. With ToMe and the options you used (merge all, ratio=0.75) I get ~35 it/s (~2800s total). Note that with your settings the max ratio is 0.75 not 0.9, so it gets rounded down.

Now these results aren't stellar, probably because of the batch size of 1, but they at least improve. Can you try running with a higher image / batch size? The way to do that is to pass arguments when you call pipe. For instance:

for i in range(2000//5):
    image = pipe("cat", num_images_per_prompt=5).images[0]

For a batch size of 5. Similarly, you can use height=1024, width=1024 for bigger images.

Running with a batch size of 5, I get ~2000s for ToMe with 3 from above and ~3300s without (i.e., 1 from above).
That's closer to the speed-up I'd expect, but still off. Maybe huggingface diffusers starts with more optimizations applied than the original stable diffusion repo, or maybe there's more overhead elsewhere in the pipeline. But even still, you shouldn't be seeing slowdowns from using ToMe on a 3090 (assuming you're not using xformers).

@alex-bene
Copy link

Hey there, I'm reviving this issue as I was facing a similar problem and I think I found out a part of what was going on in my case. I am using a single A10 GPU, and I'm trying to reproduce @dbolya results.

It turns out that pytorch==2.0.1 has a built-in efficient implementation for attention, like xformers, and diffusers==0.16.1 use this by default. As a result, when trying to reproduce the numbers from the paper, I initially got disappointing results.

However, resetting to the default attention (non-optimized) using pipe.unet.set_default_attn_processor(), now the results are much better, but still far from the ones on the paper! @dbolya do you think this might be due to overhead in the pipeline and/or general improvements in the model that make the ToMe contribution not as pronounced as in the original codebase? Am I missing something?

Main Packages Versions

python==3.10.11
diffusers==0.16.1
pytorch==2.0.1

Minimum code example to reproduce

import time
import tomesd
import torch
from diffusers import StableDiffusionPipeline
from tqdm import tqdm, trange

def test(batch_size=1, total_number_of_images=20, tome_ratio=0):
    pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
    ).to("cuda")

    if tome_ratio > 0:
        tomesd.apply_patch(
            pipe, ratio=tome_ratio
        )  # Can also use pipe.unet in place of pipe here

    pipe.set_progress_bar_config(disable=True)
    # skip efficient torch.nn.functional.scaled_dot_product_attention based attention
    pipe.unet.set_default_attn_processor()

    start = time.time()
    for i in trange(total_number_of_images//batch_size):
        images = pipe(
            ["A photo of a dog riding a bicycle"],
            num_images_per_prompt=batch_size,
            generator=torch.manual_seed(4251142),
        ).images

    return (time.time()-start)/total_number_of_images

def ddict():
    return defaultdict(ddict)

runtimes = ddict()

## Run Experiments
for batch_size in [1, 2, 4, 5, 10]:
    print(f"Batch size: {batch_size}")
    for tome_ratio in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]:
        print(f"\tToMe ratio: {tome_ratio}")
        runtimes[batch_size][tome_ratio] = test(
            batch_size=batch_size, total_number_of_images=20, tome_ratio=tome_ratio
        )
        print(f"\t\tRuntime: {runtimes[batch_size][tome_ratio]:.3f}")

## Print Results
for batch_size, runtimes_batch in runtimes.items():
    no_tome = runtimes_batch[0]
    print(f"Batch size: {batch_size}.")
    for tome_ratio, runtimes_batch_ratio in runtimes_batch.items():
        time_perc = 100*(no_tome - runtimes_batch_ratio)/no_tome
        speed_perc = 100*(no_tome - runtimes_batch_ratio)/runtimes_batch_ratio
        print(f"  ToMe ratio: {tome_ratio:.1f} -- runtime reduction: {time_perc:5.2f}% -- speed increase: {speed_perc:5.2f}%")

Results

Results with no memory-efficient attention

Batch size: 1.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -2.75% -- speed increase: -2.67%
  ToMe ratio: 0.2 -- runtime reduction:  4.09% -- speed increase:  4.27%
  ToMe ratio: 0.3 -- runtime reduction: 10.43% -- speed increase: 11.64%
  ToMe ratio: 0.4 -- runtime reduction: 16.16% -- speed increase: 19.27%
  ToMe ratio: 0.5 -- runtime reduction: 21.43% -- speed increase: 27.28%
  ToMe ratio: 0.6 -- runtime reduction: 23.97% -- speed increase: 31.53%
Batch size: 2.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -4.79% -- speed increase: -4.57%
  ToMe ratio: 0.2 -- runtime reduction:  3.11% -- speed increase:  3.21%
  ToMe ratio: 0.3 -- runtime reduction:  9.73% -- speed increase: 10.78%
  ToMe ratio: 0.4 -- runtime reduction: 16.22% -- speed increase: 19.36%
  ToMe ratio: 0.5 -- runtime reduction: 22.89% -- speed increase: 29.68%
  ToMe ratio: 0.6 -- runtime reduction: 25.23% -- speed increase: 33.74%
Batch size: 4.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -6.14% -- speed increase: -5.79%
  ToMe ratio: 0.2 -- runtime reduction:  3.58% -- speed increase:  3.71%
  ToMe ratio: 0.3 -- runtime reduction: 10.18% -- speed increase: 11.34%
  ToMe ratio: 0.4 -- runtime reduction: 16.55% -- speed increase: 19.83%
  ToMe ratio: 0.5 -- runtime reduction: 24.01% -- speed increase: 31.60%
  ToMe ratio: 0.6 -- runtime reduction: 26.53% -- speed increase: 36.10%
Batch size: 5.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -5.18% -- speed increase: -4.92%
  ToMe ratio: 0.2 -- runtime reduction:  3.88% -- speed increase:  4.03%
  ToMe ratio: 0.3 -- runtime reduction: 10.39% -- speed increase: 11.59%
  ToMe ratio: 0.4 -- runtime reduction: 17.37% -- speed increase: 21.03%
  ToMe ratio: 0.5 -- runtime reduction: 24.37% -- speed increase: 32.22%
  ToMe ratio: 0.6 -- runtime reduction: 27.11% -- speed increase: 37.19%
Batch size: 10.
  ToMe ratio: 0.0 -- ΟΟΜ
  ToMe ratio: 0.1 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.2 -- runtime reduction:  6.22% -- speed increase:  6.63%
  ToMe ratio: 0.3 -- runtime reduction: 11.22% -- speed increase: 12.64%
  ToMe ratio: 0.4 -- runtime reduction: 17.42% -- speed increase: 21.10%
  ToMe ratio: 0.5 -- runtime reduction: 25.86% -- speed increase: 34.88%
  ToMe ratio: 0.6 -- runtime reduction: 27.91% -- speed increase: 38.71%

Results with memory-efficient attention

For the sake of completeness, here are the results using the efficient torch.nn.functional.scaled_dot_product_attention based attention, i.e. with this line pipe.unet.set_default_attn_processor() commented.

Batch size: 1.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -3.11% -- speed increase: -3.02%
  ToMe ratio: 0.2 -- runtime reduction: -2.54% -- speed increase: -2.48%
  ToMe ratio: 0.3 -- runtime reduction: -0.66% -- speed increase: -0.65%
  ToMe ratio: 0.4 -- runtime reduction:  1.40% -- speed increase:  1.42%
  ToMe ratio: 0.5 -- runtime reduction:  3.15% -- speed increase:  3.26%
  ToMe ratio: 0.6 -- runtime reduction:  4.69% -- speed increase:  4.92%
Batch size: 2.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -0.25% -- speed increase: -0.25%
  ToMe ratio: 0.2 -- runtime reduction:  2.37% -- speed increase:  2.43%
  ToMe ratio: 0.3 -- runtime reduction:  5.41% -- speed increase:  5.72%
  ToMe ratio: 0.4 -- runtime reduction:  7.43% -- speed increase:  8.02%
  ToMe ratio: 0.5 -- runtime reduction:  9.64% -- speed increase: 10.67%
  ToMe ratio: 0.6 -- runtime reduction: 11.24% -- speed increase: 12.67%
Batch size: 4.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -0.45% -- speed increase: -0.44%
  ToMe ratio: 0.2 -- runtime reduction:  2.36% -- speed increase:  2.42%
  ToMe ratio: 0.3 -- runtime reduction:  4.53% -- speed increase:  4.74%
  ToMe ratio: 0.4 -- runtime reduction:  7.20% -- speed increase:  7.76%
  ToMe ratio: 0.5 -- runtime reduction:  9.43% -- speed increase: 10.41%
  ToMe ratio: 0.6 -- runtime reduction: 10.90% -- speed increase: 12.24%
Batch size: 5.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -1.03% -- speed increase: -1.02%
  ToMe ratio: 0.2 -- runtime reduction:  2.72% -- speed increase:  2.79%
  ToMe ratio: 0.3 -- runtime reduction:  4.72% -- speed increase:  4.96%
  ToMe ratio: 0.4 -- runtime reduction:  7.64% -- speed increase:  8.27%
  ToMe ratio: 0.5 -- runtime reduction:  9.98% -- speed increase: 11.09%
  ToMe ratio: 0.6 -- runtime reduction: 10.58% -- speed increase: 11.83%
Batch size: 10.
  ToMe ratio: 0.0 -- runtime reduction:  0.00% -- speed increase:  0.00%
  ToMe ratio: 0.1 -- runtime reduction: -0.73% -- speed increase: -0.72%
  ToMe ratio: 0.2 -- runtime reduction:  2.53% -- speed increase:  2.60%
  ToMe ratio: 0.3 -- runtime reduction:  5.22% -- speed increase:  5.51%
  ToMe ratio: 0.4 -- runtime reduction:  7.68% -- speed increase:  8.32%
  ToMe ratio: 0.5 -- runtime reduction: 10.23% -- speed increase: 11.39%
  ToMe ratio: 0.6 -- runtime reduction: 11.83% -- speed increase: 13.42%

Discussion Points

  • The results still do not agree with the ones in the paper
  • From the experiment above, it seems like, regardless of the batch size (except for the batch size of 1), the maximum achievable increase in speed (up to 0.6 compression ratio) for 512x512 images does not match the 2x performance gains on the paper.
  • Also, regardless of the batch size, at lower ratios (~0.1), we observe a decrease in speed!

@dbolya Is this the expected performance boost or am I missing something here? Does anyone else have similar results?

@dbolya
Copy link
Owner

dbolya commented Jun 16, 2023

Hi @alex-bene, thanks for the detailed write-up. The experiments in the paper were performed in the original stable diffusion repo (namely the runway-ml one). I think users have consistently found that the diffusers implementation doesn't give them the same speed-up. Perhaps diffusers does a bunch of extra things like different memory management? Even still, 38% does seem low. I was getting at least 60% or higher with 0.5 reduction using diffusers (using a 4090). Unsure why you would be getting such low results.

As for the torch SDPA: in performance that should be equivalent to "xformers" or "flash attn", which I already have a disclaimer about in the readme / paper. For small images, that means not much extra speed up when using ToMe on top. But, for bigger images ToMe still leads to a large speed-up there. The ToMe + xFormers figure in the paper used a 2048px image, for instance.

@alex-bene
Copy link

 Hey @dbolya and thanks for the quick response.

  • The sdpa results are there mostly for reference for other people that get similar results and can't figure out why. I (eventually) figured out that this was automatically used and was the main cause for the minimal improvement.
  • Regarding the higher resolution images, I can confirm that resizing the 512x512 output of the plain SD to 2048x2048 and passing it through an img2img SD pipeline yields 2.9 times faster generation (for the img2img) when using ToMe which is amazing! (an extra reason why I'm trying to understand why I can reproduce more significant improvement with the simple 512x512 generation)
  • Regarding the pipeline with sdpa disabled, do you have any idea what might be going on and how I could fix it? Do you think you could try and reproduce the results with the latest version of diffusers and pytorch that I reference in the previous comment maybe just for the sake of verifying the experimental results??

@dbolya
Copy link
Owner

dbolya commented Jun 16, 2023

Regarding the pipeline with sdpa disabled, do you have any idea what might be going on and how I could fix it? Do you think you could try and reproduce the results with the latest version of diffusers and pytorch that I reference in the previous comment maybe just for the sake of verifying the experimental results??

To be honest, I'm not sure. I didn't even write the diffusers implementation of ToMe, and it was added well after release. Do you think you could try using the runway-ml repo that I used for the paper? I am actually travelling to CVPR right now so I don't have access to the machine I did the original testing on right now.

@alex-bene
Copy link

Hey @dbolya, hope CVPR went well! Unfortunately, I haven't found the time to test this yet. If by any chance you have this already set up (the runway-ml environment and/or a "diffusers" environment to cross-check my results), I'd much appreciate the help.

@aiXia121
Copy link

aiXia121 commented Feb 21, 2024

without xformers
and
Name: tomesd
Version: 0.1.3

Batch size: 1.
ToMe ratio: 0.0 -- runtime reduction: 0.00% -- speed increase: 0.00%
ToMe ratio: 0.1 -- runtime reduction: -14.88% -- speed increase: -12.95%
ToMe ratio: 0.2 -- runtime reduction: -6.85% -- speed increase: -6.41%
ToMe ratio: 0.3 -- runtime reduction: -0.11% -- speed increase: -0.11%
ToMe ratio: 0.4 -- runtime reduction: 6.07% -- speed increase: 6.46%
ToMe ratio: 0.5 -- runtime reduction: 8.70% -- speed increase: 9.53%
ToMe ratio: 0.6 -- runtime reduction: 6.90% -- speed increase: 7.41%
Batch size: 2.
ToMe ratio: 0.0 -- runtime reduction: 0.00% -- speed increase: 0.00%
ToMe ratio: 0.1 -- runtime reduction: -18.48% -- speed increase: -15.60%
ToMe ratio: 0.2 -- runtime reduction: -8.47% -- speed increase: -7.81%
ToMe ratio: 0.3 -- runtime reduction: 1.00% -- speed increase: 1.01%
ToMe ratio: 0.4 -- runtime reduction: 8.28% -- speed increase: 9.03%
ToMe ratio: 0.5 -- runtime reduction: 18.84% -- speed increase: 23.22%
ToMe ratio: 0.6 -- runtime reduction: 17.59% -- speed increase: 21.34%
Batch size: 4.
ToMe ratio: 0.0 -- runtime reduction: 0.00% -- speed increase: 0.00%
ToMe ratio: 0.1 -- runtime reduction: -18.92% -- speed increase: -15.91%
ToMe ratio: 0.2 -- runtime reduction: -7.53% -- speed increase: -7.00%
ToMe ratio: 0.3 -- runtime reduction: 2.47% -- speed increase: 2.54%
ToMe ratio: 0.4 -- runtime reduction: 10.39% -- speed increase: 11.60%
ToMe ratio: 0.5 -- runtime reduction: 22.23% -- speed increase: 28.58%
ToMe ratio: 0.6 -- runtime reduction: 21.26% -- speed increase: 27.01%
Batch size: 5.
ToMe ratio: 0.0 -- runtime reduction: 0.00% -- speed increase: 0.00%
ToMe ratio: 0.1 -- runtime reduction: -19.30% -- speed increase: -16.18%
ToMe ratio: 0.2 -- runtime reduction: -8.20% -- speed increase: -7.58%
ToMe ratio: 0.3 -- runtime reduction: 2.85% -- speed increase: 2.94%
ToMe ratio: 0.4 -- runtime reduction: 10.43% -- speed increase: 11.65%
ToMe ratio: 0.5 -- runtime reduction: 23.36% -- speed increase: 30.48%
ToMe ratio: 0.6 -- runtime reduction: 22.01% -- speed increase: 28.22%
Batch size: 10.
ToMe ratio: 0.0 -- runtime reduction: 0.00% -- speed increase: 0.00%
ToMe ratio: 0.1 -- runtime reduction: -19.46% -- speed increase: -16.29%
ToMe ratio: 0.2 -- runtime reduction: -7.60% -- speed increase: -7.06%
ToMe ratio: 0.3 -- runtime reduction: 3.38% -- speed increase: 3.49%
ToMe ratio: 0.4 -- runtime reduction: 11.91% -- speed increase: 13.52%
ToMe ratio: 0.5 -- runtime reduction: 24.69% -- speed increase: 32.79%
ToMe ratio: 0.6 -- runtime reduction: 23.72% -- speed increase: 31.09%

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants