Skip to content

Missing parallel broadcast #757

@naoyam

Description

@naoyam

The decision whether blockBroadcast is used is done by looking at the parallel type of an IterDomain of the output Fusion TensorView, which may not have the correct parallel type when computed at another tensor. It should use the ComputeAtMap rather than just looking at the IterDomain of the tensor.

Repro:

  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(2);
  fusion.addInput(tv0);
  auto tv1 = sum(tv0, {1});
  auto tv2 = broadcast(tv1, {false, true});
  auto tv3 = makeSymbolicTensor(2);
  fusion.addInput(tv3);
  auto tv4 = add(tv2, tv3);
  fusion.addOutput(tv4);

  fusion.printMath();
  fusion.printKernel();

  tv1->computeAt(tv4, -1);

  tv4->axis(-1)->parallelize(ParallelType::TIDx);
  tv1->axis(-1)->parallelize(ParallelType::TIDx);

  // Uncomment below, and the test passes. This explicit
  // parallelization of tv2 should not be necessary.
  // tv2->axis(-1)->parallelize(ParallelType::TIDx);

The problem happens with tv2. Its computeAt position is 2, so it should inherit the parallelization of tv4 for both of the two axes, and thus it should generate a call to blockBroadcast. However, here's what's generated:

__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T3, Tensor<float, 2> T4) {
  alignas(4) extern __shared__ char array[];
  void* shared_mem = array;
  #pragma unroll 1
  for(size_t ki28 = 0; ki28 < T0.size[0]; ++ki28) {
    float T1[1];
    T1[0] = 0;
    blockReduce<true, false, false>(
      T1[0],
      T0[(ki28 * T0.stride[0]) + (threadIdx.x * T0.stride[1])],
      [](float &a, float b) { a = a + b; },
      threadIdx,
      blockDim,
      static_cast<float*>(shared_mem),
      true,
      float(0));
    float T2[1];
    if (((threadIdx.x == 0))) {
      T2[0]
         = T1[0];
    }
    if (((threadIdx.x == 0))) {
      T4[(ki28 * (1 * blockDim.x)) + (threadIdx.x * 1)]
        = T2[0]
        + T3[(ki28 * T3.stride[0]) + (threadIdx.x * T3.stride[1])];
    }
  }
}

Notice that T2 does not use blockBroadcast.

The problem can be worked around by explicitly parallelizing the second axis of T2 as well.

The bug resides in getParallelBroadcastDomains, where we simply use a Fusion IterDomain of a Fusion TensorView.

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions