Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime] Spike in memory usage when running VAE ("segmind/SSD-1B", "stabilityai/stable-diffusion-2") #5924

Closed
jon-chuang opened this issue Nov 24, 2023 · 11 comments
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@jon-chuang
Copy link

jon-chuang commented Nov 24, 2023

Describe the bug

When running inference, the VAE decoders for SD2 and SSD-1B are only ~1/10th the size of the unet, but cause a GPU memory spike when run (sharp increase of 2GB+), causing OOM on my device

Would be able to help investigate if advice is provided. I think there's somewhere where we can early release some unused memory.

Reproduction

Watch the memory usage when the vae is applied for:
"stabilityai/stable-diffusion-2"
"segmind/SSD-1B"

Logs

File "/home/jonch/Desktop/sdpa.py", line 2228, in do_POST
    output = get_model_instance()(post_data)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jonch/.local/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1096, in __call__
    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
  File "/home/jonch/.local/lib/python3.10/site-packages/diffusers/models/autoencoder_kl.py", line 318, in decode
    decoded = self._decode(z).sample
  File "/home/jonch/.local/lib/python3.10/site-packages/diffusers/models/autoencoder_kl.py", line 289, in _decode
    dec = self.decoder(z)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jonch/.local/lib/python3.10/site-packages/diffusers/models/vae.py", line 321, in forward
    sample = up_block(sample, latent_embeds)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jonch/.local/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py", line 2535, in forward
    hidden_states = upsampler(hidden_states)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jonch/.local/lib/python3.10/site-packages/diffusers/models/resnet.py", line 199, in forward
    hidden_states = self.conv(hidden_states, scale)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jonch/.local/lib/python3.10/site-packages/diffusers/models/lora.py", line 228, in forward
    return F.conv2d(

System Info

Laptop 4080
torch nightly
diffusers Version: 0.23.1

Who can help?

cc: @sayakpaul @DN6 @yiyixuxu @patrickvonplaten for VAE expertise

@jon-chuang jon-chuang added the bug Something isn't working label Nov 24, 2023
@jon-chuang
Copy link
Author

Maybe related: #5594 #5761

@jon-chuang
Copy link
Author

jon-chuang commented Nov 25, 2023

Here is my memory history as per https://pytorch.org/docs/stable/torch_cuda_memory.html

It seems like an issue with how cuDNN allocates memory for conv

Screenshot from 2023-11-25 08-58-36

@jon-chuang
Copy link
Author

The difference seems to be in the implementation of diffuser's unet_2d_blocks (higher memory usage) v.s. unet_2d_condition (lower memory usage).

Investigation underway.

@jon-chuang
Copy link
Author

jon-chuang commented Nov 25, 2023

Seems to be due to upsampling. It causes the conv activations to be extremely large (300/600MB)

EDIT: Found the culprit - activations of size 256 * 768 * 768 * 2 = 300MB (256 channels for image of size 768 * 768 at fp16 precision)

@jon-chuang
Copy link
Author

jon-chuang commented Nov 25, 2023

The sources of the largest memory spikes are:
create_out, cudnn_convolution_forward, run_conv_plan (this is some function of the input and output channels, maybe it is a scratch buffer of something. But it is basically in_bytes + out_bytes).

Not sure why there are 2 allocations. Honestly, it seems that the create_out allocation is unnecessary as only the second one is actually used 🤔

Also, it's weird to me that the upsample block has 256 channels while the next resblock run has only 128 channels...

Maybe relevant references:

  1. https://arxiv.org/pdf/1610.03618.pdf
  2. https://zdevito.github.io/2022/08/04/cuda-caching-allocator.html

@sayakpaul
Copy link
Member

Thanks for the super cool investigation here, @jon-chuang. However, I don't have any concrete suggestions to reduce the memory usage here as the blocks that seem to be causing the spikes are known to be memory-intensive.

Does it help to use FP16 precision if not being used already?

@jon-chuang
Copy link
Author

jon-chuang commented Nov 27, 2023

Yep, I was just wondering if there a way to reduce the peak memory allocated e.g. by configuring cuDNN or generated triton code to prefer using less memory.

I'm actually already running fp16.

See also pytorch/pytorch#31500, pytorch/pytorch#49207

@sayakpaul
Copy link
Member

Thanks for sharing your findings. I don't really see a concrete action item for us here. But I will keep the issue open in case someone else stumbles upon it.

@jon-chuang
Copy link
Author

Sure, it was unintuitive that the activations could use so much memory as I was used to thinking that model weights dominate the memory cost, but I guess for vision models, this is expected.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 26, 2023
@github-actions github-actions bot closed this as completed Jan 3, 2024
@feifeibear
Copy link

Hi, we solve this problem with sequence parallel on multiple devices and chunked input.

https://github.com/xdit-project/DistVAE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

3 participants