Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low performance from unnecessary permutations #936

Open
jonpryai opened this issue Sep 27, 2023 · 18 comments
Open

Low performance from unnecessary permutations #936

jonpryai opened this issue Sep 27, 2023 · 18 comments

Comments

@jonpryai
Copy link

I'm using fx2ait to load an onnx graph. After optimization, the results are not good.

BS: 1, PT Eager time per iter: 0.01654841552734375ms, PT Eager QPS: 60.43, FX2AIT time per iter: 0.024108586425781252ms, FX2AIT Eager QPS: 41.48, Speedup: 0.69

Let alone compared to tensorRt. I profiled the optimized graph and found:

61.9 11,952,884,520 64,800 184,458.1 123,999.0 18,112 1,017,919 216,853.6 void ::PermuteKernel<(unsigned long)4, (unsigned long)2, int>(::PermuteKernelPara…

Analyzing this in nsys, i see what is happening is that the graph is consistently doing:

permute -> element wise addition -> permute.

These permutations don't do anything because the element wise operator doesn't care about the ordering.

How to fix?

@jonpryai jonpryai changed the title Low performance from unnessary permutations Low performance from unnecessary permutations Sep 27, 2023
@ColinPeppler
Copy link
Contributor

Hi @jonpryai, thanks for flagging this. It does seem like at least one of the permutes could be redundant. But without a minimal repro, it's hard to determine whether they should be removed and whether we need a pass to handle this case.

Do you mind sharing details on how to reproduce this? Thanks!

@jonpryai
Copy link
Author

I use this to compile

import onnx
from onnx2torch import convert
from fx2ait.example.benchmark_utils import benchmark_function

batch_size = 1
class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        onnx_model = onnx.load("test.onnx")
        self.mod = convert(onnx_model)
    def forward(self, x):
        return self.mod(x)

model = TestModule().cuda().half()
inputs = [torch.randn(batch_size, 16, 224, 224).half().cuda()]
benchmark_function(
    self.__class__.__name__,
    100,
    model,
    inputs,
)

Profiling the above is difficult because of all the compiling and profiling. I locate the test.so in /tmp and copy it to the current dir.
Then run:

import unittest

import torch
import torchvision
import onnx
import os
from aitemplate.compiler import compile_model, Model

from onnx2pytorch import ConvertModel
from onnx2torch import convert
from fx2ait.example.benchmark_utils import benchmark_function, verify_accuracy

def benchmark(model_name, batch_size, mod=None, graph_mode=True):
    # Load params
    #cuda_params = export_to_torch_tensor(model_name)
[test.zip](https://github.com/facebookincubator/AITemplate/files/12752757/test.zip)

    # Load compiled model
    if mod is None:
        model_name = f"{model_name}_{batch_size}"
        mod = Model(os.path.join("./", "test.so"))


    # prepare input/output tensor
    x_input = torch.randn([batch_size, 16, 224, 224]).cuda().half()
    x_input = x_input.contiguous()
    y_output = torch.zeros([batch_size, 64, 56, 56]).cuda().half()
    y_output = y_output.contiguous()

    # warm up
    t, _, __ = mod.benchmark_with_tensors(
        [x_input],
        [y_output],
        count=100,
        repeat=4,
        graph_mode=graph_mode,
    )
    # benchmark
    t, _, __ = mod.benchmark_with_tensors(
        [x_input],
        [y_output],
        count=100,
        repeat=4,
        graph_mode=graph_mode,
    )
    print(f"batch_size: {batch_size}, latency: {t}")
    dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1")
    dev_flag = dev_flag.replace(",", "_")
    with open(f"resnet50_ait_benchmark_dev_{dev_flag}.txt", "a") as f:
        f.write(f"batch_size: {batch_size}, latency: {t}\n")


if __name__ == "__main__":
    benchmark("",1)```

@jonpryai
Copy link
Author

Example onnx section:

test.zip

@ColinPeppler
Copy link
Contributor

For security reasons, I'm unable to download external files. I hope you understand.

It'll be easier to reproduce if you can share the model graph that AIT dumps automatically via dump_graph_debug_str_to_file. You'll need to set the following environment variable: LOGLEVEL=DEBUG when you compile the model. The files will appear in your workdir.

Once you do that, could you share the contents of memory_planning_pseudo_code.txt?

@jonpryai
Copy link
Author

jonpryai commented Oct 4, 2023

(Tensor(name=permute_0_0, shape=[1, 224, 224, 16])) 
= permute()(
Tensor(name=x, shape=[1, 16, 224, 224]))

# conv2d_bias_1
(Tensor(name=conv2d_bias_1_0, shape=[1, 112, 112, 32])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=permute_0_0, shape=[1, 224, 224, 16]), Tensor(name=mod_level1_level1_0_Conv_weight, shape=[32, 3, 3, 16], data=(9216 bytes)), Tensor(name=mod_level1_level1_0_Conv_bias, shape=[32], data=(64 bytes)))

# permute_2
(Tensor(name=permute_2_0, shape=[1, 32, 112, 112])) 
= permute()(
Tensor(name=conv2d_bias_1_0, shape=[1, 112, 112, 32]))

# fused_elementwise_19
(Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112])) 
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(
Tensor(name=permute_2_0, shape=[1, 32, 112, 112]))

# permute_4
(Tensor(name=permute_4_0, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112]))

# permute_4
(Tensor(name=permute_5_0, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=elementwise_3_0, shape=[1, 32, 112, 112]))

# conv2d_bias_6
(Tensor(name=conv2d_bias_6_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=permute_5_0, shape=[1, 112, 112, 32]), Tensor(name=mod_level2_tree1_conv1_Conv_weight, shape=[64, 3, 3, 32], data=(36864 bytes)), Tensor(name=mod_level2_tree1_conv1_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_7_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_6_0, shape=[1, 56, 56, 64]))

# fused_elementwise_20
(Tensor(name=elementwise_8_0, shape=[1, 64, 56, 56])) 
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(
Tensor(name=permute_7_0, shape=[1, 64, 56, 56]))

# permute_9
(Tensor(name=permute_9_0, shape=[1, 56, 56, 64])) 
= permute()(
Tensor(name=elementwise_8_0, shape=[1, 64, 56, 56]))

# conv2d_bias_10
(Tensor(name=conv2d_bias_10_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=1)(
Tensor(name=permute_9_0, shape=[1, 56, 56, 64]), Tensor(name=mod_level2_tree1_conv2_Conv_weight, shape=[64, 3, 3, 64], data=(73728 bytes)), Tensor(name=mod_level2_tree1_conv2_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_11_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_10_0, shape=[1, 56, 56, 64]))

# max_pool2d_12
(Tensor(name=max_pool2d_12_0, shape=[1, 56, 56, 32])) 
= max_pool2d(stride=2, pad=0, kernel_size=2, reduce_func=max)(
Tensor(name=permute_4_0, shape=[1, 112, 112, 32]))

# conv2d_bias_15
(Tensor(name=conv2d_bias_15_0, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=0, stride=1)(
Tensor(name=max_pool2d_12_0, shape=[1, 56, 56, 32]), Tensor(name=mod_level2_project_project_0_Conv_weight, shape=[64, 1, 1, 32], data=(4096 bytes)), Tensor(name=mod_level2_project_project_0_Conv_bias, shape=[64], data=(128 bytes)))

# permute_7
(Tensor(name=permute_16_0, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=conv2d_bias_15_0, shape=[1, 56, 56, 64]))

# fused_elementwise_21
(Tensor(name=output_0, shape=[1, 64, 56, 56])) 
= fused_elementwise(func=[<FuncEnum.ADD: 1>, <FuncEnum.RELU: 18>])(
Tensor(name=permute_11_0, shape=[1, 64, 56, 56]), Tensor(name=permute_16_0, shape=[1, 64, 56, 56]))```

@jonpryai
Copy link
Author

jonpryai commented Oct 5, 2023

This image of the network might be helpful.

Screenshot from 2023-10-05 09-44-11

@ColinPeppler
Copy link
Contributor

It does seem like either permute2 or permute4 can be removed here. It'll be easier to remove permute_2 imo.

And sorry for the delay, but this is what I believe we need:

  1. Find the conditions for removing permute_2.
    • We can make the conditions specific for your graph (i.e. only when the middle op is an elementwise-relu/gelu/etc.).
  2. Remove the first permute.
  3. Take the first permute's input (conv2d_bias_1_0) and make it the new input for the middle op (fused_elementwise_19).
  4. Confirm the shapes are correct for middle op and the remaining permute.
  5. Write a test case and confirm its accuracy.

Here's some pointers:

Lmk if there's any questions there.

@jonpryai
Copy link
Author

jonpryai commented Oct 6, 2023

I am not very familiar with the code, so I could be wrong. But my first impression looking at this is while the optimizer is able to look at different orderings, NHWC and NCHW for the conv2d, for some reason it is married to NCHW for the elementwise, and maybe doesn't take into account the permutation cost.

I think that both permute_2 and permute_4 can be removed. There's also 2 copies of permute_4 that yield exactly the same tensor. What is happening here is:

conv2d(NHWC) -> toNCHW -> elementWise -> toNHWC
                                      -> toNHWC

which is the same thing as
conv2d(NHWC) -> elementWise

@ColinPeppler
Copy link
Contributor

Ah I see, both permutes can definitely be removed in that case. And I'm not sure which pass introduces them in the first place.

Do you still have the dumped graphs in your directory? We can see which pass adds the permutes by looking at the {passname}_pseudo_code.txt.

@jonpryai
Copy link
Author

They are present in everything except toposort_pseudo_code.txt. So bind_constants pass is causing it?

@jonpryai
Copy link
Author

Actually, that's not true. It's even in the toposort, just the nodes haven't been annotated yet.

(Tensor(name=None, shape=[1, 224, 224, 16])) 
= permute()(
Tensor(name=x, shape=[1, 16, 224, 224]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=None, shape=[1, 224, 224, 16]), Tensor(name=mod_level1_level1_0_Conv_weight, shape=[32, 3, 3, 16], data=(9216 bytes)), Tensor(name=mod_level1_level1_0_Conv_bias, shape=[32], data=(64 bytes)))

# None
(Tensor(name=None, shape=[1, 32, 112, 112])) 
= permute()(
Tensor(name=None, shape=[1, 112, 112, 32]))

# None
(Tensor(name=None, shape=[1, 32, 112, 112])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 112, 112, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 112, 112]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=2)(
Tensor(name=None, shape=[1, 112, 112, 32]), Tensor(name=mod_level2_tree1_conv1_Conv_weight, shape=[64, 3, 3, 32], data=(36864 bytes)), Tensor(name=mod_level2_tree1_conv1_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= permute()(
Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=1, stride=1)(
Tensor(name=None, shape=[1, 56, 56, 64]), Tensor(name=mod_level2_tree1_conv2_Conv_weight, shape=[64, 3, 3, 64], data=(73728 bytes)), Tensor(name=mod_level2_tree1_conv2_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 56, 56, 32])) 
= max_pool2d(stride=2, pad=0, kernel_size=2, reduce_func=max)(
Tensor(name=None, shape=[1, 112, 112, 32]))

# None
(Tensor(name=None, shape=[1, 32, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 32]))

# None
(Tensor(name=None, shape=[1, 56, 56, 32])) 
= permute()(
Tensor(name=None, shape=[1, 32, 56, 56]))

# None
(Tensor(name=None, shape=[1, 56, 56, 64])) 
= conv2d_bias(dilate=1, group=1, pad=0, stride=1)(
Tensor(name=None, shape=[1, 56, 56, 32]), Tensor(name=mod_level2_project_project_0_Conv_weight, shape=[64, 1, 1, 32], data=(4096 bytes)), Tensor(name=mod_level2_project_project_0_Conv_bias, shape=[64], data=(128 bytes)))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= permute()(
Tensor(name=None, shape=[1, 56, 56, 64]))

# None
(Tensor(name=None, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.ADD)(
Tensor(name=None, shape=[1, 64, 56, 56]), Tensor(name=None, shape=[1, 64, 56, 56]))

# None
(Tensor(name=output_0, shape=[1, 64, 56, 56])) 
= elementwise(func=FuncEnum.RELU)(
Tensor(name=None, shape=[1, 64, 56, 56]))

Is it possible these nodes are being inserted by fxt2ai?

@ColinPeppler
Copy link
Contributor

It could be fx2ait but it may also be onnx2torch.

I'm curious if replicating the model in Pytorch then using fx2ait will give us the same graph. If not, then I assume it's onnx2torch.

@jonpryai
Copy link
Author

model gv

The permutes do not appear to be in the converted pytorch model. The permutes are present in the AITModel after the trace is performed.

@ColinPeppler
Copy link
Contributor

You're right, the permutes are being added in fx2ait. The result from each conv2d is being permuted via ait_nhwc2nchw (here).

AIT does that because PyTorch takes channel-first tensors for conv, maxpool, etc., whereas, AIT takes channel-last tensors.

A potential workaround is to add a permute after each Conv2D? cc: @chenyang78

@jonpryai
Copy link
Author

Is it possible to just make all the elementwise ops also do the permutation, then we will end up with a graph that is like

toNCHW -> conv2d -> toNHWC -> toHCHW -> elementWise -> to NHWC

then the remove permutations pass will find the redundant permutes

@ColinPeppler
Copy link
Contributor

It sounds like that could work.

But would it be possible to try this?

  1. Permute your tensor so it's channel-last
  2. Set set_tensor_layout_policy(false) before lowering your model -- this avoids the permutes after conv2d

@xmfbit
Copy link
Contributor

xmfbit commented Nov 17, 2023

@jonpryai hi, have you solved the problem?

@jonpryai
Copy link
Author

jonpryai commented Nov 17, 2023

@xmfbit No not really. I am just trying to quickly see what the inference performance of a model would be with AITemplate. I'm wondering if instead of an onnx model, an FX graph may work correctly? Otherwise it may actually be easier to write the code to create an AITemplate model instead of trying to fix fxt2ait.

Trying to import a typical dla34 model gives a good example of the issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants