Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove call to F.pad, improved calculation of memory_count #10620

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

bm-synth
Copy link

@bm-synth bm-synth commented Jan 21, 2025

  • remove one call to symmetric padding in F.pad when running with non-replicate pad mode, and instead let padding be done by Conv3d for a more efficient execution;
  • computation of memory_count doesn't extend dimensions to allow torch.compile to do a better optimisation (?) by @ic-synth

cc: @jamesbriggs-synth

@bm-synth bm-synth changed the title Inplace sums, remove call to F.pad and better memory count Inplace sums, remove call to F.pad, improved calculation of memory Jan 21, 2025
@bm-synth bm-synth changed the title Inplace sums, remove call to F.pad, improved calculation of memory Inplace sums, remove call to F.pad, improved calculation of memory_count Jan 21, 2025
@bm-synth bm-synth marked this pull request as ready for review January 21, 2025 12:01
@bm-synth bm-synth changed the title Inplace sums, remove call to F.pad, improved calculation of memory_count in-place sums, remove call to F.pad, improved calculation of memory_count Jan 21, 2025
@hlky
Copy link
Member

hlky commented Jan 22, 2025

Hi @bm-synth. Thanks for your contribution. Can you share some figures on the memory and performance improvements?

@brunomaga
Copy link

brunomaga commented Jan 24, 2025

Hi @hlky.

Running the following test_autoencoder.py

import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d

torch.manual_seed(42)

def train(model: nn.Module, video_input: torch.Tensor):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    model.train()
    start_train = time.time()
    for iteration in range(100):  # Simulate 100 training iterations
        optimizer.zero_grad()
        output = model(video_input)[0]
        loss = F.mse_loss(output, output+iteration) # sum iteration to fake different grads per iteration
        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
    train_time = time.time() - start_train
    print("train_time", train_time, "secs")
    return output.to("cpu")


def eval(model: nn.Module, video_input: torch.Tensor):
    model.eval()
    start_train = time.time()
    with torch.no_grad():
        for _ in range(300):  # Simulate 300 inference iterations
            model(video_input)
            torch.cuda.synchronize()
    eval_time = time.time() - start_train
    print("eval_time", eval_time, "secs")

calling with that input shape [1, 128, 8, 544, 960], on the main branch, gives:

$ PYTHONPATH=./diffusers_main/src/ python test_autoencoder.py
input size:  0.498046875 GBs
eval_time 33.06385564804077 secs
train_time 34.33984375 secs
Max memory 22.18018913269043 GBs

calling this PR branch gives:

$ PYTHONPATH=./diffusers_PR/src/ python test_autoencoder.py
input size:  0.498046875 GBs
eval_time 31.588099241256714 secs
train_time 34.1251916885376 secs
Max memory 22.17398452758789 GBs

on the shape (1, 3, 300, 544, 960), main branch:

$ PYTHONPATH=./diffusers_main/src/ python test_autoencoder.py
input size:  0.43773651123046875 GBs
eval_time 17.759469032287598 secs
train_time 96.50320744514465 secs
Max memory 16.353439331054688 GBs

and this PR:

$ PYTHONPATH=./diffusers_PR/src/ python test_autoencoder.py
input size:  0.43773651123046875 GBs
eval_time 16.8880774974823 secs
train_time 96.04004764556885 secs
Max memory 16.34803009033203 GBs

I'll try to test more dimensions.

@bm-synth bm-synth changed the title in-place sums, remove call to F.pad, improved calculation of memory_count remove call to F.pad, improved calculation of memory_count Jan 25, 2025
@hlky
Copy link
Member

hlky commented Jan 27, 2025

@bm-synth Great, thanks. Would it also be possible to verify numerical accuracy between the two versions? For a change like this we would expect between 0 to 1e-6 difference.

@brunomaga
Copy link

brunomaga commented Jan 27, 2025

@hlky I updated the code above to fix a seed (torch.manual_seed(42)) and save the tensor with the model output after 100 training iterations. Then I ran this to compare both output_*.pt files:

if __name__=='__main__':
    output_main: torch.Tensor = torch.load("output_main.pt")
    output_PR: torch.Tensor = torch.load("output_PR.pt")
    print("mean:", output_main.mean().item(), "vs", output_PR.mean().item())
    print("std:", output_main.std().item(), "vs", output_PR.std().item())
    print("max abs diff:", (output_PR-output_main).diff().abs().max().item())
    assert torch.allclose(output_main, output_PR)

output:

mean: -8.058547973632812e-05 vs -8.058547973632812e-05
std: 0.578125 vs 0.578125
max abs diff: 0.0

@bm-synth
Copy link
Author

@hlky ping?

@hlky
Copy link
Member

hlky commented Jan 31, 2025

Hi @bm-synth. We need to verify the accuracy of CogVideoXCausalConv3d, not the output from your trained model.

Code

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d


class CogVideoXSafeConv3d_PR(nn.Conv3d):
    r"""
    A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        memory_count = torch.prod(torch.tensor(input.shape)) * 2 / 1024**3

        # Set to 2GB, suitable for CuDNN
        if memory_count > 2:
            kernel_size = self.kernel_size[0]
            part_num = int(memory_count / 2) + 1
            input_chunks = torch.chunk(input, part_num, dim=2)

            if kernel_size > 1:
                input_chunks = [input_chunks[0]] + [
                    torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
                    for i in range(1, len(input_chunks))
                ]

            output_chunks = []
            for input_chunk in input_chunks:
                output_chunks.append(super().forward(input_chunk))
            output = torch.cat(output_chunks, dim=2)
            return output
        else:
            return super().forward(input)


class CogVideoXCausalConv3d_PR(nn.Module):
    r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.

    Args:
        in_channels (`int`): Number of channels in the input tensor.
        out_channels (`int`): Number of output channels produced by the convolution.
        kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
        stride (`int`, defaults to `1`): Stride of the convolution.
        dilation (`int`, defaults to `1`): Dilation rate of the convolution.
        pad_mode (`str`, defaults to `"constant"`): Padding mode.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int, int]],
        stride: int = 1,
        dilation: int = 1,
        pad_mode: str = "constant",
    ):
        super().__init__()

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size,) * 3

        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

        # TODO(aryan): configure calculation based on stride and dilation in the future.
        # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
        time_pad = time_kernel_size - 1
        height_pad = (height_kernel_size - 1) // 2
        width_pad = (width_kernel_size - 1) // 2

        self.pad_mode = pad_mode
        self.height_pad = height_pad
        self.width_pad = width_pad
        self.time_pad = time_pad
        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
        self.const_padding_conv3d = (0, self.width_pad, self.height_pad)

        self.temporal_dim = 2
        self.time_kernel_size = time_kernel_size

        stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
        dilation = (dilation, 1, 1)
        self.conv = CogVideoXSafeConv3d_PR(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
            padding_mode="zeros",
        )

    def fake_context_parallel_forward(
        self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        if self.pad_mode == "replicate":
            inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
        else:
            kernel_size = self.time_kernel_size
            if kernel_size > 1:
                cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
                inputs = torch.cat(cached_inputs + [inputs], dim=2)
        return inputs

    def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
        inputs = self.fake_context_parallel_forward(inputs, conv_cache)

        if self.pad_mode == "replicate":
            conv_cache = None
        else:
            conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()

        output = self.conv(inputs)
        return output, conv_cache

model = CogVideoXCausalConv3d(in_channels=128, out_channels=512, kernel_size=3).eval()
with torch.no_grad():
    output = model(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]

with torch.no_grad():
    output_2 = model(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]

torch.testing.assert_close(output, output_2)

print((output - output_2).abs().max())

model_pr = CogVideoXCausalConv3d_PR(in_channels=128, out_channels=512, kernel_size=3).eval()
with torch.no_grad():
    output_pr = model_pr(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]

torch.testing.assert_close(output, output_pr)
Mismatched elements: 2139073042 / 2139095040 (100.0%)
Greatest absolute difference: 5.3313703536987305 at index (0, 421, 5, 286, 946) (up to 1e-05 allowed)
Greatest relative difference: 2893981952.0 at index (0, 348, 3, 142, 869) (up to 1.3e-06 allowed)
print((output - output_pr).abs().max())
tensor(5.3314)

The first check here torch.testing.assert_close(output, output_2) shows that CogVideoXCausalConv3d is deterministic so we would only accept up to around 1e-6 difference but preferably less or no change.

Also note the code will run on CPU in float32, this is to avoid other source of non-determinism when testing although it will use a large amount of memory. Generally we choose the smallest possible shape and model configuration for tests.

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 25, 2025
@bm-synth
Copy link
Author

bm-synth commented Mar 1, 2025

@hlky your code has an issue. the random state before initializing each CogVideoXCausalConv3D is different. Try to call torch.manual_seed(42) before each initialization, i.e.:

if __name__=='__main__':

    print(f"diffusers version: {diffusers.__version__}")

    torch.manual_seed(42)  #### <---- added this
    model = CogVideoXCausalConv3d(in_channels=128, out_channels=512, kernel_size=3).eval()
    with torch.no_grad():
        output = model(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]

    with torch.no_grad():
        output_2 = model(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]

    torch.testing.assert_close(output, output_2)
    print("max abs difference (output, output_2):", (output - output_2).abs().max().item())
    print("number of different elements (output, output_2):", (output != output_2).sum().item())

    torch.manual_seed(42) ##### <---- added this
    model_pr = CogVideoXCausalConv3d_PR(in_channels=128, out_channels=512, kernel_size=3).eval()
    with torch.no_grad():
        output_pr = model_pr(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]
        
    torch.testing.assert_close(output, output_pr)
    print("max abs difference (output, output_pr):", (output - output_pr).abs().max().item())
    print("number of different elements (output, output_pr):", (output != output_pr).sum().item())

output:

diffusers version: 0.32.2
abs difference (output, output_2): 0.0
number of different elements (output, output_2): 0
abs difference (output, output_pr): 0.0
number of different elements (output, output_pr): 0

@github-actions github-actions bot removed the stale Issues that haven't received updates label Mar 2, 2025
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @bm-synth
the refactor looks really nice.

let's merge once @hlky confirm the results are identical

@hlky
Copy link
Member

hlky commented Mar 4, 2025

I'm still concerned by the reproducibility, especially as it's with float32 on CPU. Using torch.manual_seed isn't something we normally require users to do yet with this change they would need to use torch.manual_seed to get the same result from VAE decode in addition to passing a torch.Generator to the pipeline to get the same generated latent, and the issue would be compounded by the expected sources of non-determinism i.e. CUDA.

@bm-synth
Copy link
Author

bm-synth commented Mar 4, 2025

@hlky (cc @yiyixuxu ) your reproducibility concerns are not related to this PR. They are an issue in the production code. This PR gives the same behaviour as production. If you run the code below, where you only call your current CogVideoX module twice, without setting the random seed, it already fails. See below:

if __name__=='__main__':

    print(f"diffusers version: {diffusers.__version__}")

    torch.manual_seed(42)
    model = CogVideoXCausalConv3d(in_channels=128, out_channels=512, kernel_size=3).eval()
    with torch.no_grad():
        output = model(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]

    with torch.no_grad():
        output_2 = model(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]

    torch.testing.assert_close(output, output_2)
    print("max abs difference (output, output_2):", (output - output_2).abs().max().item())
    print("number of different elements (output, output_2):", (output != output_2).sum().item())

    # torch.manual_seed(42)  ##### <--- THE SECOND CHECK ONLY PASSES IF YOU UNCOMMENT THIS!  
    model_main = CogVideoXCausalConv3d_PR(in_channels=128, out_channels=512, kernel_size=3).eval()
    model_main = CogVideoXCausalConv3d_PR(in_channels=128, out_channels=512, kernel_size=3).eval()
    with torch.no_grad():
        output_main = model_pr(torch.randn([1, 128, 8, 544, 960], generator=torch.Generator().manual_seed(0)))[0]
        
    torch.testing.assert_close(output, output_main)
    print("max abs difference (output, output_pr):", (output - output_main).abs().max().item())
    print("number of different elements (output, output_pr):", (output != output_main).sum().item())

@hlky
Copy link
Member

hlky commented Mar 4, 2025

We get the same output_2 as output with the same input (determined by torch.Generator().manual_seed(0)). We should expect the same output from PR version with the same input.

@bm-synth
Copy link
Author

bm-synth commented Mar 4, 2025

I'm still concerned by the reproducibility, especially as it's with float32 on CPU. Using torch.manual_seed isn't something we normally require users to do yet with this change they would need to use torch.manual_seed to get the same result from VAE decode in addition to passing a torch.Generator to the pipeline to get the same generated latent, and the issue would be compounded by the expected sources of non-determinism i.e. CUDA.

@hlky what i meant is:

  • both this PR and the main give the same output==output_2 in this check here
  • This PR's CogVideoX output matches the output of the CogVideoX in the current main ( see output==output_pr here), but you need to set the same random state when you initialise the module (ie set the seed) otherwise you have two modules with same architecture but different parameter state.
  • This is exactly the same behaviour as the main branch: it gives the same output == output_main here but only if you set the seed.

What am I missing here?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants