-
Notifications
You must be signed in to change notification settings - Fork 7
Description
The decision whether blockBroadcast
is used is done by looking at the parallel type of an IterDomain of the output Fusion TensorView, which may not have the correct parallel type when computed at another tensor. It should use the ComputeAtMap rather than just looking at the IterDomain of the tensor.
Repro:
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
auto tv1 = sum(tv0, {1});
auto tv2 = broadcast(tv1, {false, true});
auto tv3 = makeSymbolicTensor(2);
fusion.addInput(tv3);
auto tv4 = add(tv2, tv3);
fusion.addOutput(tv4);
fusion.printMath();
fusion.printKernel();
tv1->computeAt(tv4, -1);
tv4->axis(-1)->parallelize(ParallelType::TIDx);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
// Uncomment below, and the test passes. This explicit
// parallelization of tv2 should not be necessary.
// tv2->axis(-1)->parallelize(ParallelType::TIDx);
The problem happens with tv2
. Its computeAt position is 2, so it should inherit the parallelization of tv4
for both of the two axes, and thus it should generate a call to blockBroadcast
. However, here's what's generated:
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T3, Tensor<float, 2> T4) {
alignas(4) extern __shared__ char array[];
void* shared_mem = array;
#pragma unroll 1
for(size_t ki28 = 0; ki28 < T0.size[0]; ++ki28) {
float T1[1];
T1[0] = 0;
blockReduce<true, false, false>(
T1[0],
T0[(ki28 * T0.stride[0]) + (threadIdx.x * T0.stride[1])],
[](float &a, float b) { a = a + b; },
threadIdx,
blockDim,
static_cast<float*>(shared_mem),
true,
float(0));
float T2[1];
if (((threadIdx.x == 0))) {
T2[0]
= T1[0];
}
if (((threadIdx.x == 0))) {
T4[(ki28 * (1 * blockDim.x)) + (threadIdx.x * 1)]
= T2[0]
+ T3[(ki28 * T3.stride[0]) + (threadIdx.x * T3.stride[1])];
}
}
}
Notice that T2
does not use blockBroadcast
.
The problem can be worked around by explicitly parallelizing the second axis of T2
as well.
The bug resides in getParallelBroadcastDomains, where we simply use a Fusion IterDomain of a Fusion TensorView.