Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For-Loop optimized path's predication appears broken with multiple operations and a broadcast first #273

Closed
kevinstephano opened this issue Aug 6, 2020 · 3 comments

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented Aug 6, 2020

🐛 Bug

This issue arises when you have Two tensors:

A: [4, X] 
B: [1, X] <--- broadcast with the operations
Reduction(A + B)

A real example:

__global__ void kernel1(Tensor<float, 1> T0, Tensor<float, 2> T1, Tensor<float, 1> T4){
  __shared__ float shared_mem[1024];
  if ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T4.size[0] ) ) {
    T4[ ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T4.stride[0] ) ]
       = float(0);
  }
  float T5[1];
  T5[ 0 ]
     = float(0);
  for(size_t i51 = 0; i51 < ( ceilDiv(( ceilDiv(T1.size[0], 4) ), 4) ); ++i51 ) {
    float T2[1];
    float T3[4];
    if ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T4.size[0] ) ) {
      for(size_t i52 = 0; i52 < 4; ++i52 ) {
        T2[ 0 ]
           = T0[ ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
        T3[ i52 ]
           = T2[ 0 ]
           + T1[ ( ( ( ( ( i51 * 4 ) + i52 ) * 4 ) + threadIdx.y ) * T1.stride[0] ) + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T1.stride[1] ) ];
        T5[ 0 ]
           = T5[ 0 ]
           + T3[ i52 ];
      }
    } else {
      for(size_t i52 = 0; i52 < 4; ++i52 ) {
        if ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T4.size[0] ) ) {
          T2[ 0 ]
             = T0[ ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
        }
        if ( ( ( ( ( ( ( i51 * 4 ) + i52 ) * 4 ) + threadIdx.y ) < T1.size[0] ) && ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T4.size[0] ) ) ) {
          T3[ i52 ]
             = T2[ 0 ]
             + T1[ ( ( ( ( ( i51 * 4 ) + i52 ) * 4 ) + threadIdx.y ) * T1.stride[0] ) + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T1.stride[1] ) ];
        }
        if ( ( ( ( ( ( ( i51 * 4 ) + i52 ) * 4 ) + threadIdx.y ) < T1.size[0] ) && ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T4.size[0] ) ) ) {
          T5[ 0 ]
             = T5[ 0 ]
             + T3[ i52 ];
        }
      }
    }
  }
  if ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T4.size[0] ) ) {
    blockReduce< false, true, false > ( T4[ ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T4.stride[0] ) ], T5[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
  }
}

Without broadcast:

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 1> T3){
  __shared__ float shared_mem[1024];
  if ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
    T3[ ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
       = float(0);
  }
  float T4[1];
  T4[ 0 ]
     = float(0);
  for(size_t i42 = 0; i42 < ( ceilDiv(( ceilDiv(T0.size[0], 4) ), 4) ); ++i42 ) {
    float T2[4];
    if ( ( ( ( ( ( ( i42 * 4 ) + ( 4 - 1 ) ) * 4 ) + threadIdx.y ) < T0.size[0] ) && ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T3.size[0] ) ) ) {
      for(size_t i43 = 0; i43 < 4; ++i43 ) {
        T2[ i43 ]
           = T0[ ( ( ( ( ( i42 * 4 ) + i43 ) * 4 ) + threadIdx.y ) * T0.stride[0] ) + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T0.stride[1] ) ]
           + T1[ ( ( ( ( ( i42 * 4 ) + i43 ) * 4 ) + threadIdx.y ) * T1.stride[0] ) + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T1.stride[1] ) ];
        T4[ 0 ]
           = T4[ 0 ]
           + T2[ i43 ];
      }
    } else {
      for(size_t i43 = 0; i43 < 4; ++i43 ) {
        if ( ( ( ( ( ( ( i42 * 4 ) + i43 ) * 4 ) + threadIdx.y ) < T0.size[0] ) && ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T3.size[0] ) ) ) {
          T2[ i43 ]
             = T0[ ( ( ( ( ( i42 * 4 ) + i43 ) * 4 ) + threadIdx.y ) * T0.stride[0] ) + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T0.stride[1] ) ]
             + T1[ ( ( ( ( ( i42 * 4 ) + i43 ) * 4 ) + threadIdx.y ) * T1.stride[0] ) + ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T1.stride[1] ) ];
        }
        if ( ( ( ( ( ( ( i42 * 4 ) + i43 ) * 4 ) + threadIdx.y ) < T0.size[0] ) && ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T3.size[0] ) ) ) {
          T4[ 0 ]
             = T4[ 0 ]
             + T2[ i43 ];
        }
      }
    }
  }
  if ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
    blockReduce< false, true, false > ( T3[ ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * T3.stride[0] ) ], T4[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
  }
}

With broadcast, the optimized for loop is gated by the equation:

 if ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T4.size[0] ) ) {

This will be less than the outer dim size of the Output tensor T4. However, given that the nested For-Loop should be executed exactly once ceilDiv(ceilDiv(4, 4), 4) == 1, it requires more predication as is shown in the case without broadcast.

 if ( ( ( ( ( ( ( i42 * 4 ) + ( 4 - 1 ) ) * 4 ) + threadIdx.y ) < T0.size[0] ) && ( ( ( blockIdx.x * 128 ) + threadIdx.x ) < T3.size[0] ) ) ) {

My guess is that the broadcast tensor's predication is getting chosen for all on the optimized path because it happens to come first.

To Reproduce

You will need the file changes found in this PR: #272

Test case with broadcast:

void testGPU_FusionReductionSchedulerReductionOnBroadcastDim() {
  constexpr int bid_x = 80;
  constexpr int tid_x = 4096;
  constexpr int red_dim = 0;

  Fusion fusion;
  FusionGuard fg(&fusion);

  // Set up your input tensor views
  TensorView* tv0 = makeDummyTensor(1);
  TensorView* tv1 = makeDummyTensor(2);
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  TensorView* tv2 =
      add(tv0, tv1);
  fusion.printMath();
  TensorView* tv3 =
      reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv2);
  fusion.addOutput(tv3);
  fusion.printMath();

  const auto options =
      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input0 = at::rand({bid_x*tid_x}, options);
  at::Tensor input1 = at::rand({4, bid_x*tid_x}, options);

  // Apply reduction heuristic
  const at::ArrayRef<c10::IValue> inputs({input0, input1});

  if (true) {
  TORCH_CHECK(
      cuda::scheduleReduction(&fusion, inputs, tv3),
      "Reduction schedule was not generated!");
  } else {
      cuda::scheduleFusion(&fusion, inputs);
  }
  fusion.printMath();

  cuda::FusionExecutor fe;
  fe.compileFusion(&fusion);
  // no broadcasting needed, omitting the last optional argument;
  auto outputs = fe.runFusion({input0, input1});
  auto aten_output = input0.add(input1).sum({red_dim});

  TORCH_CHECK(
      aten_output.allclose(outputs[0]),
      "Error of: ",
      aten_output.sub(outputs[0]).abs().max());
}

Test case without broadcast:

void testGPU_FusionReductionSchedulerReductionNoBroadcastDim() {
  constexpr int bid_x = 80;
  constexpr int tid_x = 4096;
  constexpr int red_dim = 0;

  Fusion fusion;
  FusionGuard fg(&fusion);

  // Set up your input tensor views
  TensorView* tv0 = makeDummyTensor(2);
  TensorView* tv1 = makeDummyTensor(2);
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  
  TensorView* tv2 =
      add(tv0, tv1);
  fusion.printMath();
  TensorView* tv3 =
      reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv2);
  fusion.addOutput(tv3);
  fusion.printMath();
  
  const auto options =
      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input0 = at::rand({4, bid_x*tid_x}, options);
  at::Tensor input1 = at::rand({4, bid_x*tid_x}, options);
  
  // Apply reduction heuristic
  const at::ArrayRef<c10::IValue> inputs({input0, input1});
  
  if (true) {
  TORCH_CHECK(
      cuda::scheduleReduction(&fusion, inputs, tv3),
      "Reduction schedule was not generated!");
  } else {
      cuda::scheduleFusion(&fusion, inputs);
  }   
  fusion.printMath();
  
  cuda::FusionExecutor fe;
  fe.compileFusion(&fusion);
  // no broadcasting needed, omitting the last optional argument;
  auto outputs = fe.runFusion({input0, input1});
  auto aten_output = input0.add(input1).sum({red_dim});
  
  TORCH_CHECK(
      aten_output.allclose(outputs[0]),
      "Error of: ",
      aten_output.sub(outputs[0]).abs().max());
}     
@jjsjann123
Copy link
Collaborator

Can we try this after #260 merged, which is supposed to fix some indexing issue we've been fighting lately.

jjsjann123 added a commit that referenced this issue Aug 11, 2020
switched to scheduleReduction instead of naive scheduleFusion for reduction-fusion;
update FusionExecutorCache to reuse kernel with ReductionParamsHash
Note:
It's failing CI test due to: #273; but luckily we have the other PR merged that disabled broadcasting, so CI is green.
@csarofeen
Copy link
Owner

Kevin could you try this again?

@kevinstephano
Copy link
Collaborator Author

Looks fixed up trying the test case on TOT. Closing...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants