diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 4535e97be85fc..e4aa6db768a32 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -1038,32 +1038,32 @@ void testGPU_FusionParser() { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3){ float T2[4]; if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - for(size_t i33 = 0; i33 < 4; ++i33 ) { - T2[ i33 ] - = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i33 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] - * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i33 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + for(size_t i29 = 0; i29 < 4; ++i29 ) { + T2[ i29 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; } } else { - for(size_t i33 = 0; i33 < 4; ++i33 ) { - if ( ( ( ( ( ( blockIdx.x * 4 ) + i33 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - T2[ i33 ] - = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i33 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] - * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i33 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + for(size_t i29 = 0; i29 < 4; ++i29 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T2[ i29 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; } } } if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - for(size_t i34 = 0; i34 < 4; ++i34 ) { - T3[ ( ( ( ( ( blockIdx.x * 4 ) + i34 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] - = T2[ i34 ] - * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i34 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + for(size_t i30 = 0; i30 < 4; ++i30 ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i30 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; } } else { - for(size_t i34 = 0; i34 < 4; ++i34 ) { - if ( ( ( ( ( ( blockIdx.x * 4 ) + i34 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - T3[ ( ( ( ( ( blockIdx.x * 4 ) + i34 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] - = T2[ i34 ] - * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i34 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + for(size_t i30 = 0; i30 < 4; ++i30 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i30 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; } } } @@ -3055,79 +3055,277 @@ void testGPU_FusionSimpleGemm() { } } -// This test currently requires a combination of broadcast and reduction -// operations and parellelization strategy that is currently not supported. -// It is a goal to get this example working and this test is added so we -// can continue working on getting this example fixed. Right now it -// produces an incorrect result. Either we need to error coherently on the -// optimization strategy we don't support and set this test to one we do support -// or we need to get this schedule working correctly. -void testGPU_FusionSoftmax() { +// Softmax with a 1D tensor. Parallelized only with a single thread block. +void testGPU_FusionSoftmax1D() { torch::jit::fuser::cuda::CudaKernel prog; Fusion& fusion = *prog.fusion_; FusionGuard fg(&fusion); + const int tidx = 128; + const int dimx = 1000; + + // Set up your input tensor views + TensorView* input_tv0 = makeDummyTensor(1); + fusion.addInput(input_tv0); + + TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); + TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); + TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0); + + TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); + + fusion.addOutput(output_tv4); + + sum_exp_tv2->split(-1, tidx); + TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); + + output_tv4->split(-1, tidx); + + exp_tv1->computeAt(sum_exp_rf_tv5, -1); + exp_tv1_copy->computeAt(output_tv4, -1); + + TensorView* tensors_to_parallelize[] = { + sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + prog.device_ = 0; + prog.grid(1, 1); + prog.block(tidx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dimx}, options); + at::Tensor cg_output = at::empty({dimx}, options); + at::Tensor t3_output = at::empty_like(cg_output, options); + torch::jit::fuser::cuda::compileKernel(&prog); + + torch::jit::fuser::cuda::runTestKernel(&prog, {t0}, {cg_output}); + + auto t2 = at::_softmax(t0, -1, false); + TORCH_CHECK( + t2.allclose(cg_output, 1e-5, 1e-5), + "Error of: ", + t2.sub(cg_output).abs().max()); +} + +// Softmax with a 1D tensor with input normalization. +void testGPU_FusionSoftmax1DNormalized() { + torch::jit::fuser::cuda::CudaKernel prog; + Fusion& fusion = *prog.fusion_; + FusionGuard fg(&fusion); + + const int tidx = 128; + const int dimx = 1000; + + // Set up your input tensor views + TensorView* input_tv0 = makeDummyTensor(1); + fusion.addInput(input_tv0); + + // Normalize with the max value before computing exp. + TensorView* max_val_tv1 = + reductionOp(BinaryOpType::Max, {-1}, new Float(0), input_tv0); + TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true}); + TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); + TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); + TensorView* sum_exp_tv5 = sum(exp_tv4, {-1}); + TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2); + TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy); + + TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); + + fusion.addOutput(output_tv7); + + max_val_tv1->split(-1, tidx); + TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); + + sum_exp_tv5->split(-1, tidx); + TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2}); + + output_tv7->split(-1, tidx); + + sub_tv3->computeAt(sum_exp_rf_tv9, -1); + sub_tv3_copy->computeAt(output_tv7, -1); + + TensorView* tensors_to_parallelize[] = {max_val_tv1, + bcast_max_tv2, + sum_exp_tv5, + bcast_sum_tv6, + output_tv7, + max_val_rf_tv8, + sum_exp_rf_tv9}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + prog.device_ = 0; + prog.grid(1, 1); + prog.block(tidx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dimx}, options); + at::Tensor cg_output = at::empty({dimx}, options); + at::Tensor t3_output = at::empty_like(cg_output, options); + torch::jit::fuser::cuda::compileKernel(&prog); + + torch::jit::fuser::cuda::runTestKernel(&prog, {t0}, {cg_output}); + + auto t2 = at::_softmax(t0, -1, false); + TORCH_CHECK( + t2.allclose(cg_output, 1e-5, 1e-5), + "Error of: ", + t2.sub(cg_output).abs().max()); +} + +// Softmax with a 3D tensor, where the inner-most 3rd dimension is +// normalized. Pallelized with multiple thread blocks. +void testGPU_FusionSoftmax3D() { + torch::jit::fuser::cuda::CudaKernel prog; + Fusion& fusion = *prog.fusion_; + FusionGuard fg(&fusion); + + const int tidx = 32; + const int dimx = 32; + const int dimy = 16; + const int dimz = 130; + + // Set up your input tensor views + TensorView* input_tv0 = makeDummyTensor(3); + fusion.addInput(input_tv0); + + TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); + TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); + TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0); + + TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); + + fusion.addOutput(output_tv4); + + sum_exp_tv2->split(-1, tidx); + TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); + + output_tv4->split(-1, tidx); + + exp_tv1->computeAt(sum_exp_rf_tv5, -1); + exp_tv1_copy->computeAt(output_tv4, -1); + + TensorView* tensors_to_parallelize[] = { + sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::BIDy); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + prog.device_ = 0; + prog.grid(dimx, dimy); + prog.block(tidx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dimx, dimy, dimz}, options); + at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); + at::Tensor t3_output = at::empty_like(cg_output, options); + torch::jit::fuser::cuda::compileKernel(&prog); + + torch::jit::fuser::cuda::runTestKernel(&prog, {t0}, {cg_output}); + + auto t2 = at::_softmax(t0, -1, false); + TORCH_CHECK( + t2.allclose(cg_output, 1e-5, 1e-5), + "Error of: ", + t2.sub(cg_output).abs().max()); +} + +// Softmax with a 3D tensor with input normalization. +void testGPU_FusionSoftmax3DNormalized() { + torch::jit::fuser::cuda::CudaKernel prog; + Fusion& fusion = *prog.fusion_; + FusionGuard fg(&fusion); + + const int tidx = 32; + const int dimx = 32; + const int dimy = 16; + const int dimz = 130; + // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(3); fusion.addInput(input_tv0); + // Normalize with the max value before computing exp. TensorView* max_val_tv1 = - reductionOp(BinaryOpType::Max, {2}, new Float(0), input_tv0); + reductionOp(BinaryOpType::Max, {-1}, new Float(0), input_tv0); TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true}); - TensorView* exp_tv3 = sub(input_tv0, bcast_max_tv2); - TensorView* sum_exp_tv4 = - reductionOp(BinaryOpType::Add, {2}, new Float(0), exp_tv3); - TensorView* bcast_sum_tv5 = broadcast(sum_exp_tv4, {false, false, true}); - TensorView* output_tv6 = div(exp_tv3, bcast_sum_tv5); - - max_val_tv1->split(-1, 32); - TensorView* max_val_rf_tv7 = max_val_tv1->rFactor({-2}); - sum_exp_tv4->split(-1, 32); - TensorView* sum_exp_rf_tv8 = sum_exp_tv4->rFactor({-2}); - - exp_tv3->computeAt(sum_exp_rf_tv8, 2); - - max_val_rf_tv7->axis(0)->parallelize(ParallelType::BIDx); - max_val_tv1->axis(0)->parallelize(ParallelType::BIDx); - bcast_max_tv2->axis(0)->parallelize(ParallelType::BIDx); - sum_exp_rf_tv8->axis(0)->parallelize(ParallelType::BIDx); - sum_exp_tv4->axis(0)->parallelize(ParallelType::BIDx); - bcast_sum_tv5->axis(0)->parallelize(ParallelType::BIDx); - output_tv6->axis(0)->parallelize(ParallelType::BIDx); - - max_val_rf_tv7->axis(1)->parallelize(ParallelType::BIDy); - max_val_tv1->axis(1)->parallelize(ParallelType::BIDy); - bcast_max_tv2->axis(1)->parallelize(ParallelType::BIDy); - sum_exp_rf_tv8->axis(1)->parallelize(ParallelType::BIDy); - sum_exp_tv4->axis(1)->parallelize(ParallelType::BIDy); - bcast_sum_tv5->axis(1)->parallelize(ParallelType::BIDy); - output_tv6->axis(1)->parallelize(ParallelType::BIDy); - - max_val_rf_tv7->axis(-1)->parallelize(ParallelType::TIDx); - max_val_tv1->axis(-1)->parallelize(ParallelType::TIDx); - bcast_max_tv2->axis(-1)->parallelize(ParallelType::TIDx); - exp_tv3->axis(-1)->parallelize(ParallelType::TIDx); - sum_exp_rf_tv8->axis(-1)->parallelize(ParallelType::TIDx); - sum_exp_tv4->axis(-1)->parallelize(ParallelType::TIDx); - bcast_sum_tv5->axis(-1)->parallelize(ParallelType::TIDx); - output_tv6->axis(-1)->parallelize(ParallelType::TIDx); - - fusion.addOutput(output_tv6); + TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); + TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); + TensorView* sum_exp_tv5 = sum(exp_tv4, {-1}); + TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2); + TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy); + + TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); + + fusion.addOutput(output_tv7); + + max_val_tv1->split(-1, tidx); + TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); + + sum_exp_tv5->split(-1, tidx); + TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2}); + + output_tv7->split(-1, tidx); + + sub_tv3->computeAt(sum_exp_rf_tv9, -1); + sub_tv3_copy->computeAt(output_tv7, -1); + + TensorView* tensors_to_parallelize[] = {max_val_tv1, + bcast_max_tv2, + sum_exp_tv5, + bcast_sum_tv6, + output_tv7, + max_val_rf_tv8, + sum_exp_rf_tv9}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::BIDy); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } prog.device_ = 0; - prog.grid(32, 32); - prog.block(32); + prog.grid(dimx, dimy); + prog.block(tidx); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({32, 32, 128}, options); - at::Tensor cg_output = at::empty({32, 32, 128}, options); + at::Tensor t0 = at::randn({dimx, dimy, dimz}, options); + at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); + at::Tensor t3_output = at::empty_like(cg_output, options); torch::jit::fuser::cuda::compileKernel(&prog); + torch::jit::fuser::cuda::runTestKernel(&prog, {t0}, {cg_output}); auto t2 = at::_softmax(t0, -1, false); - // TORCH_CHECK( - // t2.allclose(cg_output, 1e-5, 1e-5), - // "Error of: ", - // t2.sub(cg_output).abs().max()); + TORCH_CHECK( + t2.allclose(cg_output, 1e-5, 1e-5), + "Error of: ", + t2.sub(cg_output).abs().max()); } void testGPU_FusionSoftmaxComputeAt() { @@ -3915,6 +4113,61 @@ void testGPU_FusionZeroDimReduction() { aten_output.sub(output).abs().max()); } +void testGPU_FusionBCastAfterReduce() { + torch::jit::fuser::cuda::CudaKernel prog; + Fusion& fusion = *prog.fusion_; + FusionGuard fg(&fusion); + + const int tidx = 128; + + // Set up your input tensor views + TensorView* tv0 = makeDummyTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + + tv1->split(1, tidx); + auto tv3 = tv1->rFactor({-2}); + + TensorView* tv4 = makeDummyTensor(2); + fusion.addInput(tv4); + + auto tv5 = add(tv2, tv4); + fusion.addOutput(tv5); + tv5->split(1, tidx); + + tv3->computeAt(tv5, 1); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + + size_t x = 63, y = 200; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y}, options); + at::Tensor t4 = at::randn({x, y}, options); + + at::Tensor cg_output = at::empty({x, y}, options); + + prog.device_ = 0; + prog.grid(x); + prog.block(tidx); + torch::jit::fuser::cuda::compileKernel(&prog); + torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t4}, {cg_output}); + + auto t3 = t0.sum({1}).unsqueeze(-1).expand({x, y}); + auto t5 = t3.add(t4); + + // Error is larger than the default threshold + TORCH_CHECK(t5.allclose(cg_output, 1e-5, 1e-5)); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 1bf0a5a880a36..e680d8a1b37c5 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -97,76 +97,80 @@ namespace jit { _(FusionAliasing) #if defined(USE_CUDA) -#define TH_FORALL_TESTS_CUDA(_) \ - _(ArgumentSpec) \ - _(CompleteArgumentSpec) \ - _(Fusion) \ - _(GraphExecutor) \ - _(ModuleConversion) \ - _(Interp) \ - _(GPU_IrGraphGenerator) \ - _(GPU_FusionDispatch) \ - _(GPU_FusionClear) \ - _(GPU_FusionCopy) \ - _(GPU_FusionMove) \ - _(GPU_FusionSimpleArith) \ - _(GPU_FusionExprEvalConstants) \ - _(GPU_FusionExprEvalBindings) \ - _(GPU_FusionExprEvalBasic) \ - _(GPU_FusionExprEvalComplex) \ - _(GPU_FusionExprEvalPostLower) \ - _(GPU_FusionSimpleTypePromote) \ - _(GPU_FusionMutator) \ - _(GPU_FusionRegister) \ - _(GPU_FusionTopoSort) \ - _(GPU_FusionTensor) \ - _(GPU_FusionTVSplit) \ - _(GPU_FusionTVMerge) \ - _(GPU_FusionTVReorder) \ - _(GPU_FusionEquality) \ - _(GPU_FusionReplaceAll) \ - _(GPU_FusionParser) \ - _(GPU_FusionDependency) \ - _(GPU_FusionCodeGen) \ - _(GPU_FusionCodeGen2) \ - _(GPU_FusionSimplePWise) \ - _(GPU_FusionExecKernel) \ - _(GPU_FusionForLoop) \ - _(GPU_FusionLoopUnroll) \ - _(GPU_FusionUnaryOps) \ - _(GPU_FusionBinaryOps) \ - _(GPU_FusionTernaryOps) \ - _(GPU_FusionCompoundOps) \ - _(GPU_FusionCastOps) \ - _(GPU_FusionAdvancedComputeAt) \ - _(GPU_FusionScalarInputs) \ - _(GPU_FusionRFactorReplay) \ - _(GPU_FusionReduction) \ - _(GPU_FusionReduction2) \ - _(GPU_FusionReduction3) \ - _(GPU_FusionReduction4) \ - _(GPU_FusionReduction5) \ - _(GPU_FusionReductionTFT) \ - _(GPU_FusionSimpleBCast) \ - _(GPU_FusionSimpleGemm) \ - _(GPU_FusionSoftmax) \ - _(GPU_FusionSoftmaxComputeAt) \ - _(GPU_FusionGridReduction1) \ - _(GPU_FusionGridReduction2) \ - _(GPU_FusionGridReduction3dim1) \ - _(GPU_FusionGridReduction3dim0) \ - _(GPU_FusionGridReduction4) \ - _(GPU_FusionGridReduction5) \ - _(GPU_FusionGridReduction6) \ - _(GPU_FusionNonRedAxisBind) \ - _(GPU_FusionBCastInnerDim) \ - _(GPU_FusionBCastReduce) \ - _(GPU_FusionSplitBCast) \ - _(GPU_FusionComputeAtExprOrder) \ - _(GPU_FusionZeroDimComputeAt) \ - _(GPU_FusionZeroDimBroadcast) \ - _(GPU_FusionZeroDimReduction) \ - _(GPU_FusionReductionMultiConsumer) +#define TH_FORALL_TESTS_CUDA(_) \ + _(ArgumentSpec) \ + _(CompleteArgumentSpec) \ + _(Fusion) \ + _(GraphExecutor) \ + _(ModuleConversion) \ + _(Interp) \ + _(GPU_IrGraphGenerator) \ + _(GPU_FusionDispatch) \ + _(GPU_FusionClear) \ + _(GPU_FusionCopy) \ + _(GPU_FusionMove) \ + _(GPU_FusionSimpleArith) \ + _(GPU_FusionExprEvalConstants) \ + _(GPU_FusionExprEvalBindings) \ + _(GPU_FusionExprEvalBasic) \ + _(GPU_FusionExprEvalComplex) \ + _(GPU_FusionExprEvalPostLower) \ + _(GPU_FusionSimpleTypePromote) \ + _(GPU_FusionMutator) \ + _(GPU_FusionRegister) \ + _(GPU_FusionTopoSort) \ + _(GPU_FusionTensor) \ + _(GPU_FusionTVSplit) \ + _(GPU_FusionTVMerge) \ + _(GPU_FusionTVReorder) \ + _(GPU_FusionEquality) \ + _(GPU_FusionReplaceAll) \ + _(GPU_FusionParser) \ + _(GPU_FusionDependency) \ + _(GPU_FusionCodeGen) \ + _(GPU_FusionCodeGen2) \ + _(GPU_FusionSimplePWise) \ + _(GPU_FusionExecKernel) \ + _(GPU_FusionForLoop) \ + _(GPU_FusionLoopUnroll) \ + _(GPU_FusionUnaryOps) \ + _(GPU_FusionBinaryOps) \ + _(GPU_FusionTernaryOps) \ + _(GPU_FusionCompoundOps) \ + _(GPU_FusionCastOps) \ + _(GPU_FusionAdvancedComputeAt) \ + _(GPU_FusionScalarInputs) \ + _(GPU_FusionRFactorReplay) \ + _(GPU_FusionReduction) \ + _(GPU_FusionReduction2) \ + _(GPU_FusionReduction3) \ + _(GPU_FusionReduction4) \ + _(GPU_FusionReduction5) \ + _(GPU_FusionReductionTFT) \ + _(GPU_FusionSimpleBCast) \ + _(GPU_FusionSimpleGemm) \ + _(GPU_FusionSoftmax1D) \ + _(GPU_FusionSoftmax1DNormalized) \ + _(GPU_FusionSoftmax3D) \ + _(GPU_FusionSoftmax3DNormalized) \ + _(GPU_FusionSoftmaxComputeAt) \ + _(GPU_FusionGridReduction1) \ + _(GPU_FusionGridReduction2) \ + _(GPU_FusionGridReduction3dim1) \ + _(GPU_FusionGridReduction3dim0) \ + _(GPU_FusionGridReduction4) \ + _(GPU_FusionGridReduction5) \ + _(GPU_FusionGridReduction6) \ + _(GPU_FusionNonRedAxisBind) \ + _(GPU_FusionBCastInnerDim) \ + _(GPU_FusionBCastReduce) \ + _(GPU_FusionSplitBCast) \ + _(GPU_FusionComputeAtExprOrder) \ + _(GPU_FusionZeroDimComputeAt) \ + _(GPU_FusionZeroDimBroadcast) \ + _(GPU_FusionZeroDimReduction) \ + _(GPU_FusionReductionMultiConsumer) \ + _(GPU_FusionBCastAfterReduce) #else #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 79232188a3532..ada966debd1d1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -1,6 +1,8 @@ #include #include #include +#include +#include #include @@ -478,15 +480,50 @@ void IRPrinter::handle(const ReductionOp* rop) { } void IRPrinter::handle(const BroadcastOp* bop) { - indent(); - handle(bop->out()); - os << "\n"; - indent_size++; - indent(); - os << " = "; - handle(bop->in()); - indent_size--; - os << ";\n"; + // Check if we've lowered yet. + bool lowered = bop->out()->getValType() == ValType::TensorIndex; + if (!lowered) { + os << bop->out() << " = broadcast( " << bop->in() << " )\n"; + return; + } + + const ir_utils::ParallelTypeBitmap domains = + ir_utils::getParallelBroadcastDomains(bop, getThreadPredicateMap()); + const bool thread_x = domains.get(ParallelType::TIDx); + const bool thread_y = domains.get(ParallelType::TIDy); + const bool thread_z = domains.get(ParallelType::TIDz); + const bool block_x = domains.get(ParallelType::BIDx); + const bool block_y = domains.get(ParallelType::BIDy); + const bool block_z = domains.get(ParallelType::BIDz); + + const bool grid_broadcast_needed = block_x || block_y || block_z; + const bool block_broadcast_needed = thread_x || thread_y || thread_z; + + TORCH_INTERNAL_ASSERT( + !grid_broadcast_needed, "Parallel broadcast across blocks not supported"); + + if (block_broadcast_needed) { + indent(); + os << "broadcast::blockBroadcast<"; + os << (thread_x ? "true" : "false") << ", "; + os << (thread_y ? "true" : "false") << ", "; + os << (thread_z ? "true" : "false"); + os << ">("; + handle(bop->out()); + os << ", "; + handle(bop->in()); + os << ");\n"; + } else { + indent(); + handle(bop->out()); + os << "\n"; + indent_size++; + indent(); + os << " = "; + handle(bop->in()); + indent_size--; + os << ";\n"; + } } void IRPrinter::handle(const ForLoop* fl) { @@ -640,6 +677,14 @@ void IRPrinter::printKernel( os << "}\n"; } +const ThreadPredicateMap& IRPrinter::getThreadPredicateMap() { + if (thread_predicates_ == nullptr) { + Fusion* fusion = FusionGuard::getCurFusion(); + thread_predicates_ = std::make_unique(fusion); + } + return *thread_predicates_; +} + std::ostream& operator<<(std::ostream& os, const Statement* stmt) { IRPrinter p(os); p.handle(stmt); diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 18576487fe36d..3dc95d16e95c2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -3,6 +3,7 @@ #include #include +#include #include @@ -126,6 +127,11 @@ class TORCH_CUDA_API IRPrinter : public OptInConstDispatch { void printKernel( const std::vector& exprs, const std::string& kernel_name); + + private: + std::unique_ptr thread_predicates_; + + const ThreadPredicateMap& getThreadPredicateMap(); }; TORCH_CUDA_API std::ostream& operator<<( diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index e7e113779182a..8b4b3e875f761 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -156,7 +156,8 @@ std::pair codeGeneration(Fusion* fusion) { << code_random_number_gen << "\n" << code_helper_funcs << "\n" << code_template_block_reduction << "\n" - << code_template_grid_reduction << "\n"; + << code_template_grid_reduction << "\n" + << code_template_block_broadcast << "\n"; std::stringstream cdg; GPULower gpulw(fusion); gpulw.printKernel(str_stream, kKernelName); diff --git a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h index 0427c8f5bcd7d..c6c009fd0bb3e 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h +++ b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h @@ -570,6 +570,53 @@ __device__ void gridReduce(T& out, T inp_val, Func reduction_op, } // namespace reduction )"; +static auto code_template_block_broadcast = R"( +namespace broadcast { + +template +__host__ __device__ unsigned offset_of_source(const dim3& block_dim, const dim3& thread_idx) { + unsigned offset = 0; + if (!Z_THREAD) + offset = offset * block_dim.z + thread_idx.z; + if (!Y_THREAD) + offset = offset * block_dim.y + thread_idx.y; + if (!X_THREAD) + offset = offset * block_dim.x + thread_idx.x; + return offset; +} + +/** Broadcasts within partitioned groups of threads. + + X_THREAD: Broadcast from threadIdx.x == 0 if true + Y_THREAD: Broadcast from threadIdx.y == 0 if true + Z_THREAD: Broadcast from threadIdx.z == 0 if true + inp_val: Per-thread source value. Only valid when the thread is a source. + out: Per-thread output location + */ +template +__device__ void blockBroadcast(T& out, T inp_val) { + + // Use worst case for memory. + __shared__ T shared_mem[1024]; + + const bool has_valid_data = + (!X_THREAD || threadIdx.x == 0) && + (!Y_THREAD || threadIdx.y == 0) && + (!Z_THREAD || threadIdx.z == 0); + + const auto shared_offset = offset_of_source(blockDim, threadIdx); + + if (has_valid_data) + shared_mem[shared_offset] = inp_val; + + __syncthreads(); + + out = shared_mem[shared_offset]; +} + +} // namespace broadcast +)"; + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 913c39e1eb8be..ea917f83a9e34 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -20,7 +20,7 @@ std::vector GPULower::getLoweredExprs() { // Validate and make some minor modifications in preparation to generate code. PrepareForLowering(fusion_); - auto preds = ThreadPredicates::compute(fusion_); + ThreadPredicateMap preds(fusion_); // Run our passes keeping the lowered expressions and forwarding them. auto loop_nests = LoopNestGenerator::getLoopNest( diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index f6ecb57aca2c9..c1717e8c3d620 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -4,6 +4,8 @@ #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -40,7 +42,7 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { // Predicates from ThreadPredicates that we will extend to reduction buffer // initialization - std::unordered_map& thread_predicates_; + ThreadPredicateMap& thread_predicates_; // Create, place, and return the allocation for tv Expr* pushAlloc(TensorView*); @@ -71,16 +73,14 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { // Run the pass and accumulate output in lowered_exprs void generate(const std::vector& exprs); - LoopNestGenerator( - Fusion* _fusion, - std::unordered_map& _thread_predicates) + LoopNestGenerator(Fusion* _fusion, ThreadPredicateMap& _thread_predicates) : fusion_(_fusion), thread_predicates_(_thread_predicates) {} public: static std::vector getLoopNest( Fusion* fusion, std::vector exprs, - std::unordered_map& thread_predicates) { + ThreadPredicateMap& thread_predicates) { FusionGuard fg(fusion); LoopNestGenerator lng(fusion, thread_predicates); lng.generate(exprs); @@ -90,4 +90,4 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp index 2f0f060af6856..913beb5dd59cb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -8,58 +8,25 @@ namespace torch { namespace jit { namespace fuser { -const static std::unordered_map pt_to_offset{ - {ParallelType::BIDx, 0}, - {ParallelType::BIDy, 1}, - {ParallelType::BIDz, 2}, - {ParallelType::TIDx, 3}, - {ParallelType::TIDy, 4}, - {ParallelType::TIDz, 5}}; - -const static std::unordered_map offset_to_pt{ - {0, ParallelType::BIDx}, - {1, ParallelType::BIDy}, - {2, ParallelType::BIDz}, - {3, ParallelType::TIDx}, - {4, ParallelType::TIDy}, - {5, ParallelType::TIDz}}; - -static constexpr int num_p_type = 6; - namespace { -void flip_true(std::bitset& bits, const ParallelType p_type) { - if (pt_to_offset.find(p_type) == pt_to_offset.end()) { - TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type."); - } - bits[pt_to_offset.at(p_type)] = true; +Val* threadPredicate(ParallelType pt) { + return eq(new NamedScalar(stringifyThread(pt), DataType::Int), new Int(0)); } -Val* threadPredicate(int i) { - if (offset_to_pt.find(i) == offset_to_pt.end()) { - TORCH_INTERNAL_ASSERT( - false, - "Invalid int for predicate computation, should be from [0-5], but recieved, ", - i, - "."); - } - return eq( - new NamedScalar(stringifyThread(offset_to_pt.at(i)), DataType::Int), - new Int(0)); -} - -Bool* getThreadPredicate(std::bitset bits) { - if (bits.none()) +Bool* getThreadPredicate(const ir_utils::ParallelTypeBitmap& bits) { + if (bits.none()) { return new Bool(true); + } Val* pred = nullptr; - for (int i = 0; i < num_p_type; i++) { - if (bits[i]) { + for (const auto& pt_bool : bits.getMap()) { + if (pt_bool.second) { if (pred == nullptr) { - pred = threadPredicate(i); + pred = threadPredicate(pt_bool.first); } else { - pred = andOp(pred, threadPredicate(i)); + pred = andOp(pred, threadPredicate(pt_bool.first)); } } } @@ -76,25 +43,16 @@ Bool* getThreadPredicate(std::bitset bits) { } // namespace -std::bitset ThreadPredicates::getThreadPredicates( - const TensorView* tv) { - TORCH_INTERNAL_ASSERT( - thread_predicates.find(tv) != thread_predicates.end(), - "Invalid predicate initialization, couldn't find ", - tv); - return thread_predicates[tv]; -} - // Update the reduction_deps bitset based on provided Expr -void ThreadPredicates::updateBitSet(Expr* expr) { +void ThreadPredicateMap::updateBitSet(Expr* expr) { // Which predicates were set for the inputs - std::bitset input_preds; + ir_utils::ParallelTypeBitmap input_preds; // Which dims are reductions in inputs - std::bitset input_reductions; + ir_utils::ParallelTypeBitmap input_reductions; // Which dims are bcast in inputs - std::bitset input_bcasts; + ir_utils::ParallelTypeBitmap input_bcasts; // Run through inputs and update bitsets for (const auto* inp : expr->inputs()) { @@ -103,28 +61,28 @@ void ThreadPredicates::updateBitSet(Expr* expr) { auto tv_inp = ir_utils::asConstTV(inp); TORCH_INTERNAL_ASSERT( - thread_predicates.find(tv_inp) != thread_predicates.end(), + thread_predicates_.find(tv_inp) != thread_predicates_.end(), "Thread predicate map was not initialized, couldn't find ", inp); - input_preds |= thread_predicates[tv_inp]; + input_preds |= thread_predicates_[tv_inp]; - std::bitset id_reductions; - std::bitset id_bcasts; - std::bitset id_ptypes; + ir_utils::ParallelTypeBitmap id_reductions; + ir_utils::ParallelTypeBitmap id_bcasts; + ir_utils::ParallelTypeBitmap id_ptypes; for (auto id : tv_inp->domain()->domain()) { if (id->isThread()) { - flip_true(id_ptypes, id->parallel_method()); + id_ptypes.set(id->parallel_method(), true); if (id->isReduction()) - flip_true(id_reductions, id->parallel_method()); + id_reductions.set(id->parallel_method(), true); if (id->isBroadcast()) - flip_true(id_bcasts, id->parallel_method()); + id_bcasts.set(id->parallel_method(), true); } } // Validate the combination of ptypes, reductions, bcasts - for (size_t i = 0; i < num_p_type; i++) { + for (size_t i = 0; i < ir_utils::ParallelTypeBitmap::num_p_type; i++) { if (input_reductions[i]) { if (id_ptypes[i]) { TORCH_INTERNAL_ASSERT( @@ -161,25 +119,48 @@ void ThreadPredicates::updateBitSet(Expr* expr) { for (const auto* out : expr->outputs()) { if (!ir_utils::isTV(out)) continue; - thread_predicates[ir_utils::asConstTV(out)] = output_preds; + thread_predicates_[ir_utils::asConstTV(out)] = output_preds; } } -ThreadPredicates::ThreadPredicates(Fusion* _fusion) : fusion_(_fusion) { - for (auto inp : fusion_->inputs()) - if (ir_utils::isTV(inp)) - thread_predicates[ir_utils::asConstTV(inp)] = std::bitset(); -} -std::unordered_map ThreadPredicates::compute( - Fusion* fusion) { - ThreadPredicates tp(fusion); - for (auto expr : fusion->exprs(true)) - tp.updateBitSet(expr); - std::unordered_map preds; - for (auto entry : tp.thread_predicates) { - preds[entry.first] = getThreadPredicate(entry.second); +ThreadPredicateMap::ThreadPredicateMap(Fusion* _fusion) : fusion_(_fusion) { + for (auto inp : fusion_->inputs()) { + if (ir_utils::isTV(inp)) { + thread_predicates_[ir_utils::asConstTV(inp)] = + ir_utils::ParallelTypeBitmap(); + } + } + for (auto expr : fusion_->exprs(true)) { + updateBitSet(expr); } - return preds; +} + +ThreadPredicateMap::const_iterator ThreadPredicateMap::find( + const TensorView* tv) const { + return thread_predicates_.find(tv); +} + +ThreadPredicateMap::const_iterator ThreadPredicateMap::end() const { + return thread_predicates_.end(); +} + +const ir_utils::ParallelTypeBitmap& ThreadPredicateMap::at( + const TensorView* tv) const { + return thread_predicates_.at(tv); +} + +ir_utils::ParallelTypeBitmap& ThreadPredicateMap::at(const TensorView* tv) { + return thread_predicates_.at(tv); +} + +ir_utils::ParallelTypeBitmap& ThreadPredicateMap::operator[]( + const TensorView* tv) { + return thread_predicates_[tv]; +} + +Bool* ThreadPredicateMap::getExpr(const TensorView* tv) const { + TORCH_INTERNAL_ASSERT(find(tv) != end(), "Couldn't find ", tv); + return getThreadPredicate(at(tv)); } } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h index df09b72751ee4..3c3a6ca81d4a9 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h @@ -2,6 +2,7 @@ #include #include +#include #include @@ -9,34 +10,39 @@ namespace torch { namespace jit { namespace fuser { -class TORCH_CUDA_API ThreadPredicates { - private: - Fusion* fusion_; +/* + * Map from tensorview to bit set represnting If any dependency of TV had a parallelized reduction, we will track + * it here. This will be used for predicate generation to prevent + * parallelization on that axis. This is important if we have a reduction on + * for example TIDx, as the reduced value is only valid on threadIdx.x == 0 + * therefore if we use that value later in the kernel we have that predicate. + * If we follow a reduction parallelized on TIDx with a broadcast on TIDx we + * no longer need the predicate and can reset the bit accordingly + */ +class TORCH_CUDA_API ThreadPredicateMap { + public: + using MapType = + std::unordered_map; + using const_iterator = MapType::const_iterator; - /* - * Map from tensorview to bit set represnting If any dependency of TV had a parallelized reduction, we will track - * it here. This will be used for predicate generation to prevent - * parallelization on that axis. This is important if we have a reduction on - * for example TIDx, as the reduced value is only valid on threadIdx.x == 0 - * therefore if we use that value later in the kernel we have that predicate. - * If we follow a reduction parallelized on TIDx with a broadcast on TIDx we - * no longer need the predicate and can reset the bit accordingly - */ - std::unordered_map> thread_predicates; + explicit ThreadPredicateMap(Fusion* _fusion); - // Update the thread_predicates bitset based on provided Expr - void updateBitSet(Expr*); + const_iterator find(const TensorView* tv) const; + const_iterator end() const; + const ir_utils::ParallelTypeBitmap& at(const TensorView* tv) const; + ir_utils::ParallelTypeBitmap& at(const TensorView* tv); + ir_utils::ParallelTypeBitmap& operator[](const TensorView* tv); - // Safety wrapper to access thread_predicates - std::bitset<6> getThreadPredicates(const TensorView*); + // Returns a Bool predicate expression for a given TensorView. + Bool* getExpr(const TensorView* tv) const; - ThreadPredicates(Fusion* _fusion); + private: + Fusion* fusion_; + MapType thread_predicates_; - public: - // Computes any thread predicates that need to be applied when computing a - // TensorView. - static std::unordered_map compute(Fusion* fusion); + // Update the thread_predicates bitset based on provided Expr + void updateBitSet(Expr*); }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index d9e1fd1ee0ed3..465d210b99b55 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -10,12 +10,18 @@ namespace torch { namespace jit { namespace fuser { -Bool* UnrollPass::getThreadPredicate(const TensorView* tv) { - TORCH_INTERNAL_ASSERT( - thread_predicates_.find(tv) != thread_predicates_.end(), - "Invalid predicate initialization, couldn't find ", - tv); - return thread_predicates_[tv]; +Bool* UnrollPass::getThreadPredicate(TensorView* tv) { + // No thread predicate is needed predicate when tv is output of a + // parallel broadcast expression. + if (tv->getOrigin() != nullptr && + tv->getOrigin()->getExprType() == ExprType::BroadcastOp && + ir_utils::getParallelBroadcastDomains( + static_cast(tv->getOrigin()), thread_predicates_) + .any()) { + return nullptr; + } + + return thread_predicates_.getExpr(tv); } // Custom dispatch for Expr, want to find out of it's a TV op @@ -56,7 +62,9 @@ Bool* getPredicate(TensorView* tv, std::vector inds_, Bool* thread_pred) { std::vector all_preds = PredicateCompute::computePredicates( new TensorIndex(tv, IndexCompute::get(tv->domain(), inds))); - all_preds.push_back(thread_pred); + if (thread_pred != nullptr) { + all_preds.push_back(thread_pred); + } std::vector preds; @@ -222,7 +230,7 @@ void UnrollPass::computeMap() { std::vector UnrollPass::runPass( Fusion* fusion, const std::vector& exprs, - std::unordered_map& thread_predicates) { + const ThreadPredicateMap& thread_predicates) { FusionGuard fg(fusion); UnrollPass up(fusion, exprs, thread_predicates); up.computeMap(); @@ -241,4 +249,4 @@ std::vector UnrollPass::runPass( } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h index d421676b8f8ab..989a131d0c09a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.h +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h @@ -51,7 +51,7 @@ namespace fuser { class TORCH_CUDA_API UnrollPass : public OptOutDispatch { private: // Wrapper to access thread_predicates_ - Bool* getThreadPredicate(const TensorView*); + Bool* getThreadPredicate(TensorView*); // We will track which loops in the incomming IR will be replaced and by what std::unordered_map loop_replacement_map; @@ -65,7 +65,7 @@ class TORCH_CUDA_API UnrollPass : public OptOutDispatch { std::vector for_loops; // Map from TensorView - std::unordered_map& thread_predicates_; + const ThreadPredicateMap& thread_predicates_; // keep track if we're within an unrolled loop bool within_unroll = false; @@ -80,7 +80,7 @@ class TORCH_CUDA_API UnrollPass : public OptOutDispatch { UnrollPass( Fusion* _fusion, const std::vector& _incoming_exprs, - std::unordered_map& _thread_predicates) + const ThreadPredicateMap& _thread_predicates) : fusion_(_fusion), incoming_exprs_(_incoming_exprs), thread_predicates_(_thread_predicates) {} @@ -94,7 +94,7 @@ class TORCH_CUDA_API UnrollPass : public OptOutDispatch { static std::vector runPass( Fusion* fusion, const std::vector& exprs, - std::unordered_map& thread_predicates); + const ThreadPredicateMap& thread_predicates); }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 75dcf0ecc0b52..6c8bbda90dc61 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -1,5 +1,7 @@ #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -467,8 +469,140 @@ bool isUnrolledFor(const Expr* expr) { ParallelType::Unroll; } +const std::unordered_map ParallelTypeBitmap::pt_to_offset_{ + {ParallelType::BIDx, 0}, + {ParallelType::BIDy, 1}, + {ParallelType::BIDz, 2}, + {ParallelType::TIDx, 3}, + {ParallelType::TIDy, 4}, + {ParallelType::TIDz, 5}}; + +const std::unordered_map ParallelTypeBitmap::offset_to_pt_ = + {{0, ParallelType::BIDx}, + {1, ParallelType::BIDy}, + {2, ParallelType::BIDz}, + {3, ParallelType::TIDx}, + {4, ParallelType::TIDy}, + {5, ParallelType::TIDz}}; + +bool ParallelTypeBitmap::get(ParallelType pt) const { + if (pt_to_offset_.find(pt) == pt_to_offset_.end()) { + TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type."); + } + return bitset_[pt_to_offset_.at(pt)]; +} + +bool ParallelTypeBitmap::set(ParallelType pt, bool new_val) { + if (pt_to_offset_.find(pt) == pt_to_offset_.end()) { + TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type."); + } + bool old_val = bitset_[pt_to_offset_.at(pt)]; + bitset_[pt_to_offset_.at(pt)] = new_val; + return old_val; +} + +ParallelTypeBitmap ParallelTypeBitmap::operator&=( + const ParallelTypeBitmap& other) { + bitset_ &= other.bitset_; + return *this; +} + +ParallelTypeBitmap ParallelTypeBitmap::operator|=( + const ParallelTypeBitmap& other) { + bitset_ |= other.bitset_; + return *this; +} + +ParallelTypeBitmap ParallelTypeBitmap::operator^=( + const ParallelTypeBitmap& other) { + bitset_ ^= other.bitset_; + return *this; +} + +ParallelTypeBitmap ParallelTypeBitmap::operator~() const { + return ParallelTypeBitmap(~bitset_); +} + +bool ParallelTypeBitmap::none() const { + return bitset_.none(); +} + +bool ParallelTypeBitmap::any() const { + return bitset_.any(); +} + +bool ParallelTypeBitmap::all() const { + return bitset_.all(); +} + +bool ParallelTypeBitmap::operator[](size_t pos) const { + TORCH_INTERNAL_ASSERT( + pos < num_p_type, "Invalid index to ParallelTypeBitset: ", pos); + return bitset_[pos]; +} + +std::map ParallelTypeBitmap::getMap() const { + std::map map; + for (const auto& pt_offset : pt_to_offset_) { + map.emplace(std::make_pair(pt_offset.first, bitset_[pt_offset.second])); + } + return map; +} + +ParallelTypeBitmap operator&( + const ParallelTypeBitmap& lhs, + const ParallelTypeBitmap& rhs) { + auto x = lhs; + x &= rhs; + return x; +} + +ParallelTypeBitmap operator|( + const ParallelTypeBitmap& lhs, + const ParallelTypeBitmap& rhs) { + auto x = lhs; + x |= rhs; + return x; +} + +ParallelTypeBitmap operator^( + const ParallelTypeBitmap& lhs, + const ParallelTypeBitmap& rhs) { + auto x = lhs; + x ^= rhs; + return x; +} + +ParallelTypeBitmap getParallelBroadcastDomains( + const BroadcastOp* const bop, + const ThreadPredicateMap& preds) { + const Val* bop_out = bop->out(); + if (bop_out->getValType().value() == ValType::TensorIndex) { + bop_out = bop_out->as()->view(); + } + TORCH_INTERNAL_ASSERT( + bop_out->getValType().value() == ValType::TensorView, + "Out is not tensor view"); + auto out_tv = bop_out->as(); + // If no pred is found for out_tv, no predicate is necessary + if (preds.find(out_tv) == preds.end()) { + return ParallelTypeBitmap(); + } + const ParallelTypeBitmap& out_pred = preds.at(out_tv); + + ParallelTypeBitmap parallel_broadcast; + const auto& iter_domains = out_tv->domain()->domain(); + for (auto id : iter_domains) { + if (id->isBroadcast() && id->isThread()) { + parallel_broadcast.set(id->parallel_method(), true); + } + } + + return parallel_broadcast & out_pred; +} + } // namespace ir_utils } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 8af46dab4eacd..c282eff120703 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -4,12 +4,16 @@ #include +#include + // Provides utilities for dealing with nested ForLoop and IfThenElse scopes namespace torch { namespace jit { namespace fuser { +class ThreadPredicateMap; + namespace scope_utils { // Grab the ForLoop starting from scope working out @@ -75,6 +79,50 @@ const TensorView* asConstTV(const Val* const); bool isUnrolledFor(const Expr*); +// Represents mapping to bool from BIDx, BIDy, BIDz, TIDx, TIDy and TIDz. +class ParallelTypeBitmap { + public: + static constexpr int num_p_type = 6; + ParallelTypeBitmap() = default; + bool get(ParallelType pt) const; + bool set(ParallelType pt, bool); + ParallelTypeBitmap operator&=(const ParallelTypeBitmap& other); + ParallelTypeBitmap operator|=(const ParallelTypeBitmap& other); + ParallelTypeBitmap operator^=(const ParallelTypeBitmap& other); + ParallelTypeBitmap operator~() const; + bool none() const; + bool any() const; + bool all() const; + bool operator[](size_t pos) const; + std::map getMap() const; + + private: + ParallelTypeBitmap(const std::bitset& bs) : bitset_(bs) {} + std::bitset bitset_; + const static std::unordered_map pt_to_offset_; + const static std::unordered_map offset_to_pt_; +}; + +ParallelTypeBitmap operator&( + const ParallelTypeBitmap& lhs, + const ParallelTypeBitmap& rhs); + +ParallelTypeBitmap operator|( + const ParallelTypeBitmap& lhs, + const ParallelTypeBitmap& rhs); + +ParallelTypeBitmap operator^( + const ParallelTypeBitmap& lhs, + const ParallelTypeBitmap& rhs); + +// Returns a ParallelTypeBitmap representing which domain needs +// blockBroadcast. +// Even when a domain is broadcast and parallelized, it does not need +// blockBroadcast unless it is predicated. +ParallelTypeBitmap getParallelBroadcastDomains( + const BroadcastOp* const, + const ThreadPredicateMap& preds); + } // namespace ir_utils } // namespace fuser } // namespace jit