Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with broadcasting (3, 1) and (3,) tensors #1788

Closed
IvanYashchuk opened this issue Jun 30, 2022 · 16 comments · Fixed by #1794
Closed

Issue with broadcasting (3, 1) and (3,) tensors #1788

IvanYashchuk opened this issue Jun 30, 2022 · 16 comments · Fixed by #1794
Assignees

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Jun 30, 2022

🐛 Describe the bug

Is this a bug in the nvFuser or the code below is invalid? I think the code is valid, I translated the trace of torch._refs.add to nvFuser Python API calls. There's no error with ATen execution.

import torch

from torch._C._nvfuser import Fusion, FusionDefinition

# Construct and Define Fusion
fusion = Fusion()

with FusionDefinition(fusion) as fd :
    t0 = fd.define_tensor(2)
    t1 = fd.define_tensor(1)

    fd.add_input(t0)
    fd.add_input(t1)

    t0_b = fd.Ops.broadcast_in_dim(t0, [3, 3], [0, 1])
    t1_b = fd.Ops.broadcast_in_dim(t1, [3, 3], [1])
    t2 = fd.Ops.add(t0_b, t1_b)

    fd.add_output(t2)

fusion.print_ir()

# Execute Fusion
input1 = torch.ones(3, 1, device='cuda')
input2 = torch.ones(3, device='cuda')

fusion.execute([input1, input2])

fusion.execute call raises

RuntimeError: Attempting to bind T0.size[1] to 3but it's already set to 1

Versions

.

@jjsjann123
Copy link
Collaborator

Oops, this should work.

Expand to concrete size was added pretty recently, we can double check expand as well as broadcast_in_dim first.

@IvanYashchuk
Copy link
Collaborator Author

IvanYashchuk commented Jun 30, 2022

I just realized that broadcast_in_dim is Python bindings only thing, but I was able to get the same error with a normal broadcast call. I will post the C++ variant of the code shortly.

@IvanYashchuk
Copy link
Collaborator Author

Here's the Python script using broadcast instead of broadcast_in_dim:

import torch

from torch._C._nvfuser import Fusion, FusionDefinition

# Construct and Define Fusion
fusion = Fusion()

with FusionDefinition(fusion) as fd :
    t0 = fd.define_tensor(2)
    t1 = fd.define_tensor(1)

    fd.add_input(t0)
    fd.add_input(t1)

    t0_b = fd.Ops.broadcast(t0, [False, False]) # using broadcast instead of broadcast_in_dim
    t1_b = fd.Ops.broadcast(t1, [True, False])
    t2 = fd.Ops.add(t0_b, t1_b)

    fd.add_output(t2)

fusion.print_ir()

# Execute Fusion
input1 = torch.ones(3, 1, device='cuda')
input2 = torch.ones(3, device='cuda')

fusion.execute([input1, input2])

It fails with the same error RuntimeError: Attempting to bind T0.size[1] to 3but it's already set to 1. But C++ test fails with a different error message:

C++ exception with description "false INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/pytorch/master/torch/csrc/jit/codegen/cuda/executor.cpp":236, please report a bug to PyTorch. Allocations must be based on constant integers for local memory. However, found: T3_l[ iS15{T0.size[0]}, bS6{1} ], T2_l[ iS11{T0.size[0]}, iS12{T0.size[1]} ],  have dynamic allocations but are placed in local memory.
Exception raised from compileFusion at /home/iyashchuk/dev/pytorch/master/torch/csrc/jit/codegen/cuda/executor.cpp:236 (most recent call first):
TEST_F(NVFuserTest, FusionBroadcastVectors_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  TensorView* t0 = makeSymbolicTensor(2);
  TensorView* t1 = makeSymbolicTensor(1);
  fusion.addInput(t0);
  fusion.addInput(t1);
  TensorView* t0_b = broadcast(t0, {false, false});
  TensorView* t1_b = broadcast(t1, {true, false});
  TensorView* t2 = add(t0_b, t1_b);
  fusion.addOutput(t2);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  auto input0 = at::randn({3, 1}, options);
  auto input1 = at::randn({3}, options);
  auto aten_output = at::add(input0, input1);

  FusionExecutor fe;
  fe.compileFusion(&fusion, {input0, input1});
  auto cg_outputs = fe.runFusion({input0, input1});

  testValidate(&fusion, cg_outputs, {input0, input1}, {aten_output}, __LINE__, __FILE__);
}

@jjsjann123
Copy link
Collaborator

Broadcasting on t0 doesn't look right to me... to repro the original problem, you might need to use expand instead. Let me take a quick look there as well.

@IvanYashchuk
Copy link
Collaborator Author

Okay, broadcasting on t0 isn't needed actually. The error is raised also with t2 = add(t0, t1_b).

@shmsong
Copy link

shmsong commented Jun 30, 2022

Looks like this logic is flipped. What's the semantic for broadcast_in_dim supposed to be? Does [1] mean the axis 1 is broadcasted or non-broadcasted?

https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp#L632

@shmsong
Copy link

shmsong commented Jun 30, 2022

This worked for me on TOT for example:

import torch

from torch._C._nvfuser import Fusion, FusionDefinition

# Construct and Define Fusion
fusion = Fusion()

with FusionDefinition(fusion) as fd :
    t0 = fd.define_tensor(2)
    t1 = fd.define_tensor(1)

    fd.add_input(t0)
    fd.add_input(t1)

    t1_b = fd.Ops.broadcast_in_dim(t1, [3, 3], [0]) # 1 -> 0
    t2 = fd.Ops.add(t0, t1_b)

    fd.add_output(t2)

fusion.print_ir()

# Execute Fusion
input1 = torch.ones(3, 1, device='cuda')
input2 = torch.ones(3, device='cuda')

fusion.execute([input1, input2])

@shmsong
Copy link

shmsong commented Jun 30, 2022

I believe eager mode interpreted this as an implicit broadcast on axis 1,
input1 = torch.ones(3, 1, device='cuda')
and this should somehow be reflected on our fusion definition without seeing the actual input, if that was the original intention.

@jjsjann123
Copy link
Collaborator

Something is wrong with broadcast_in_dim... The sample python code runs fine without error, but it returns the output tensor not expanded to the right size. print(out[0].shape)

Inputs:
  T0_g[ iS0{i0}, iS1{i2} ], float
  T1_g[ iS2{i3} ], float
Outputs:
  T3_g[ iS5{i0}, iS6{i2} ], float

%kernel_math {
T2_l[ iS3{i3}, bS4{1} ]
   = broadcast( T1_g[ iS2{i3} ] )
T3_g[ iS5{i0}, iS6{i2} ]
   = T0_g[ iS0{i0}, iS1{i2} ]
   + T2_l[ iS3{i3}, bS4{1} ];
}

torch.Size([3, 1])

@shmsong
Copy link

shmsong commented Jun 30, 2022

So we want output size to be [3,3]?

In that case I believe we need to make T0[I, B] instead of T0[I,I], currently the fusion definition is saying T0 is a concrete tensor of size[3,1].

The iS1 in T0_g[ iS0{i0}, iS1{i2} ] is symbolic shaped so there's no indication that it could be a broadcast.

Do we currently support creating tensors with broadcast axes with define_tensor ?

@jjsjann123
Copy link
Collaborator

t0 = fd.define_tensor([3, 1], [1, 1]) (for some reason -1 for size is not supported yet).

hmmm. even that doesn't work, since I don't see expand in the kernel math...

@jjsjann123
Copy link
Collaborator

I guess this is expected, since broadcast_in_dim only does broadcast at this moment, but not expand to the right size. https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp#L606-L635

@jjsjann123 jjsjann123 self-assigned this Jun 30, 2022
@jjsjann123
Copy link
Collaborator

I'll patch this. Self assigned.

@jjsjann123
Copy link
Collaborator

Just got it patched.

@IvanYashchuk
Copy link
Collaborator Author

Maybe there should be a squeeze operation first to remove dimensions with 1 shape and then use broadcast. This code works as expected without using expand:

In [10]: import torch
    ...:
    ...: from torch._C._nvfuser import Fusion, FusionDefinition
    ...:
    ...: # Construct and Define Fusion
    ...: fusion = Fusion()
    ...:
    ...: with FusionDefinition(fusion) as fd :
    ...:     t0 = fd.define_tensor(1)
    ...:     t1 = fd.define_tensor(1)
    ...:
    ...:     fd.add_input(t0)
    ...:     fd.add_input(t1)
    ...:
    ...:     t0_b = fd.Ops.broadcast(t0, [False, True])
    ...:     t1_b = fd.Ops.broadcast(t1, [True, False])
    ...:     t2 = fd.Ops.add(t0_b, t1_b)
    ...:
    ...:     fd.add_output(t2)
    ...:
    ...: fusion.print_ir()
    ...:
    ...: # Execute Fusion
    ...: input1 = torch.ones(3, device='cuda')
    ...: input2 = torch.ones(3, device='cuda')
    ...:
    ...: fusion.execute([input1, input2])
Inputs:
  T0_g[ iS0{i0} ], float
  T1_g[ iS1{i2} ], float
Outputs:
  T4_g[ iS6{i0}, iS7{i2} ], float

%kernel_math {
T2_l[ iS2{i0}, bS3{1} ]
   = broadcast( T0_g[ iS0{i0} ] )
T3_l[ bS4{1}, iS5{i2} ]
   = broadcast( T1_g[ iS1{i2} ] )
T4_g[ iS6{i0}, iS7{i2} ]
   = T2_l[ iS2{i0}, bS3{1} ]
   + T3_l[ bS4{1}, iS5{i2} ];
}

Out[10]:
[tensor([[2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.]], device='cuda:0')]

@jjsjann123
Copy link
Collaborator

We shouldn't need to add squeeze, expand can work on broadcasted dimension (size-1).

jjsjann123 added a commit that referenced this issue Jul 6, 2022
Fixes #1788

Added expand in broadcast_in_dim to support expanding to concrete size. Note that we are not supporting dynamic shape for concrete size at this moment.
naoyam added a commit that referenced this issue Jul 11, 2022
* Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (#1775)

`MaxInfoPropagator` is renamed to `MaxInfoSpanningTree`, it now only does path-finding, and the propagation is in a separate class `MaxInfoSpanningTree::Propagator`. Same for `MaxRootDomainInfoPropagator`.

`MaxInfoSpanningTree` and `MaxRootDomainInfoSpanningTree`  now allow specifying a selector, which controls which subgraph should be included in path-finding.

`MaxRootDomainInfoSpanningTree` also gets a few new constructors for convenience to use.

`TransormPropagator` is now a subclass of `MaxInfoSpanningTree::Propagator`, so the way to use it has changed.

Now `MaxInfoSpanningTree` and `MaxRootDomainInfoSpanningTree` will store the path after generation so that the same path can be traversed multiple times. This will be useful to support use cases like new `computeAt`. Pseudo-code:
```C++
void TensorView::computeAt(TensorView tv, int pos) {
  auto ComputeAtSubgraphSelector selector(this, tv);
  MaxRootDomainInfoSpanningTree path(tv, pos, &selector);
  TransformPropagator propagator(tv, pos);
  path.traverse(&propagator);
  ComputeAtPosPropagator ca_propagator(tv, pos);
  path.traverse(&ca_propagator);
}
```

* Revert scheduling changes. Cleanup only.

* Start drafting grid persistent kernels.

* Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)

Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com>

* Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)

* Fix div(Val, TensorView) (#1778)

* Fix div(scalar, tensor)

* lintrunner: clang-format

* Adding sibling path for MaxInfoSpanningTree (#1776)

The sibling path is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector. For example, when the producer of a Welford is excluded from the propagation section. See test `FusionTransformPropagateSelectorSibling_CUDA` for a detailed example. Besides, since we know that siblings should be transformed exactly the same, the sibling path is a perfect next hop for preserving information.

If you want a spanning tree without a sibling path, you can override `allowSibling` as `return false` in your selector;

* Save.

* Disable register reuse across serial broadcast ops (#1787)

Disable memory aliasing for inner sharing across serial broadcast.

* Fix isIntegralType error msg (#1789)

* Transform propagator skip replay when possible (#1782)

This comment in the code describes what this PR is doing:

```C++
  // Note: [Using multiple TransformPropagators]
  // There are cases that we use multiple TransformPropagators along different
  // spanning trees with different references in the same fusion. Some of these
  // spanning trees could overlap. In cases when there are overlapping nodes,
  // TransformPropagator needs to respect the replay of others, because the
  // current TransformPropagator might not contain the most amount of
  // information on how to do the correct transformation. The logic below tells
  // TransformPropagator to skip the replay when not necessary.
```

* Output allocate patch (#1790)

Caching strides along with sizes. This is to support current expand, which introduces non-contiguous output tensor

* Add SpanningTreePrinter (#1786)

* New compute at interface (#1743)

Rewrite of the compute at pass to rely on the new propagation mechanisms.

* Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)

* Some further cleanup for the new computeAt interface (#1793)

Revert MaxProducerPosUpdater to old algo.

* Use TransformPropagatorWithCheck in many tests (#1795)

* validateDomain in TransformPropagator (#1796)

* InlinePropagator please don't replay (#1797)

This PR makes `InlinePropagator` just set compute-at positions. It will not replay any tensor. If you want to replay, please use `TransformPropagator` and friends to do so.

Currently, `InlinePropagator` is already asserting no replay for standard and best effort compute at. So this PR is mostly about making most inlined compute at works as well.

This PR also does a lot of cleanups to remove the word "replay" from comments and variable and function names from `InlinePropagator`.

I also cleaned up `recordReplayedPos` and `retrieveReplayedPos`, now the logic is much easier to understand.

* Coding style cleanups (#1798)

Per offline discussion with @csarofeen, this PR does many renaming for better coding style: For all propagation-related things, I am now using the names `P2C` and `C2P` instead of `CasP` and `PasC`. Because "A as B" somewhat implies we want to replay A the same as B, but "B to A" sounds more general and is a better word for this case. Also, I modified the order of function arguments to match the order in its name. For example `PasC` should have `(producer, consumer)` or `(to, from)`, but not `(consumer, producer)` or `(from, to)`, and `C2P` should have `(consumer, producer)` or `(from, to)`, but not `(producer, consumer)` or `(to, from)`.

* Add parsing support for `_to_copy` to handle AMP casts. (#1756)

1. Add support for _to_copy() to support AMP casts.
2. refactored cast, accept none for dtype
3. python tests

Co-authored-by: jjsjann123 <jiej@nvidia.com>

* MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)

* Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)

Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com>

* More cleanup on InlinePropagator (#1800)

I just realized that `InlinePropagator` can be further simplified because it no longer replays.

Since `InlinePropagator` is no longer doing replay, it is more like a "for each" problem rather than a propagation problem:

For each tensor `tv`, if we already know what is the max position of `tv` that is mapped to the reference tensor's selected outer dimensions(stored in `mapped_reference_pos_` in the code), setting the CA position is a very local operation, and is as simple as checking `tv` itself and all its consumers to determine the inline position.

`InlinePropagator` is not completely a "for each" problem only because the computation of `mapped_reference_pos_` is a propagation problem.

This cleanup reorganizes the code of `InlinePropagator` so it is clear that `InlinePropagator` is nothing but a two-step process:
Step 1: Do a propagation to find the `mapped_reference_pos_` for all tensors.
Step 2: For each tensor, check itself and its consumers to determine the CA position.

Conceptually, I would like to split step 1 with step 2. Because this split makes these concepts decoupled. Especially, this PR makes `mapped_reference_pos_` only contain info about the reference tensor, and is independent of the CA position (Currently, this is not true for best effort and most inlined computeAt without this PR). Now, in my view, `InlinePropagator` is conceptually very simple and easy to understand.

In terms of implementation, step 1 and step 2 can be interleaved, because when we don't need to know the `mapped_reference_pos_` for `tv`'s consumer in order to compute the CA position of `tv`. So a one-pass traverse could do both step 1 and step 2 altogether.

* Temporarily disable test requring large shared memory. (#1802)

* Grouping grid allreduces across iterations (#1755)

* Extend the grouped grid reduction kernel

The kernel itself should work with an arbitrary number of inputs, but
the underlying data structure, Tuple, still explicitly needs to be
specialized for the number of values, which is currently limited to 8.

* Check siblings in getMaxPosAll (#1805)

* remove dead indexing code (#1806)

* Broadcast in dim with expand (#1794)

Fixes #1788

Added expand in broadcast_in_dim to support expanding to concrete size. Note that we are not supporting dynamic shape for concrete size at this moment.

* spam nvrtc options (#1783)

TORCH_WARN on nvrtc debug option impacting performance.

Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
Co-authored-by: S. Song <41357537+shmsong@users.noreply.github.com>
Co-authored-by: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
Co-authored-by: Sergey Lebedev <sergeyle@nvidia.com>
Co-authored-by: jjsjann123 <jiej@nvidia.com>
Co-authored-by: Kevin Stephano <kevin.stephano@gmail.com>
Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants