Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fusion Segmenter asserts #1947

Closed
jjsjann123 opened this issue Aug 31, 2022 · 7 comments
Closed

Fusion Segmenter asserts #1947

jjsjann123 opened this issue Aug 31, 2022 · 7 comments
Assignees
Labels

Comments

@jjsjann123
Copy link
Collaborator

🐛 Describe the bug

The same asserts on a few hugging face benchmarks.

RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp":2613, please report a bug to PyTorch.

Repro on devel

import torch

from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

t0 = torch.randn(2, 512, 128, device="cuda")
t1 = torch.randn(128, device="cuda")
t2 = torch.randn(2, 512, 1, device="cuda")  # alternatively (2, 512, 128)
t3 = torch.randn(2, 512, 128, device="cuda")
t4 = torch.randn(2, 512, 1, device="cuda")
t5 = torch.randn(2, 512, 128, device="cuda")
t6 = torch.randn(2, 512, 128, device="cuda")
t7 = torch.randn(2, 512, 128, device="cuda")
t8 = torch.randn(2, 512, 128, device="cuda")

def forward(arg331_1, arg329_1, arg213_1, arg346_1, arg211_1, _reshape_alias_default_2, arg199_1, arg348_1, arg226_1):
    mul_default_10 = torch.ops.nvprims.mul.default(arg331_1, arg331_1);  arg331_1 = None
    broadcast_in_dim_default_5 = torch.ops.nvprims.broadcast_in_dim.default(arg329_1, [2, 512, 128], [2]);  arg329_1 = None
    broadcast_in_dim_default_4 = torch.ops.nvprims.broadcast_in_dim.default(arg213_1, [2, 512, 128], [0, 1, 2])
    mul_default_14 = torch.ops.nvprims.mul.default(arg346_1, arg346_1);  arg346_1 = None
    broadcast_in_dim_default_3 = torch.ops.nvprims.broadcast_in_dim.default(arg211_1, [2, 512, 128], [0, 1, 2]);  arg211_1 = None
    div_default = torch.ops.nvprims.div.default(arg213_1, 128.0);  arg213_1 = None
    sub_default_4 = torch.ops.nvprims.sub.default(1.0, mul_default_10);  mul_default_10 = None
    mul_default_2 = torch.ops.nvprims.mul.default(_reshape_alias_default_2, broadcast_in_dim_default_5);  broadcast_in_dim_default_5 = None
    mul_default_15 = torch.ops.nvprims.mul.default(mul_default_14, 3.0);  mul_default_14 = None
    sub_default_1 = torch.ops.nvprims.sub.default(arg199_1, broadcast_in_dim_default_3);  arg199_1 = broadcast_in_dim_default_3 = None
    broadcast_in_dim_default_10 = torch.ops.nvprims.broadcast_in_dim.default(div_default, [2, 512, 128], [0, 1, 2]);  div_default = None
    mul_default_3 = torch.ops.nvprims.mul.default(mul_default_2, 128.0)
    convert_element_type_default_2 = torch.ops.nvprims.convert_element_type.default(mul_default_2, torch.float32)
    mul_default_1 = torch.ops.nvprims.mul.default(sub_default_1, broadcast_in_dim_default_4);  sub_default_1 = broadcast_in_dim_default_4 = None
    sum_default_2 = torch.ops.nvprims.sum.default(convert_element_type_default_2, [2]);  convert_element_type_default_2 = None
    mul_default_4 = torch.ops.nvprims.mul.default(mul_default_2, mul_default_1);  mul_default_2 = None
    mul_default_7 = torch.ops.nvprims.mul.default(_reshape_alias_default_2, mul_default_1);  _reshape_alias_default_2 = None
    broadcast_in_dim_default_6 = torch.ops.nvprims.broadcast_in_dim.default(sum_default_2, [2, 512, 1], [0, 1]);  sum_default_2 = None
    convert_element_type_default_3 = torch.ops.nvprims.convert_element_type.default(mul_default_4, torch.float32);  mul_default_4 = None
    convert_element_type_default_4 = torch.ops.nvprims.convert_element_type.default(mul_default_7, torch.float32);  mul_default_7 = None
    broadcast_in_dim_default_9 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_default_6, [2, 512, 128], [0, 1, 2]);  broadcast_in_dim_default_6 = None
    sum_default_3 = torch.ops.nvprims.sum.default(convert_element_type_default_3, [2]);  convert_element_type_default_3 = None
    sum_default_4 = torch.ops.nvprims.sum.default(convert_element_type_default_4, [0, 1]);  convert_element_type_default_4 = None
    sub_default_2 = torch.ops.nvprims.sub.default(mul_default_3, broadcast_in_dim_default_9);  mul_default_3 = broadcast_in_dim_default_9 = None
    broadcast_in_dim_default_7 = torch.ops.nvprims.broadcast_in_dim.default(sum_default_3, [2, 512, 1], [0, 1]);  sum_default_3 = None
    broadcast_in_dim_default_8 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_default_7, [2, 512, 128], [0, 1, 2]);  broadcast_in_dim_default_7 = None
    mul_default_5 = torch.ops.nvprims.mul.default(mul_default_1, broadcast_in_dim_default_8);  mul_default_1 = broadcast_in_dim_default_8 = None
    sub_default_3 = torch.ops.nvprims.sub.default(sub_default_2, mul_default_5);  sub_default_2 = mul_default_5 = None
    mul_default_6 = torch.ops.nvprims.mul.default(broadcast_in_dim_default_10, sub_default_3);  broadcast_in_dim_default_10 = sub_default_3 = None
    mul_default_8 = torch.ops.nvprims.mul.default(mul_default_6, arg348_1);  arg348_1 = None
    mul_default_9 = torch.ops.nvprims.mul.default(mul_default_6, arg226_1);  mul_default_6 = arg226_1 = None
    mul_default_11 = torch.ops.nvprims.mul.default(mul_default_8, sub_default_4);  mul_default_8 = sub_default_4 = None
    mul_default_17 = torch.ops.nvprims.mul.default(mul_default_9, 0.5);  mul_default_9 = None
    mul_default_12 = torch.ops.nvprims.mul.default(mul_default_11, 0.7978845608028654);  mul_default_11 = None
    mul_default_13 = torch.ops.nvprims.mul.default(mul_default_12, 0.044715)
    mul_default_16 = torch.ops.nvprims.mul.default(mul_default_13, mul_default_15);  mul_default_13 = mul_default_15 = None
    add_default_1 = torch.ops.nvprims.add.default(mul_default_12, mul_default_16);  mul_default_12 = mul_default_16 = None
    add_default_2 = torch.ops.nvprims.add.default(add_default_1, mul_default_17);  add_default_1 = mul_default_17 = None
    return (sum_default_4, add_default_2)

gm = make_fx(forward)(t0, t1, t2, t3, t4, t5, t6, t7, t8)
print(gm.graph)
execute(gm, t0, t1, t2, t3, t4, t5, t6, t7, t8, executor="nvfuser")

Versions

You'll need ToT devel, since there's some nvprim changes merged from last master pull. Here's the commit where I got the repro vvv

commit ac4de38c6ee53b366e85fdfe408c3642d32b57df (HEAD, origin/devel, origin/HEAD)
Merge: 631094891a aab10bce45
Author: Christian Sarofeen <csarofeen@nvidia.com>
Date:   Tue Aug 30 15:44:39 2022 -0400

    Merge pull request #1945 from csarofeen/master_merge_0828
    
    Master merge 0828
@jjsjann123
Copy link
Collaborator Author

I can probably minify the repro or get it into a cpp test instead of going through nvprim.

cc'ing @shmsong in case this rings any bell~~

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Aug 31, 2022

Here's a code snippet I extracted from torchbench's timm_vision_transformer with the same assert:

Click here
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.nvfuser_executor import nvfuser_execute

# arg.shape 0: torch.Size([8, 197, 1])
# arg.shape 1: torch.Size([8, 197, 1])
# arg.shape 2: torch.Size([384])
# arg.shape 3: torch.Size([8, 197, 384])
# arg.shape 4: torch.Size([8, 197, 384])
# arg.shape 5: torch.Size([8, 197, 384])
arg263_1 = torch.randn(8, 197, 1, dtype=torch.float32, device='cuda')
arg202_1 = torch.randn(8, 197, 1, dtype=torch.float32, device='cuda')
arg145_1 = torch.randn(384, dtype=torch.float32, device='cuda')
arg291_1 = torch.randn(8, 197, 384, dtype=torch.float32, device='cuda')
_reshape_alias_3 = torch.randn(8, 197, 384, dtype=torch.float32, device='cuda')
mul_5 = torch.randn(8, 197, 384, dtype=torch.float32, device='cuda')

def forward(self, arg263_1, arg202_1, arg145_1, arg291_1, _reshape_alias_3, mul_5):
    div_1 = torch.ops.nvprims.div.default(arg263_1, 384.0)
    broadcast_in_dim_11 = torch.ops.nvprims.broadcast_in_dim.default(arg202_1, [8, 197, 384], [0, 1, 2]);  arg202_1 = None
    broadcast_in_dim_12 = torch.ops.nvprims.broadcast_in_dim.default(arg263_1, [8, 197, 384], [0, 1, 2]);  arg263_1 = None
    broadcast_in_dim_13 = torch.ops.nvprims.broadcast_in_dim.default(arg145_1, [8, 197, 384], [2]);  arg145_1 = None
    broadcast_in_dim_18 = torch.ops.nvprims.broadcast_in_dim.default(div_1, [8, 197, 384], [0, 1, 2]);  div_1 = None
    sub_3 = torch.ops.nvprims.sub.default(arg291_1, broadcast_in_dim_11);  arg291_1 = broadcast_in_dim_11 = None
    mul_15 = torch.ops.nvprims.mul.default(_reshape_alias_3, broadcast_in_dim_13);  broadcast_in_dim_13 = None
    mul_14 = torch.ops.nvprims.mul.default(sub_3, broadcast_in_dim_12);  sub_3 = broadcast_in_dim_12 = None
    mul_16 = torch.ops.nvprims.mul.default(mul_15, 384.0)
    convert_element_type_7 = torch.ops.nvprims.convert_element_type.default(mul_15, torch.float32)
    mul_17 = torch.ops.nvprims.mul.default(mul_15, mul_14);  mul_15 = None
    mul_20 = torch.ops.nvprims.mul.default(_reshape_alias_3, mul_14);  _reshape_alias_3 = None
    sum_8 = torch.ops.nvprims.sum.default(convert_element_type_7, [2]);  convert_element_type_7 = None
    convert_element_type_8 = torch.ops.nvprims.convert_element_type.default(mul_17, torch.float32);  mul_17 = None
    convert_element_type_9 = torch.ops.nvprims.convert_element_type.default(mul_20, torch.float32);  mul_20 = None
    broadcast_in_dim_14 = torch.ops.nvprims.broadcast_in_dim.default(sum_8, [8, 197, 1], [0, 1]);  sum_8 = None
    sum_9 = torch.ops.nvprims.sum.default(convert_element_type_8, [2]);  convert_element_type_8 = None
    sum_10 = torch.ops.nvprims.sum.default(convert_element_type_9, [0, 1]);  convert_element_type_9 = None
    broadcast_in_dim_17 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_14, [8, 197, 384], [0, 1, 2]);  broadcast_in_dim_14 = None
    broadcast_in_dim_15 = torch.ops.nvprims.broadcast_in_dim.default(sum_9, [8, 197, 1], [0, 1]);  sum_9 = None
    sub_4 = torch.ops.nvprims.sub.default(mul_16, broadcast_in_dim_17);  mul_16 = broadcast_in_dim_17 = None
    broadcast_in_dim_16 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_15, [8, 197, 384], [0, 1, 2]);  broadcast_in_dim_15 = None
    mul_18 = torch.ops.nvprims.mul.default(mul_14, broadcast_in_dim_16);  mul_14 = broadcast_in_dim_16 = None
    sub_5 = torch.ops.nvprims.sub.default(sub_4, mul_18);  sub_4 = mul_18 = None
    mul_19 = torch.ops.nvprims.mul.default(broadcast_in_dim_18, sub_5);  broadcast_in_dim_18 = sub_5 = None
    add_2 = torch.ops.nvprims.add.default(mul_5, mul_19);  mul_5 = mul_19 = None
    return (sum_10, add_2)

gm = make_fx(forward)(None, arg263_1, arg202_1, arg145_1, arg291_1, _reshape_alias_3, mul_5)

try:
    nv_output = nvfuser_execute(gm, None, arg263_1, arg202_1, arg145_1, arg291_1, _reshape_alias_3, mul_5)
except RuntimeError as e:
    print(e)
    print("nvfuser_execute failed")

I tried to run it through functorch.compile.minifier but its "minimal" reproducer was exactly the same.

@shmsong
Copy link

shmsong commented Sep 1, 2022

#1949 is a WAR if it needs unblocking. @jjsjann123 could you help confirm if this is enough for fixing?

@shmsong
Copy link

shmsong commented Sep 1, 2022

#1950 is the root cause, moving discussion there.

@shmsong
Copy link

shmsong commented Sep 1, 2022

Long story short, canSchedule gives different results for the same fusion because the output ordering was different. The WAR essentially tries to make the output ordering the same throughout segmenter pass, which isn't a real fix for the underlying issue but should help avoiding it.

The root domain mapping giving different results is the root cause that would definitely need fixing.

@jjsjann123
Copy link
Collaborator Author

Thx for the quick turn around~~ LOoks like Ivan already verified the fix in #1949.

We see quite a lot of benchmarks giving the same failure. looks like most patterns are very similar though. We might want to merge #1949 as-is and kick off a benchmark run to see if all the fusion segmenter issue goes away.

@zasdfgbnm zasdfgbnm mentioned this issue Sep 1, 2022
3 tasks
@csarofeen
Copy link
Owner

Fixed in #1952

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Sep 27, 2022
Fixes csarofeen#1947

Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser:
1. test the groups comes with the same order as they are merged.
2. Fix detection of un-mappable root domains:
    ComputeAtRootDomainMap flags domains that should not be mapped due to
    reductions. Previously, checking if a domain potentially causes an
    invalid mapping is only done with one domain in each group of domains
    that are found to be mappable so far. That's not actually sufficient as
    the unmappable domain set is created just once with no root mapping
    information. The fix is to check all consumer domains of a producer
    tensor. A small other fix is also done to address a different problem
    discovered after the first fix.

Pull Request resolved: #85620
Approved by: https://github.com/csarofeen, https://github.com/davidberard98
mehtanirav pushed a commit to pytorch/pytorch that referenced this issue Oct 4, 2022
Fixes csarofeen#1947

Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser:
1. test the groups comes with the same order as they are merged.
2. Fix detection of un-mappable root domains:
    ComputeAtRootDomainMap flags domains that should not be mapped due to
    reductions. Previously, checking if a domain potentially causes an
    invalid mapping is only done with one domain in each group of domains
    that are found to be mappable so far. That's not actually sufficient as
    the unmappable domain set is created just once with no root mapping
    information. The fix is to check all consumer domains of a producer
    tensor. A small other fix is also done to address a different problem
    discovered after the first fix.

Pull Request resolved: #85620
Approved by: https://github.com/csarofeen, https://github.com/davidberard98
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this issue Oct 29, 2022
Fixes csarofeen/pytorch#1947

Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser:
1. test the groups comes with the same order as they are merged.
2. Fix detection of un-mappable root domains:
    ComputeAtRootDomainMap flags domains that should not be mapped due to
    reductions. Previously, checking if a domain potentially causes an
    invalid mapping is only done with one domain in each group of domains
    that are found to be mappable so far. That's not actually sufficient as
    the unmappable domain set is created just once with no root mapping
    information. The fix is to check all consumer domains of a producer
    tensor. A small other fix is also done to address a different problem
    discovered after the first fix.

Pull Request resolved: pytorch/pytorch#85620
Approved by: https://github.com/csarofeen, https://github.com/davidberard98
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this issue Nov 10, 2022
Fixes csarofeen/pytorch#1947

Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser:
1. test the groups comes with the same order as they are merged.
2. Fix detection of un-mappable root domains:
    ComputeAtRootDomainMap flags domains that should not be mapped due to
    reductions. Previously, checking if a domain potentially causes an
    invalid mapping is only done with one domain in each group of domains
    that are found to be mappable so far. That's not actually sufficient as
    the unmappable domain set is created just once with no root mapping
    information. The fix is to check all consumer domains of a producer
    tensor. A small other fix is also done to address a different problem
    discovered after the first fix.

Pull Request resolved: pytorch/pytorch#85620
Approved by: https://github.com/csarofeen, https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants