Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Topi] Allow relay.nn.dense support arbitrary number dimensions #8412

Closed
AndrewZhaoLuo opened this issue Jul 6, 2021 · 8 comments
Closed
Assignees
Labels
topi python/tvm/topi

Comments

@AndrewZhaoLuo
Copy link
Contributor

Right now the schedule assumes two dimensions for the input tensor though the documentation suggests otherwise:
https://tvm.apache.org/docs/api/python/relay/nn.html#tvm.relay.nn.dense

Alternatively we can just adjust the documentation to stipulate relay.nn.dense has an input tensor of only two dimensions. However, other frameworks allow their dense layers to be arbitrary number dimensions. E.g. https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear

For forum discussion see: https://discuss.tvm.apache.org/t/relay-nn-does-relay-nn-dense-supports-multi-dimensional-input/10343/5

@masahi
Copy link
Member

masahi commented Oct 14, 2021

See https://discuss.tvm.apache.org/t/relay-nn-does-relay-nn-dense-supports-multi-dimensional-input/10343/7. I realized that supporting more than 2D input for dense is extremely important for BERT-like model performance. In particular, without fixing this issue, I cannot demonstrate the performance advantage of cutlass BYOC #9261 on transformer models over cublass offload. cc @comaniac @Laurawly

It would also improve TVM numbers for bert_large in #8294 (comment), because those numbers were obtained without dense + activation fusion (also no fusion for batch_matmul).

@AndrewZhaoLuo Are you going to work on this one? If not, I need to do it anyway before or after #9261 lands.

@masahi
Copy link
Member

masahi commented Oct 14, 2021

I think a good way to add N-D input support to topi.dense is to do topi.reshape on the input before doing dense compute, and inside dense schedules we do compute_inline(...) on the reshape stage. This way, we can reuse existing schedules that are hardcoded for 2D input at no performance overhead (modulo some indexing math from reshaping). See my PR #9207 that does something similar.

@AndrewZhaoLuo
Copy link
Contributor Author

Hey @masahi you can go ahead and tackle this one. I'll have time probably next week if this is super high priority.

@comaniac
Copy link
Contributor

I have a concern that we have more than one backends and we need to make sure all backend implementations have the same semantic/spec. For example, previously I encountered the issue (#7730) that CuBLAS batch_matmul implementation doesn't supports implicit broadcasting, which results in errors when lowering batch_matmul to CuBLAS.

If we allow arbitrary dimension in relay.nn.dense, which means implicitly reshaping actually, we need to make sure all supported implementations follow this semantic. Considering Relay graph should be as general as possible to be compatible with arbitrary backends, simplify the operator semantic and make extra behaviors as explicit as possible seems more developer friendly. On the other hand, if supporting arbitrary dimension is a common sense for dense, then we could definitely go with it.

@masahi
Copy link
Member

masahi commented Oct 14, 2021

Hey @masahi you can go ahead and tackle this one

I'm happy to take this, but I also want to continue working on cutlass byoc #9261. Moreover, things look more complicated than I thought yesterday, as pointed out by @comaniac and discussed below. Making sure that N-D input dense is supported by external backends is a really good point. It is kind of odd I didn't think about it yesterday because I was trying to fuse 3-D dense + activation for cutlass byoc. Now I realized that I don't even know if cutlass supports such N-D dense operation.

@comaniac N-D dense input is common in all frameworks so we should definitely support it. My hope is that since those extra batch dimensions can be conceptually fused into 1D batch dim, external libs can directly work on N-D input, i.e. whether or not we explicitly do reshape doesn't change underlying memory addresses. I guess PyTorch simply uses tensor.view(...) before sending inputs to cublas. I think we can do a similar trick you did to support implicit broadcasting in the past.

Other related points I realized now:

  • batch_matmul has the same problem. Our batch_matmul only supports 3-D inputs, while other frameworks support arbitrary number of batch dims as long as corresponding dims are compatible. Huggingface bert-large has 48 batch_matmul, all of which are not fused with the following elemwise ops due to reshape.
  • We should also update qnn.dense and qnn.batch_matmul.
  • Our tensorcore schedules have to check the size of the batch dim to decide if they are applicable. If there are multiple batch dims, we need to take them into account by multiplying them. The change is trivial but we have to touch a lot of code, e.g this includes tensorcore legalize utility that tries to pad batch dim to multiple of 8.

@masahi
Copy link
Member

masahi commented Oct 14, 2021

Another workaround that allows dense + activation without supporting N-D dense is to swap the order of reshape and activations. So right now we are doing dense -> reshape -> activations but we can do dense -> activations -> reshape. We need to reshape the bias but that's compile time op. We are still left with annoying "naked" reshape at the end, but this might be a quicker solution for demonstrating cutlass byoc.

@comaniac
Copy link
Contributor

Thanks for the analysis and I deeply agree with it as we (with @icemelon) are also experiencing the same issue. Making the pattern like dense -> activations is definitely more straightforward, and reshaping bias is basically a free lunch. We may start with this solution to unblock CUTLASS, and we could still support N-D dense/batch_matmul in parallel.

@masahi
Copy link
Member

masahi commented Oct 15, 2021

I found something interesting wrt the reshape swapping approach. If there are back-to-back dense ops, currently we end up with something like

  %2367 = fn (%p06: Tensor[(1024, 1024), float16], %p16: Tensor[(4096, 1024), float16], Primitive=1, hash="c13735290dc46bbc") -> Tensor[(1024, 4096), float16] {
    nn.dense(%p06, %p16, units=None, out_dtype="float16") /* ty=Tensor[(1024, 4096), float16] */
  };
  %2368 = %2367(%2366, meta[relay.Constant][430] /* ty=Tensor[(4096, 1024), float16] */) /* ty=Tensor[(1024, 4096), float16] */;
  %2369 = fn (%p05: Tensor[(1024, 4096), float16], %p15: Tensor[(4096), float16], %p24: float16, Primitive=1, hash="ab37ab7bd1a05f99") -> Tensor[(1024, 4096), float16] {
    %13 = reshape(%p05, newshape=[8, 128, 4096]) /* ty=Tensor[(8, 128, 4096), float16] */;
    %14 = add(%13, %p15) /* ty=Tensor[(8, 128, 4096), float16] */;
    %15 = multiply(%14, %p24) /* ty=Tensor[(8, 128, 4096), float16] */;
    %16 = cast(%15, dtype="float32") /* ty=Tensor[(8, 128, 4096), float32] */;
    %17 = erf(%16) /* ty=Tensor[(8, 128, 4096), float32] */;
    %18 = multiply(%17, 0.5f /* ty=float32 */) /* ty=Tensor[(8, 128, 4096), float32] */;
    %19 = cast(%14, dtype="float32") /* ty=Tensor[(8, 128, 4096), float32] */;
    %20 = add(0.5f /* ty=float32 */, %18) /* ty=Tensor[(8, 128, 4096), float32] */;
    %21 = multiply(%19, %20) /* ty=Tensor[(8, 128, 4096), float32] */;
    %22 = reshape(%21, newshape=[-1, 4096]) /* ty=Tensor[(1024, 4096), float32] */;
    cast(%22, dtype="float16") /* ty=Tensor[(1024, 4096), float16] */
  };
  %2370 = %2369(%2368, meta[relay.Constant][431] /* ty=Tensor[(4096), float16] */, meta[relay.Constant][432] /* ty=float16 */) /* ty=Tensor[(1024, 4096), float16] */;
  %2371 = fn (%p04: Tensor[(1024, 4096), float16], %p14: Tensor[(1024, 4096), float16], Primitive=1, hash="adf706330ace9ece") -> Tensor[(1024, 1024), float16] {
    nn.dense(%p04, %p14, units=None, out_dtype="float16") /* ty=Tensor[(1024, 1024), float16] */
  };

Note that there are two reshapes in the middle, one for 2D -> 3D and another for 3D -> 2D. If we swap activation with reshape, two reshapes become back-to-back and hence they cancel each other.

@masahi masahi self-assigned this Jan 9, 2022
@areusch areusch added the needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it label Oct 19, 2022
@Lunderberg Lunderberg added topi python/tvm/topi and removed needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Oct 28, 2022
@tqchen tqchen closed this as completed Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topi python/tvm/topi
Projects
None yet
Development

No branches or pull requests

6 participants