-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Silent wrong result on broadcasting with split and merge #1880
Comments
Generated code: __global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 1> T1, Tensor<float, 2> T4) {
NVFUSER_DEFINE_MAGIC_ZERO
#pragma unroll 1
for(nvfuser_index_t i64 = 0; i64 < (ceilDiv((ceilDiv(T0.size[0], 32)), 1)); ++i64) {
if (((((((((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) && (((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8)))) && (((i64 * 32) + (((((ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8)) * 7) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x))) < T0.size[0])) && ((((ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8)) * 7) + ((nvfuser_index_t)threadIdx.y)) < (ceilDiv((ceilDiv(32, 1)), 32)))) && (((i64 * 32) + ((((((ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8)) * 7) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) / 32)) < T0.size[0])) && (((((ceilDiv(T0.size[1], 32)) - 1) * 32) + ((((((ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8)) * 7) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) % 32)) < T0.size[1]))) {
float T2[(8 * 1)];
#pragma unroll
for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
T2[i55] = 0;
}
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll
for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
T2[i55]
= T1[(((i64 * 32) + (((((i55 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x))) * T1.stride[0])];
}
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll 1
for(nvfuser_index_t i66 = 0; i66 < (ceilDiv(T0.size[1], 32)); ++i66) {
#pragma unroll
for(nvfuser_index_t i67 = 0; i67 < 8; ++i67) {
int64_t i120;
i120 = (i64 * 32) + ((((((i67 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) / 32);
int64_t i117;
i117 = (i66 * 32) + ((((((i67 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) % 32);
float T3[1];
T3[0]
= T2[i67];
T4[(i120 * T0.size[1]) + i117]
= T3[0]
+ T0[(i120 * T0.stride[0]) + (i117 * T0.stride[1])];
}
NVFUSER_UPDATE_MAGIC_ZERO
}
} else {
float T2[(8 * 1)];
#pragma unroll
for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
if ((((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8)))) {
T2[i55] = 0;
}
}
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll
for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
int64_t i170;
i170 = (i64 * 32) + (((((i55 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x));
if ((((i170 < T0.size[0]) && ((((i55 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) < (ceilDiv((ceilDiv(32, 1)), 32)))) && (((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))))) {
T2[i55]
= T1[(i170 * T1.stride[0])];
}
}
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll 1
for(nvfuser_index_t i66 = 0; i66 < (ceilDiv(T0.size[1], 32)); ++i66) {
#pragma unroll
for(nvfuser_index_t i67 = 0; i67 < 8; ++i67) {
int64_t i201;
i201 = (i64 * 32) + ((((((i67 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) / 32);
int64_t i198;
i198 = (i66 * 32) + ((((((i67 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) % 32);
float T3[1];
T3[0]
= T2[i67];
if ((((i201 < T0.size[0]) && (i198 < T0.size[1])) && (((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))))) {
T4[(i201 * T0.size[1]) + i198]
= T3[0]
+ T0[(i201 * T0.stride[0]) + (i198 * T0.stride[1])];
}
}
NVFUSER_UPDATE_MAGIC_ZERO
}
}
}
} This predicate is wrong: (((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) It should be And the index is also wrong: T1[(((i64 * 32) + (((((i55 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x))) * T1.stride[0])]; There should be a |
I am not very familiar with index and predicate calculation, but my guess is, when unrolled, the code that reads T1 will be generated based on T2. But neither T1 nor T2 has the full information about the underlying transformation. The broadcasting dimension appears at T3. So locally looking at the transformation of T1 and T2 when generating code is not sufficient. We should use something with more complete information for predicate and index calculation. |
Just curious does it fail as well if you put all the scheduling on tv4 instead of tv3? |
I see it's the unroll on the inner dimension that was the problem so it would still fail I guess. |
Yes, it does |
Problem seems to come from outer split. Could reproduce the failure with auto tv0 = makeConcreteTensor({32,2});
auto tv1 = makeConcreteTensor({32});
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = set(tv1);
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = add(tv3, tv0);
fusion.addOutput(tv4);
tv3->merge(0);
tv3->split(0, 8, false);
MaxRootDomainInfoSpanningTree tree(tv3);
TransformPropagator tp(tv3);
tree.traverse(&tp);
InlinePropagator inline_propagator(tv3, -2);
tree.traverse(&inline_propagator); The Iterdomains on the right of ca axes after outer split is implicitly concretized so would need to bind the concrete info but it's not loop mapped to other concretized loops. |
Looks like the outer split might not be the root cause of the problem. I removed the outer split, but the problem is still there: TEST_F(NVFuserTest, FusionBroadcastingIndexing_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeSymbolicTensor(2);
auto tv1 = makeSymbolicTensor(1);
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = set(tv1);
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = add(tv3, tv0);
fusion.addOutput(tv4);
tv3->split(0, 32);
tv3->reorder({{1, -1}});
tv3->split(1, 32);
tv3->reorder({{2, -1}});
tv3->merge(2);
tv3->split(2, 1);
tv3->split(2, 128);
tv3->axis(-2)->parallelize(ParallelType::TIDx);
MaxRootDomainInfoSpanningTree tree(tv3);
TransformPropagator tp(tv3);
tree.traverse(&tp);
scheduler_utils::parallelizeAllLike(tv3);
tv2->axis(-3)->parallelize(ParallelType::Unroll);
InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined);
tree.traverse(&inline_propagator);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input0 = at::arange(64, options).view({32, 2});
at::Tensor input1 = at::arange(32, options) * 0.01;
fusion.printMath();
fusion.print();
fusion.printKernel();
FusionExecutor fe;
fe.compileFusion(&fusion, {input0, input1});
auto outputs = fe.runFusion({input0, input1});
std::cout << outputs[0] << std::endl;
auto tv_ref = input0 + input1.unsqueeze(1);
testValidate(
&fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
} |
removing The difference seems to be on the ca_pos of T2: vvv This one failed:
vvv This one passed:
Sorry didn't paste the right printout earlier. Still looking for the root cause. |
Looks like inner split of inner broadcast has similar issue with // This allocation size wouldn't be safe, hard to know exactly how many of T2 we'd need without the
// concrete sizes.
float T2[((ceilDiv((ceilDiv(32, 1)), 128)) * 1)];
#pragma unroll
for(nvfuser_index_t i41 = 0; i41 < (ceilDiv((ceilDiv(32, 1)), 128)); ++i41) {
T2[i41] = 0;
}
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll
for(nvfuser_index_t i41 = 0; i41 < (ceilDiv((ceilDiv(32, 1)), 128)); ++i41) {
int64_t i71;
i71 = (i51 * 32) + (((i41 + nvfuser_zero) * 64) + ((nvfuser_index_t)threadIdx.x));
if (((i71 < T0.size[0]) && ((((i41 + nvfuser_zero) * 64) + ((nvfuser_index_t)threadIdx.x)) < (ceilDiv(32, 1))))) {
T2[i41]
= T1[(i71 * T1.stride[0])];
}
}
NVFUSER_UPDATE_MAGIC_ZERO
float T3[((ceilDiv((ceilDiv((32 * 32), 1)), 64)) * 1)];
#pragma unroll
for(nvfuser_index_t i44 = 0; i44 < (ceilDiv((ceilDiv((32 * 32), 1)), 64)); ++i44) {
// Indexing of T2 would need to be fixed, but also depending on how T2 is allocated.
T3[i44]
= T2[i44];
} |
Other related repros: TEST_F(NVFuserTest, FusionBroadcastingIndexingOuter_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = TensorViewBuilder().shape({6, 5}).dtype(DataType::Float).contiguity({true, true}).build();
auto tv1 = TensorViewBuilder().shape({6}).dtype(DataType::Float).contiguity({true}).build();
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = set(tv1);
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = add(tv3, tv0);
fusion.addOutput(tv4);
tv3->merge(0);
tv3->split(0, 4, false);
MaxRootDomainInfoSpanningTree tree(tv3);
TransformPropagator tp(tv3);
tree.traverse(&tp);
auto inline_propagator = InlinePropagator(tv3, 1, ComputeAtMode::BestEffort);
tree.traverse(&inline_propagator);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input0 = at::arange(6*5, options).view({6, 5});
at::Tensor input1 = at::arange(6, options) * 0.01;
fusion.printMath();
fusion.print();
fusion.printKernel();
FusionExecutor fe;
fe.compileFusion(&fusion, {input0, input1});
auto outputs = fe.runFusion({input0, input1});
std::cout << outputs[0] << std::endl;
auto tv_ref = input0 + input1.unsqueeze(1);
testValidate(
&fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
} Setting T2 and T3 seems right, but accessing T3 seems wrong. TEST_F(NVFuserTest, FusionBroadcastingIndexing_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeSymbolicTensor(2);
auto tv1 = makeSymbolicTensor(1);
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = set(tv1);
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = add(tv3, tv0);
fusion.addOutput(tv4);
tv3->split(0, 32);
tv3->reorder({{1, -1}});
tv3->split(1, 32);
tv3->reorder({{2, -1}});
// [Io0, Bo1, Ii0(32), Bi1(32)]
tv3->merge(2);
tv3->split(2, 128);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
MaxRootDomainInfoSpanningTree tree(tv3);
TransformPropagator tp(tv3);
tree.traverse(&tp);
scheduler_utils::parallelizeAllLike(tv3);
InlinePropagator inline_propagator(tv3,1);
tree.traverse(&inline_propagator);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input0 = at::arange(64, options).view({32, 2});
at::Tensor input1 = at::arange(32, options) * 0.01;
fusion.printMath();
fusion.print();
fusion.printKernel();
FusionExecutor fe;
fe.compileFusion(&fusion, {input0, input1});
auto outputs = fe.runFusion({input0, input1});
auto tv_ref = input0 + input1.unsqueeze(1);
testValidate(
&fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
} Which won't fail if TIDx binding is removed. Static shaped version will fail parallelization check instead of silent erroring. |
@zasdfgbnm could you add a case with vectorization as mentioned in #1918 ? Or do you think there isn't anything specific to vectorization for this bug? |
Some thoughts in this space:
So based on compute at
So we don't have to generate that many values of
The indexing math annoying here and I might have made minor mistakes but that shouldn't impact the conclusion if so.
We know I think this can get even stranger, where a tensor could look like its leaves are 2D, but we would really have to generate and use a 3D loop nest for it. |
Just looked into this fusion. Looks like it's a parallelization problem rather than indexing. Here's the kernel math:
Notice that the TIDX parallelization is done for all tensors, but those parallelized domains are not exactly mapped. This should be fine as long as the domains are on shared memory (or on global memory with grid sync), but we are using the local memory for The validation passes if the parallelization is removed. It also passes if So, the first problem is the parallel validation fails to detect the invalid parallelization, which we need to fix. Another thing to consider is if the current behavior of Since parallel types and memory types are tightly related, it seems to make sense to update both with Any thoughts? |
What does inlining a parallel dimension mean? |
Sorry, what do you mean? |
Aren't our validation/communication rules dependent on if a parallel dimension is inlined or not? What happens if you most inline your example? |
Yes, |
What's the difference between processing a parallel dimension as being |
The problem here is to identify when a producer ID and a consumer ID have the same index. When they are exactly mapped, they are guaranteed to have the same index. It isn't generally the case when they are permissively mapped, but since the producer tensor is actually indexed using the consumer CA domains, the producer ID ends up using the same index in some cases. The question is when exactly that happens. For example, this is a simplified case based on the above repro by Xiang.
The problem is the forwarding of Suppose the CA position of Overall, the problem here is when we can say two permissively mapped domains have the same index. By definition, they must have a forwarded merge, which can result in different indices. However, when either of the outputs of a forwarded merge is also used for indexing the producer through computeAt, they should end up using the same index. A similar analysis may be required for trivial reductions as they are basically a reverse operation. |
I think there's two issues in this thread, (uncertain if 3 below is really in this thread or just a separate thought): |
Extra note on 1: We can't demote broadcast dimensions if their merge is part of reshape. |
🐛 Describe the bug
Removing the
unroll
will fix the issue.Versions
devel
The text was updated successfully, but these errors were encountered: