-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fusion Segmenter asserts #1947
Comments
I can probably minify the repro or get it into a cpp test instead of going through nvprim. cc'ing @shmsong in case this rings any bell~~ |
Here's a code snippet I extracted from torchbench's Click hereimport torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.nvfuser_executor import nvfuser_execute
# arg.shape 0: torch.Size([8, 197, 1])
# arg.shape 1: torch.Size([8, 197, 1])
# arg.shape 2: torch.Size([384])
# arg.shape 3: torch.Size([8, 197, 384])
# arg.shape 4: torch.Size([8, 197, 384])
# arg.shape 5: torch.Size([8, 197, 384])
arg263_1 = torch.randn(8, 197, 1, dtype=torch.float32, device='cuda')
arg202_1 = torch.randn(8, 197, 1, dtype=torch.float32, device='cuda')
arg145_1 = torch.randn(384, dtype=torch.float32, device='cuda')
arg291_1 = torch.randn(8, 197, 384, dtype=torch.float32, device='cuda')
_reshape_alias_3 = torch.randn(8, 197, 384, dtype=torch.float32, device='cuda')
mul_5 = torch.randn(8, 197, 384, dtype=torch.float32, device='cuda')
def forward(self, arg263_1, arg202_1, arg145_1, arg291_1, _reshape_alias_3, mul_5):
div_1 = torch.ops.nvprims.div.default(arg263_1, 384.0)
broadcast_in_dim_11 = torch.ops.nvprims.broadcast_in_dim.default(arg202_1, [8, 197, 384], [0, 1, 2]); arg202_1 = None
broadcast_in_dim_12 = torch.ops.nvprims.broadcast_in_dim.default(arg263_1, [8, 197, 384], [0, 1, 2]); arg263_1 = None
broadcast_in_dim_13 = torch.ops.nvprims.broadcast_in_dim.default(arg145_1, [8, 197, 384], [2]); arg145_1 = None
broadcast_in_dim_18 = torch.ops.nvprims.broadcast_in_dim.default(div_1, [8, 197, 384], [0, 1, 2]); div_1 = None
sub_3 = torch.ops.nvprims.sub.default(arg291_1, broadcast_in_dim_11); arg291_1 = broadcast_in_dim_11 = None
mul_15 = torch.ops.nvprims.mul.default(_reshape_alias_3, broadcast_in_dim_13); broadcast_in_dim_13 = None
mul_14 = torch.ops.nvprims.mul.default(sub_3, broadcast_in_dim_12); sub_3 = broadcast_in_dim_12 = None
mul_16 = torch.ops.nvprims.mul.default(mul_15, 384.0)
convert_element_type_7 = torch.ops.nvprims.convert_element_type.default(mul_15, torch.float32)
mul_17 = torch.ops.nvprims.mul.default(mul_15, mul_14); mul_15 = None
mul_20 = torch.ops.nvprims.mul.default(_reshape_alias_3, mul_14); _reshape_alias_3 = None
sum_8 = torch.ops.nvprims.sum.default(convert_element_type_7, [2]); convert_element_type_7 = None
convert_element_type_8 = torch.ops.nvprims.convert_element_type.default(mul_17, torch.float32); mul_17 = None
convert_element_type_9 = torch.ops.nvprims.convert_element_type.default(mul_20, torch.float32); mul_20 = None
broadcast_in_dim_14 = torch.ops.nvprims.broadcast_in_dim.default(sum_8, [8, 197, 1], [0, 1]); sum_8 = None
sum_9 = torch.ops.nvprims.sum.default(convert_element_type_8, [2]); convert_element_type_8 = None
sum_10 = torch.ops.nvprims.sum.default(convert_element_type_9, [0, 1]); convert_element_type_9 = None
broadcast_in_dim_17 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_14, [8, 197, 384], [0, 1, 2]); broadcast_in_dim_14 = None
broadcast_in_dim_15 = torch.ops.nvprims.broadcast_in_dim.default(sum_9, [8, 197, 1], [0, 1]); sum_9 = None
sub_4 = torch.ops.nvprims.sub.default(mul_16, broadcast_in_dim_17); mul_16 = broadcast_in_dim_17 = None
broadcast_in_dim_16 = torch.ops.nvprims.broadcast_in_dim.default(broadcast_in_dim_15, [8, 197, 384], [0, 1, 2]); broadcast_in_dim_15 = None
mul_18 = torch.ops.nvprims.mul.default(mul_14, broadcast_in_dim_16); mul_14 = broadcast_in_dim_16 = None
sub_5 = torch.ops.nvprims.sub.default(sub_4, mul_18); sub_4 = mul_18 = None
mul_19 = torch.ops.nvprims.mul.default(broadcast_in_dim_18, sub_5); broadcast_in_dim_18 = sub_5 = None
add_2 = torch.ops.nvprims.add.default(mul_5, mul_19); mul_5 = mul_19 = None
return (sum_10, add_2)
gm = make_fx(forward)(None, arg263_1, arg202_1, arg145_1, arg291_1, _reshape_alias_3, mul_5)
try:
nv_output = nvfuser_execute(gm, None, arg263_1, arg202_1, arg145_1, arg291_1, _reshape_alias_3, mul_5)
except RuntimeError as e:
print(e)
print("nvfuser_execute failed") I tried to run it through |
#1949 is a WAR if it needs unblocking. @jjsjann123 could you help confirm if this is enough for fixing? |
#1950 is the root cause, moving discussion there. |
Long story short, The root domain mapping giving different results is the root cause that would definitely need fixing. |
Thx for the quick turn around~~ LOoks like Ivan already verified the fix in #1949. We see quite a lot of benchmarks giving the same failure. looks like most patterns are very similar though. We might want to merge #1949 as-is and kick off a benchmark run to see if all the fusion segmenter issue goes away. |
Fixed in #1952 |
Fixes csarofeen#1947 Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser: 1. test the groups comes with the same order as they are merged. 2. Fix detection of un-mappable root domains: ComputeAtRootDomainMap flags domains that should not be mapped due to reductions. Previously, checking if a domain potentially causes an invalid mapping is only done with one domain in each group of domains that are found to be mappable so far. That's not actually sufficient as the unmappable domain set is created just once with no root mapping information. The fix is to check all consumer domains of a producer tensor. A small other fix is also done to address a different problem discovered after the first fix. Pull Request resolved: #85620 Approved by: https://github.com/csarofeen, https://github.com/davidberard98
Fixes csarofeen#1947 Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser: 1. test the groups comes with the same order as they are merged. 2. Fix detection of un-mappable root domains: ComputeAtRootDomainMap flags domains that should not be mapped due to reductions. Previously, checking if a domain potentially causes an invalid mapping is only done with one domain in each group of domains that are found to be mappable so far. That's not actually sufficient as the unmappable domain set is created just once with no root mapping information. The fix is to check all consumer domains of a producer tensor. A small other fix is also done to address a different problem discovered after the first fix. Pull Request resolved: #85620 Approved by: https://github.com/csarofeen, https://github.com/davidberard98
Fixes csarofeen/pytorch#1947 Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser: 1. test the groups comes with the same order as they are merged. 2. Fix detection of un-mappable root domains: ComputeAtRootDomainMap flags domains that should not be mapped due to reductions. Previously, checking if a domain potentially causes an invalid mapping is only done with one domain in each group of domains that are found to be mappable so far. That's not actually sufficient as the unmappable domain set is created just once with no root mapping information. The fix is to check all consumer domains of a producer tensor. A small other fix is also done to address a different problem discovered after the first fix. Pull Request resolved: pytorch/pytorch#85620 Approved by: https://github.com/csarofeen, https://github.com/davidberard98
Fixes csarofeen/pytorch#1947 Cherry-picked patch for torchbench issues where fusion segmenter asserts in nvfuser: 1. test the groups comes with the same order as they are merged. 2. Fix detection of un-mappable root domains: ComputeAtRootDomainMap flags domains that should not be mapped due to reductions. Previously, checking if a domain potentially causes an invalid mapping is only done with one domain in each group of domains that are found to be mappable so far. That's not actually sufficient as the unmappable domain set is created just once with no root mapping information. The fix is to check all consumer domains of a producer tensor. A small other fix is also done to address a different problem discovered after the first fix. Pull Request resolved: pytorch/pytorch#85620 Approved by: https://github.com/csarofeen, https://github.com/davidberard98
🐛 Describe the bug
The same asserts on a few hugging face benchmarks.
Repro on devel
Versions
You'll need ToT devel, since there's some nvprim changes merged from last master pull. Here's the commit where I got the repro vvv
The text was updated successfully, but these errors were encountered: