diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 3450eea1bb6ec4..9772f5877a880e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -451,8 +451,11 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/compute_at.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/dispatch.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/expr_evaluator.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_kernel_arg.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_launch_params.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_utils.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/fusion.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/scheduler.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/graph_fuser.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/index_compute.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -463,9 +466,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/shape_inference.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_index.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -473,16 +473,20 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_validation.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/predicate_compute.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/register_interface.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/scheduler.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/shape_inference.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_view.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_iter.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_replay.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_rfactor.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/type.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/utils.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/register_interface.cpp ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp ) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 3949b98f77d662..08e6fde5b2f530 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include #include @@ -25,6 +27,7 @@ namespace torch { namespace jit { +using namespace torch::jit::fuser; using namespace torch::jit::fuser; namespace { @@ -101,6 +104,7 @@ void testGPU_IrGraphGenerator() { fusion.addOutput(tv6); + tv4->axis(2)->parallelize(ParallelType::BIDy); tv6->merge(0); tv6->split(0, 4); tv6->axis(0)->parallelize(ParallelType::BIDx); @@ -189,6 +193,9 @@ void testGPU_FusionExprEvalBindings() { checkIntValue(&eval_context, ceilDiv(a, b), 3); checkIntValue(&eval_context, d, -4); + // Reset evaluation context + eval_context = EvaluationContext(&fusion); + eval_context.bind(a, 2); eval_context.bind(b, 5); @@ -368,10 +375,8 @@ void testGPU_FusionExprEvalPostLower() { } void testGPU_FusionClear() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // 1. Create a dummy IR @@ -379,13 +384,13 @@ void testGPU_FusionClear() { TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); - fusion->addInput(tv0); - fusion->addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); - fusion->addOutput(tv3); + fusion.addOutput(tv3); tv3->split(0, 4); tv0->computeAt(tv3, 1); @@ -395,24 +400,24 @@ void testGPU_FusionClear() { tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); - fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); + fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); } // 2. Clear the IR - fusion->clear(); + fusion.clear(); - TORCH_CHECK(fusion->exprs().empty()); - TORCH_CHECK(fusion->vals().empty()); + TORCH_CHECK(fusion.exprs().empty()); + TORCH_CHECK(fusion.vals().empty()); - TORCH_CHECK(fusion->inputs().empty()); - TORCH_CHECK(fusion->outputs().empty()); + TORCH_CHECK(fusion.inputs().empty()); + TORCH_CHECK(fusion.outputs().empty()); - TORCH_CHECK(fusion->launch_configs().empty()); + TORCH_CHECK(fusion.launch_configs().empty()); - TORCH_CHECK(!fusion->hasReduction()); - TORCH_CHECK(!fusion->hasBlockReduction()); - TORCH_CHECK(!fusion->hasGridReduction()); + TORCH_CHECK(!fusion.hasReduction()); + TORCH_CHECK(!fusion.hasBlockReduction()); + TORCH_CHECK(!fusion.hasGridReduction()); // 3. Rebuild the IR @@ -422,9 +427,9 @@ void testGPU_FusionClear() { TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addOutput(tv3); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv3); // tv3 [i0, i1, i2] tv3->reorder({{0, 2}, {2, 0}}); @@ -438,31 +443,19 @@ void testGPU_FusionClear() { tv3->axis(1)->parallelize(ParallelType::BIDx); } - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 1, // tid_x - 1, // tid_y - 1, // tid_z - 4, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({16, 8, 8}, options); at::Tensor input2 = at::randn_like(input1); - at::Tensor output = at::empty_like(input1); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input1, input2}, {output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input1, input2}); at::Tensor tv2_ref = input2 + 2.0; at::Tensor output_ref = input1 + tv2_ref; - TORCH_CHECK(output_ref.equal(output)); + TORCH_CHECK(output_ref.equal(outputs[0])); } void testGPU_FusionCopy() { @@ -490,10 +483,6 @@ void testGPU_FusionCopy() { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); - - original_fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); - original_fusion.setLaunchConfig( - LaunchConfigType::BIDx, tv3->axis(0)->rawExtent()); } // Test copy before lowering @@ -1066,32 +1055,32 @@ void testGPU_FusionParser() { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3){ float T2[4]; if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - for(size_t i29 = 0; i29 < 4; ++i29 ) { - T2[ i29 ] - = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] - * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + for(size_t i27 = 0; i27 < 4; ++i27 ) { + T2[ i27 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; } } else { - for(size_t i29 = 0; i29 < 4; ++i29 ) { - if ( ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - T2[ i29 ] - = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] - * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i29 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + for(size_t i27 = 0; i27 < 4; ++i27 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T2[ i27 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; } } } if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - for(size_t i30 = 0; i30 < 4; ++i30 ) { - T3[ ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] - = T2[ i30 ] - * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + for(size_t i28 = 0; i28 < 4; ++i28 ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i28 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; } } else { - for(size_t i30 = 0; i30 < 4; ++i30 ) { - if ( ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - T3[ ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] - = T2[ i30 ] - * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i30 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + for(size_t i28 = 0; i28 < 4; ++i28 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i28 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; } } } @@ -1155,17 +1144,15 @@ void testGPU_FusionForLoop() { } void testGPU_FusionCodeGen() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(3); new BinaryOp(BinaryOpType::Add, tv0, new Float(0.0), new Float(1.0)); TensorView* tv1 = add(tv0, new Float(2.0)); TensorView* tv2 = add(tv1, new Float(3.0)); - fusion->addOutput(tv2); + fusion.addOutput(tv2); //[I0, I1, I2] tv2 = tv2->split(0, 4); @@ -1179,26 +1166,13 @@ void testGPU_FusionCodeGen() { tv0->computeAt(tv2, -1); - prog.setDevice(0); - // These can be set to anything as there are no bindings! - // All CTAS and threads execute the same thing. - setupLaunchConfig( - prog.fusion(), - 32, // tid_x - 1, // tid_y - 1, // tid_z - 4, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor output = at::empty({16, 8, 8}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {}, {output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({}, {output}); at::Tensor output_ref = at::zeros_like(output, options); output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0; @@ -1207,19 +1181,17 @@ void testGPU_FusionCodeGen() { } void testGPU_FusionCodeGen2() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(3); TensorView* tv1 = makeDummyTensor(3); TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addOutput(tv3); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv3); //[I0, I1, I2] tv3->reorder({{0, 2}, {2, 0}}); @@ -1235,39 +1207,24 @@ void testGPU_FusionCodeGen2() { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 8, // tid_x - 1, // tid_y - 1, // tid_z - 4, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({16, 8, 8}, options); at::Tensor input2 = at::randn_like(input1); - ; - at::Tensor output = at::empty_like(input1); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input1, input2}, {output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input1, input2}); at::Tensor tv2_ref = input2 + 2.0; at::Tensor output_ref = input1 + tv2_ref; - TORCH_CHECK(output_ref.equal(output)); + TORCH_CHECK(output_ref.equal(outputs[0])); } void testGPU_FusionSimplePWise() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // dimensionality of the problem int nDims = 3; @@ -1276,8 +1233,8 @@ void testGPU_FusionSimplePWise() { TensorView* tv1 = makeDummyTensor(nDims); // Register your inputs - fusion->addInput(tv0); - fusion->addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView @@ -1285,7 +1242,7 @@ void testGPU_FusionSimplePWise() { TensorView* tv3 = add(tv0, tv2); // Register your outputs - fusion->addOutput(tv3); + fusion.addOutput(tv3); // Do transformations, remember, transformations are outputs to inputs // This doesn't have to be in this order @@ -1306,26 +1263,15 @@ void testGPU_FusionSimplePWise() { tv3->axis(-2)->parallelize(ParallelType::TIDy); tv3->axis(-1)->parallelize(ParallelType::TIDx); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 2, // tid_y - 1, // tid_z - 64, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({64, 2, 128}, options); at::Tensor input2 = at::rand_like(input1); at::Tensor output = at::empty_like(input1); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input1, input2}, {output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input1, input2}, {output}); at::Tensor tv2_ref = input2 + 2.0; at::Tensor output_ref = input1 + tv2_ref; @@ -1334,18 +1280,16 @@ void testGPU_FusionSimplePWise() { } void testGPU_FusionExecKernel() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); // Register your inputs - fusion->addInput(tv0); - fusion->addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView @@ -1353,7 +1297,7 @@ void testGPU_FusionExecKernel() { TensorView* tv3 = add(tv0, tv2); // Register your outputs - fusion->addOutput(tv3); + fusion.addOutput(tv3); tv3->merge(0); tv3->split(0, 128); @@ -1371,31 +1315,18 @@ void testGPU_FusionExecKernel() { tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::ones({1, 128}, options); at::Tensor input2 = at::ones_like(input1); - at::Tensor output = at::empty_like(input1); - - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input1, input2}, {output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input1, input2}); at::Tensor check = at::full({1, 128}, 4, options); ; - TORCH_CHECK(output.equal(check)); + TORCH_CHECK(outputs[0].equal(check)); } int ceilDiv_(int a, int b) { @@ -1414,13 +1345,11 @@ void testGPU_FusionAdvancedComputeAt() { * tv7 = tv1 + tv4 */ { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); TensorView* tv1 = mul(tv0, new Float(0.5)); TensorView* tv2 = mul(tv1, new Float(-1.0)); @@ -1431,8 +1360,8 @@ void testGPU_FusionAdvancedComputeAt() { TensorView* tv6 = add(tv5, tv4); TensorView* tv7 = add(tv1, tv4); - fusion->addOutput(tv6); - fusion->addOutput(tv7); + fusion.addOutput(tv6); + fusion.addOutput(tv7); // Lets setup to actually run tv7->merge(0); @@ -1451,8 +1380,8 @@ void testGPU_FusionAdvancedComputeAt() { TORCH_CHECK(tv6->getComputeAtView() == tv7 && tv6->nDims() == 3); TORCH_CHECK(!tv7->hasComputeAt()); - for (Val* val : fusion->vals()) { - if (!fusion->hasInput(val) && + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -1475,22 +1404,9 @@ void testGPU_FusionAdvancedComputeAt() { at::Tensor kernel_tv6 = at::empty_like(t0, options); at::Tensor kernel_tv7 = at::empty_like(t0, options); - prog.setDevice(0); - - int blocks = ceilDiv_( - ceilDiv_(t0.numel(), 128), 4); // numel / unroll factor / threads - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - blocks, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0}, {kernel_tv6, kernel_tv7}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0}, {kernel_tv6, kernel_tv7}); TORCH_CHECK(at::allclose(kernel_tv6, t6)); TORCH_CHECK(at::allclose(kernel_tv7, t7)); @@ -1506,13 +1422,11 @@ void testGPU_FusionAdvancedComputeAt() { * tv6 = tv5 + tv3 */ { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); TensorView* tv1 = mul(tv0, new Float(-1.0)); TensorView* tv2 = add(tv0, new Float(3.0)); @@ -1522,8 +1436,8 @@ void testGPU_FusionAdvancedComputeAt() { TensorView* tv5 = add(tv4, tv3); TensorView* tv6 = add(tv5, tv3); - fusion->addOutput(tv5); - fusion->addOutput(tv6); + fusion.addOutput(tv5); + fusion.addOutput(tv6); // Lets setup to actually run tv6->merge(0); @@ -1534,8 +1448,8 @@ void testGPU_FusionAdvancedComputeAt() { tv0->computeAt(tv6, 1); - for (Val* val : fusion->vals()) { - if (!fusion->hasInput(val) && + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1554,53 +1468,35 @@ void testGPU_FusionAdvancedComputeAt() { auto t5 = t4.add(t3); auto t6 = t5.add(t3); - at::Tensor kernel_tv5 = at::empty_like(t0, options); - at::Tensor kernel_tv6 = at::empty_like(t0, options); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0}); - prog.setDevice(0); - - int blocks = ceilDiv_( - ceilDiv_(t0.numel(), 128), 4); // numel / unroll factor / threads - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - blocks, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0}, {kernel_tv5, kernel_tv6}); - - GPULower gpulw(fusion); + GPULower gpulw(&fusion); std::stringstream actual_kernel; gpulw.printKernel(actual_kernel); - TORCH_CHECK(at::allclose(kernel_tv5, t5), actual_kernel.str()); - TORCH_CHECK(at::allclose(kernel_tv6, t6)); + TORCH_CHECK(at::allclose(outputs[0], t5), actual_kernel.str()); + TORCH_CHECK(at::allclose(outputs[1], t6)); } // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(4); - fusion->addInput(tv0); + fusion.addInput(tv0); TensorView* tv1 = makeDummyTensor(4); - fusion->addInput(tv1); + fusion.addInput(tv1); TensorView* tv2 = mul(tv1, new Float(.979361)); TensorView* tv3 = mul(tv2, tv0); - fusion->addOutput(tv3); + fusion.addOutput(tv3); // Lets setup to actually run while (tv3->nDims() > 1) @@ -1613,8 +1509,8 @@ void testGPU_FusionAdvancedComputeAt() { tv3->axis(0)->parallelize(ParallelType::BIDx); - for (Val* val : fusion->vals()) { - if (!fusion->hasInput(val) && + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1632,25 +1528,11 @@ void testGPU_FusionAdvancedComputeAt() { at::Tensor kernel_tv3 = at::empty_like(t0, options); - prog.setDevice(0); - - int blocks = ceilDiv_( - ceilDiv_(t0.numel(), 128), 4); // numel / unroll factor / threads - - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - blocks, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0, t1}, {kernel_tv3}); - - GPULower gpulw(fusion); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0, t1}, {kernel_tv3}); + + GPULower gpulw(&fusion); std::stringstream actual_kernel; gpulw.printKernel(actual_kernel); @@ -1662,28 +1544,26 @@ void testGPU_FusionAdvancedComputeAt() { // T5 = T1 + T4 // T6 = T5 - T0 { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(4); - fusion->addInput(tv0); + fusion.addInput(tv0); TensorView* tv1 = makeDummyTensor(4); - fusion->addInput(tv1); + fusion.addInput(tv1); TensorView* tv2 = makeDummyTensor(4); - fusion->addInput(tv2); + fusion.addInput(tv2); TensorView* tv3 = makeDummyTensor(4); - fusion->addInput(tv3); + fusion.addInput(tv3); TensorView* tv4 = sub(tv2, tv3); TensorView* tv5 = add(tv1, tv4); TensorView* tv6 = sub(tv5, tv0); - fusion->addOutput(tv6); + fusion.addOutput(tv6); // Lets setup to actually run while (tv6->nDims() > 1) @@ -1698,8 +1578,8 @@ void testGPU_FusionAdvancedComputeAt() { tv6->axis(0)->parallelize(ParallelType::BIDx); - for (Val* val : fusion->vals()) { - if (!fusion->hasInput(val) && + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1718,53 +1598,35 @@ void testGPU_FusionAdvancedComputeAt() { auto t5 = t1.add(t4); auto t6 = t5.sub(t0); - at::Tensor kernel_tv6 = at::empty_like(t0, options); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1, t2, t3}); - prog.setDevice(0); - - int blocks = ceilDiv_( - ceilDiv_(t0.numel(), 128), 4); // numel / unroll factor / threads - - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - blocks, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0, t1, t2, t3}, {kernel_tv6}); - - GPULower gpulw(fusion); + GPULower gpulw(&fusion); std::stringstream actual_kernel; gpulw.printKernel(actual_kernel); - TORCH_CHECK(at::allclose(kernel_tv6, t6), actual_kernel.str()); + TORCH_CHECK(at::allclose(outputs[0], t6), actual_kernel.str()); } } void testGPU_FusionScalarInputs() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); TensorView* tv1 = makeDummyTensor(2); - fusion->addInput(tv1); + fusion.addInput(tv1); Float* f0 = new Float(); - fusion->addInput(f0); + fusion.addInput(f0); Float* f1 = new Float(); - fusion->addInput(f1); + fusion.addInput(f1); Float* f2 = new Float(); - fusion->addInput(f2); + fusion.addInput(f2); Float* f3 = new Float(); - fusion->addInput(f3); + fusion.addInput(f3); Val* f4 = mul(f0, f1); Val* f5 = sub(f2, f3); @@ -1772,7 +1634,7 @@ void testGPU_FusionScalarInputs() { TensorView* tv3 = add(tv0, f5); TensorView* tv4 = mul(tv3, tv2); - fusion->addOutput(tv4); + fusion.addOutput(tv4); // Lets setup to actually run while (tv4->nDims() > 1) @@ -1785,8 +1647,8 @@ void testGPU_FusionScalarInputs() { tv4->axis(0)->parallelize(ParallelType::BIDx); - for (Val* val : fusion->vals()) { - if (!fusion->hasInput(val) && + for (Val* val : fusion.vals()) { + if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1819,26 +1681,11 @@ void testGPU_FusionScalarInputs() { at::Tensor kernel_tv4 = at::empty_like(t0, options); - prog.setDevice(0); - - int blocks = - ceilDiv_(ceilDiv_(t0.numel(), 128), 4); // numel / unroll factor / threads - - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - blocks, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); at::Scalar test(fl0); - torch::jit::fuser::cuda::runKernel( - &prog, + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion( {t0, t1, at::Scalar(fl0), @@ -1847,7 +1694,7 @@ void testGPU_FusionScalarInputs() { at::Scalar(fl3)}, {kernel_tv4}); - GPULower gpulw(fusion); + GPULower gpulw(&fusion); std::stringstream actual_kernel; gpulw.printKernel(actual_kernel); @@ -1855,18 +1702,16 @@ void testGPU_FusionScalarInputs() { } void testGPU_FusionLoopUnroll() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(3); TensorView* tv1 = makeDummyTensor(3); // Register your inputs - fusion->addInput(tv0); - fusion->addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView @@ -1874,7 +1719,7 @@ void testGPU_FusionLoopUnroll() { TensorView* tv3 = add(tv0, tv2); // Register your outputs - fusion->addOutput(tv3); + fusion.addOutput(tv3); int block_size = 16; @@ -1898,29 +1743,16 @@ void testGPU_FusionLoopUnroll() { int inp_size = 129 * 13 * 3; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - block_size, // tid_x - 1, // tid_y - 1, // tid_z - (inp_size + 63) / 64, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input0 = at::rand({129, 13, 3}, options); at::Tensor input1 = at::rand({129, 13, 3}, options); - at::Tensor output = at::empty_like(input1); - - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input0, input1}, {output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input0, input1}); - TORCH_CHECK(output.equal(input0.add(input1.add(2.0)))); + TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0)))); } /* @@ -2011,21 +1843,19 @@ void test_op( OutputPair op, InputTuple it, std::index_sequence) { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Generate Input JIT function Inputs and add them as Inputs to the Fusion // Graph std::array jit_inputs = { gen_jit_operand(std::get(it))...}; - std::for_each(jit_inputs.begin(), jit_inputs.end(), [fusion](Val* v) { - fusion->addInput(v); + std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) { + fusion.addInput(v); }); TensorView* out = static_cast(jf(std::get(jit_inputs)...)); - fusion->addOutput(out); + fusion.addOutput(out); std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) { if (v->getValType() == ValType::TensorView) @@ -2034,19 +1864,6 @@ void test_op( out->axis(0)->parallelize(ParallelType::BIDx); out->axis(-1)->parallelize(ParallelType::TIDx); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - threads, // tid_x - 1, // tid_y - 1, // tid_z - blocks, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - std::array aten_inputs = {gen_aten_operand( std::get(it), blocks, threads, /*rand*/ true)...}; const at::ArrayRef aten_inputs_ivalues(aten_inputs); @@ -2055,12 +1872,15 @@ void test_op( gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor(); std::vector output_vect = {output}; cudaDeviceSynchronize(); - if (fusion->hasRNG()) + if (fusion.hasRNG()) at::manual_seed(0); - torch::jit::fuser::cuda::runKernel(&prog, aten_inputs_ivalues, output_vect); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion(aten_inputs_ivalues, output_vect); cudaDeviceSynchronize(); - if (fusion->hasRNG()) + if (fusion.hasRNG()) at::manual_seed(0); at::Tensor ref_output = af(aten_inputs); cudaDeviceSynchronize(); // This sync shouldn't be necessary; @@ -2391,52 +2211,37 @@ void testGPU_FusionCompoundOps() { } void testGPU_FusionCastOps() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2, DataType::Half); TensorView* intrm1 = castOp(DataType::Float, tv0); TensorView* out = castOp(DataType::Half, intrm1); - fusion->addInput(tv0); - fusion->addOutput(out); + fusion.addInput(tv0); + fusion.addOutput(out); tv0->computeAt(out, -1); out->axis(0)->parallelize(ParallelType::BIDx); out->axis(-1)->parallelize(ParallelType::TIDx); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 4, // tid_x - 1, // tid_y - 1, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); at::Tensor input1 = at::rand({1, 4}, options); - at::Tensor output = at::empty_like(input1); at::Tensor ref_output = at::empty_like(input1); std::array inputs = {input1}; const at::ArrayRef input_ivalues(inputs); - std::vector outputs{{output}}; - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, input_ivalues, outputs); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion(input_ivalues); ref_output = at::_cast_Half(at::_cast_Float(input1)); TORCH_CHECK( - output.equal(ref_output), + outputs[0].equal(ref_output), "\nOp Type: -- ", "cast FP16->FP32->FP16", " -- had a mismatch.\n", @@ -2444,7 +2249,7 @@ void testGPU_FusionCastOps() { input1, "\n", "JIT: ", - output, + outputs[0], "\n", "REF: ", ref_output, @@ -2498,7 +2303,7 @@ void testGPU_FusionRFactorReplay() { // new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf] // casp[I0oi{16}, I0oo*I0i{32}, R1oi{4}] - casp->split(1, 2); + casp->split(1, new Int(2)); // casp [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4} ] // new_domain[I0oi{16}, I0oo*I0i{32} , ir1oi{4}rf, // R(R1oo*R1i{8})rf] @@ -2547,20 +2352,18 @@ void testGPU_FusionRFactorReplay() { // Start off simple, block on the outer dim // block stride + thread all reduce + unrolling on inner dim void testGPU_FusionReduction() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, 128); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -2594,24 +2397,13 @@ void testGPU_FusionReduction() { int numel_x = 65000; int numel_y = 1025; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - numel_x, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); @@ -2619,19 +2411,17 @@ void testGPU_FusionReduction() { void testGPU_FusionReduction2() { { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); // switches to try some different scenarios. maybe we should iterate on all // permutations. @@ -2680,47 +2470,33 @@ void testGPU_FusionReduction2() { tv1->axis(-1)->parallelize(ParallelType::TIDx); } - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - bind_tidx ? tidx : 1, // tid_x - bind_tidy ? tidy : 1, // tid_y - 1, // tid_z - bind_bidx ? bidx : 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); auto aten_output = input.sum({1}); - TORCH_CHECK(aten_output.allclose(cg_output)); + TORCH_CHECK(aten_output.allclose(outputs[0])); } { // What if Z participates in the reduction with X? - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); int numel_x = 1025; // Cannot exceed block dim max size / tidy int numel_y = 129; @@ -2746,24 +2522,13 @@ void testGPU_FusionReduction2() { tv2->axis(-2)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDz); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - 1, // tid_y - tidz, // tid_z - numel_x, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {cg_output}); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); @@ -2776,10 +2541,8 @@ void testGPU_FusionReduction2() { // TODO: Fix and reenable this test. void testGPU_FusionReduction3() { { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); @@ -2788,18 +2551,18 @@ void testGPU_FusionReduction3() { TensorView* tv2 = add(tv0, tv1); // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1] - fusion->addInput(tv0); - fusion->addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv2); // tv3[I0, R1] = tv2[I0, I1] TensorView* tv4 = makeDummyTensor(1); - fusion->addInput(tv4); + fusion.addInput(tv4); // tv5[I0] = tv3[I0, R1] * tv4[I0] TensorView* tv5 = mul(tv3, tv4); - fusion->addOutput(tv5); + fusion.addOutput(tv5); int tidx = 16; @@ -2828,18 +2591,6 @@ void testGPU_FusionReduction3() { int numel_y = 129; int bidx = numel_x; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - 1, // tid_y - 1, // tid_z - bidx, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::rand({numel_x, numel_y}, options); at::Tensor t1 = at::rand({numel_x, numel_y}, options); @@ -2848,33 +2599,30 @@ void testGPU_FusionReduction3() { at::Tensor t4 = at::rand({numel_x}, options); auto t5 = t3.mul(t4); - at::Tensor cg_output = at::empty({numel_x}, options); - - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0, t1, t4}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1, t4}); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); TORCH_CHECK( - t5.allclose(cg_output), "Error of: ", t5.sub(cg_output).abs().max()); + t5.allclose(outputs[0]), "Error of: ", t5.sub(outputs[0]).abs().max()); } } void testGPU_FusionReduction4() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(3); - fusion->addInput(tv0); + fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); int bidy = 2; int tidy = 4; @@ -2890,8 +2638,8 @@ void testGPU_FusionReduction4() { tv1->axis(0)->parallelize(ParallelType::BIDy); - for (auto* val : fusion->vals()) { - if (!fusion->hasInput(val) && + for (auto* val : fusion.vals()) { + if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { val->as()->axis(-1)->parallelize(ParallelType::TIDx); } @@ -2900,48 +2648,35 @@ void testGPU_FusionReduction4() { tv2->axis(-2)->parallelize(ParallelType::TIDy); tv1->axis(-2)->parallelize(ParallelType::TIDy); - prog.setDevice(0); - torch::jit::fuser::cuda::compileKernel(&prog); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - tidy, // tid_y - 1, // tid_z - 1, // gid_x - bidy, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::randn({bidy, dim1, tidx}, options); at::Tensor cg_output = at::empty({bidy, tidx}, options); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } void testGPU_FusionReduction5() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); const int bdimx = 64; const int bdimy = 8; // Set up your input tensor views TensorView* tv0 = makeDummyTensor(3); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1, R2] = tv0[I0, I1, I2] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(2, bdimx); // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] @@ -2976,43 +2711,29 @@ void testGPU_FusionReduction5() { int numel_y = 1000; int numel_z = 1000; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - bdimx, // tid_x - bdimy, // tid_y - 1, // tid_z - numel_x, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); auto aten_output = input.sum({1, 2}); - TORCH_CHECK(aten_output.allclose(cg_output)); + TORCH_CHECK(aten_output.allclose(outputs[0])); } void testGPU_FusionReductionTFT() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); int numel_x = 1025; int numel_y = 129; @@ -3043,24 +2764,13 @@ void testGPU_FusionReductionTFT() { tv1->axis(-2)->parallelize(ParallelType::TIDz); tv2->axis(-2)->parallelize(ParallelType::TIDz); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - tidy, // tid_y - tidz, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {cg_output}); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); @@ -3071,16 +2781,14 @@ void testGPU_FusionReductionTFT() { void testGPU_FusionSimpleBCast() { { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); - fusion->addInput(tv0); - fusion->addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); TensorView* tv2 = broadcast(tv0, {false, false, true}); TensorView* tv3 = broadcast(tv1, {true, false, false}); @@ -3088,7 +2796,7 @@ void testGPU_FusionSimpleBCast() { TensorView* tv4 = add(tv2, tv3); tv4->split(-1, 4); tv4->split(0, 8); - fusion->addOutput(tv4); + fusion.addOutput(tv4); tv0->computeAt(tv4, -1); tv1->computeAt(tv4, -1); @@ -3103,40 +2811,26 @@ void testGPU_FusionSimpleBCast() { at::Tensor t0 = at::randn({x, y}, options); at::Tensor t1 = at::randn({y, z}, options); - at::Tensor cg_output = at::empty({x, y, z}, options); - - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 4, // tid_x - 1, // tid_y - 1, // tid_z - ceilDiv_(x, 8), // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0, t1}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t1}); auto t2 = t0.unsqueeze(-1).expand({x, y, z}); auto t3 = t1.expand({x, y, z}); auto t4 = t2.add(t3); - TORCH_CHECK(t4.allclose(cg_output)); + TORCH_CHECK(t4.allclose(outputs[0])); } { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); - fusion->addInput(tv0); - fusion->addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); // TODO add pointwise ops on the begining before the bcast. @@ -3147,7 +2841,7 @@ void testGPU_FusionSimpleBCast() { tv4->merge(0, 1); - fusion->addOutput(tv4); + fusion.addOutput(tv4); tv0->computeAt(tv4, -1); tv1->computeAt(tv4, -1); @@ -3163,19 +2857,9 @@ void testGPU_FusionSimpleBCast() { at::Tensor cg_output = at::empty({x, y, z}, options); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 1, // tid_x - 1, // tid_y - 1, // tid_z - x * y, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0, t1}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0, t1}, {cg_output}); auto t2 = t0.unsqueeze(-1).expand({x, y, z}); auto t3 = t1.expand({x, y, z}); @@ -3185,17 +2869,16 @@ void testGPU_FusionSimpleBCast() { } } +// Test a simple Gemm but also play around with fusion executor features void testGPU_FusionSimpleGemm() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); // M, K TensorView* tv1 = makeDummyTensor(2); // K, N - fusion->addInput(tv0); - fusion->addInput(tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); TensorView* tv2 = broadcast(tv0, {false, false, true}); // tv2[I0, I1, B] = tv0[I0, I1] @@ -3207,7 +2890,7 @@ void testGPU_FusionSimpleGemm() { TensorView* tv4 = mul(tv2, tv3); // tv5[I0, R1, I2] = tv4[I0, I1, I2] TensorView* tv5 = sum(tv4, {1}); - fusion->addOutput(tv5); + fusion.addOutput(tv5); tv5->split(1, 32); // tv5[I0, R1o, R1i{32}, I2] @@ -3253,42 +2936,37 @@ void testGPU_FusionSimpleGemm() { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - at::Tensor cg_output = at::empty({M, N}, options); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + // Lets specify a few bounds in launch params to make sure it works + fe.runFusion( + {t0, t1}, torch::jit::fuser::cuda::LaunchParams(1, -1, -1, 32, 4, 4)); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 32, // tid_x - 4, // tid_y - 4, // tid_z - 1, // gid_x - ceilDiv_(N, 4), // gid_y - ceilDiv_(M, 4), // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0, t1}, {cg_output}); + // Make sure bad launch params throws + ASSERT_ANY_THROW(fe.runFusion( + {t0, t1}, torch::jit::fuser::cuda::LaunchParams(1, 2, 3, 4, 5, 6))); + + // Don't specify any launch params + auto outputs = fe.runFusion({t0, t1}); auto t2 = t0.matmul(t1); TORCH_CHECK( - t2.allclose(cg_output, 1e-5, 1e-5), + t2.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", - t2.sub(cg_output).abs().max()); + t2.sub(outputs[0]).abs().max()); } // Softmax with a 1D tensor. Parallelized only with a single thread block. void testGPU_FusionSoftmax1D() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); const int tidx = 128; const int dimx = 1000; // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(1); - fusion->addInput(input_tv0); + fusion.addInput(input_tv0); TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); @@ -3300,7 +2978,7 @@ void testGPU_FusionSoftmax1D() { TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); - fusion->addOutput(output_tv4); + fusion.addOutput(output_tv4); sum_exp_tv2->split(-1, tidx); TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); @@ -3317,25 +2995,14 @@ void testGPU_FusionSoftmax1D() { tv->axis(-1)->parallelize(ParallelType::TIDx); } - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - 1, // tid_y - 1, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dimx}, options); at::Tensor cg_output = at::empty({dimx}, options); at::Tensor t3_output = at::empty_like(cg_output, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0}, {cg_output}); auto t2 = at::_softmax(t0, -1, false); TORCH_CHECK( @@ -3346,17 +3013,15 @@ void testGPU_FusionSoftmax1D() { // Softmax with a 1D tensor with input normalization. void testGPU_FusionSoftmax1DNormalized() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); const int tidx = 128; const int dimx = 1000; // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(1); - fusion->addInput(input_tv0); + fusion.addInput(input_tv0); // Normalize with the max value before computing exp. TensorView* max_val_tv1 = @@ -3374,7 +3039,7 @@ void testGPU_FusionSoftmax1DNormalized() { TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); - fusion->addOutput(output_tv7); + fusion.addOutput(output_tv7); max_val_tv1->split(-1, tidx); TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); @@ -3399,40 +3064,26 @@ void testGPU_FusionSoftmax1DNormalized() { tv->axis(-1)->parallelize(ParallelType::TIDx); } - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - 1, // tid_y - 1, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dimx}, options); - at::Tensor cg_output = at::empty({dimx}, options); - at::Tensor t3_output = at::empty_like(cg_output, options); - torch::jit::fuser::cuda::compileKernel(&prog); + at::Tensor t3_output = at::empty({dimx}, options); - torch::jit::fuser::cuda::runKernel(&prog, {t0}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0}); auto t2 = at::_softmax(t0, -1, false); TORCH_CHECK( - t2.allclose(cg_output, 1e-5, 1e-5), + t2.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", - t2.sub(cg_output).abs().max()); + t2.sub(outputs[0]).abs().max()); } // Softmax with a 3D tensor, where the inner-most 3rd dimension is // normalized. Pallelized with multiple thread blocks. void testGPU_FusionSoftmax3D() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); const int tidx = 32; const int dimx = 32; @@ -3441,7 +3092,7 @@ void testGPU_FusionSoftmax3D() { // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(3); - fusion->addInput(input_tv0); + fusion.addInput(input_tv0); TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); @@ -3453,7 +3104,7 @@ void testGPU_FusionSoftmax3D() { TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); - fusion->addOutput(output_tv4); + fusion.addOutput(output_tv4); sum_exp_tv2->split(-1, tidx); TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); @@ -3472,25 +3123,13 @@ void testGPU_FusionSoftmax3D() { tv->axis(-1)->parallelize(ParallelType::TIDx); } - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - 1, // tid_y - 1, // tid_z - dimx, // gid_x - dimy, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dimx, dimy, dimz}, options); at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); at::Tensor t3_output = at::empty_like(cg_output, options); - torch::jit::fuser::cuda::compileKernel(&prog); - - torch::jit::fuser::cuda::runKernel(&prog, {t0}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0}, {cg_output}); auto t2 = at::_softmax(t0, -1, false); TORCH_CHECK( @@ -3501,10 +3140,8 @@ void testGPU_FusionSoftmax3D() { // Softmax with a 3D tensor with input normalization. void testGPU_FusionSoftmax3DNormalized() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); const int tidx = 32; const int dimx = 32; @@ -3513,7 +3150,7 @@ void testGPU_FusionSoftmax3DNormalized() { // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(3); - fusion->addInput(input_tv0); + fusion.addInput(input_tv0); // Normalize with the max value before computing exp. TensorView* max_val_tv1 = @@ -3531,7 +3168,7 @@ void testGPU_FusionSoftmax3DNormalized() { TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); - fusion->addOutput(output_tv7); + fusion.addOutput(output_tv7); max_val_tv1->split(-1, tidx); TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); @@ -3558,42 +3195,28 @@ void testGPU_FusionSoftmax3DNormalized() { tv->axis(-1)->parallelize(ParallelType::TIDx); } - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - 1, // tid_y - 1, // tid_z - dimx, // gid_x - dimy, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dimx, dimy, dimz}, options); - at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); - at::Tensor t3_output = at::empty_like(cg_output, options); - torch::jit::fuser::cuda::compileKernel(&prog); + at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options); - torch::jit::fuser::cuda::runKernel(&prog, {t0}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0}); auto t2 = at::_softmax(t0, -1, false); TORCH_CHECK( - t2.allclose(cg_output, 1e-5, 1e-5), + t2.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", - t2.sub(cg_output).abs().max()); + t2.sub(outputs[0]).abs().max()); } void testGPU_FusionSoftmaxComputeAt() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); auto tv2 = broadcast(tv1, {false, true}); @@ -3606,7 +3229,7 @@ void testGPU_FusionSoftmaxComputeAt() { auto tv6 = broadcast(tv5, {false, true}); auto tv7 = sub(tv6, tv4); - fusion->addOutput(tv7); + fusion.addOutput(tv7); tv1->computeAt(tv7, 1); ASSERT_ANY_THROW(tv1->computeAt(tv7, -1)); @@ -3616,20 +3239,19 @@ void testGPU_FusionSoftmaxComputeAt() { void testGPU_FusionGridReduction1() { const int gdimx = 32; const int bdimx = 128; - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -3657,24 +3279,15 @@ void testGPU_FusionGridReduction1() { int numel_x = 10000; int numel_y = 65000; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - bdimx, // tid_x - 1, // tid_y - 1, // tid_z - gdimx, // gid_x - numel_x, // gid_y - 1, // gid_z - 0 // shared_memory size - ); + // fusion.printKernel(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); @@ -3684,20 +3297,19 @@ void testGPU_FusionGridReduction1() { void testGPU_FusionGridReduction2() { const int gdimy = 32; const int bdimx = 128; - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -3725,47 +3337,34 @@ void testGPU_FusionGridReduction2() { int numel_x = 10000; int numel_y = 65000; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - bdimx, // tid_x - 1, // tid_y - 1, // tid_z - numel_x, // gid_x - gdimy, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); auto aten_output = input.sum({1}); - TORCH_CHECK(aten_output.allclose(cg_output)); + TORCH_CHECK(aten_output.allclose(outputs[0])); } // Same test but uses BIDy and BIDz for reduction. No TID used. void testGPU_FusionGridReduction3dim1() { const int gdimz = 32; const int gdimy = 128; - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, gdimy); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -3793,27 +3392,13 @@ void testGPU_FusionGridReduction3dim1() { int numel_x = 100; int numel_y = 6500; - prog.setDevice(0); - // This number should not affect the output as TIDx is not - // used. All threads in a thread block redundantly computes the - // same value. - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - numel_x, // gid_x - gdimy, // gid_y - gdimz, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); @@ -3824,20 +3409,19 @@ void testGPU_FusionGridReduction3dim0() { const int rdim = 0; const int gdimy = 128; const int gdimz = 32; - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[R0, I1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {rdim}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(rdim, gdimy); // tv1[R0o, R0i{128}, I1] = tv0[I0, I1] @@ -3862,51 +3446,34 @@ void testGPU_FusionGridReduction3dim0() { int numel_x = 6500; int numel_y = 100; - prog.setDevice(0); - // This number should not affect the output as TIDx is not - // used. All threads in a thread block redundantly computes the - // same value. - setupLaunchConfig( - prog.fusion(), - 1, // tid_x - 1, // tid_y - 1, // tid_z - numel_y, // gid_x - gdimy, // gid_y - gdimz, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_y}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); auto aten_output = input.sum({0}); - TORCH_CHECK(aten_output.allclose(cg_output)); + TORCH_CHECK(aten_output.allclose(outputs[0])); } // This is similar to the FusionReduction, but swaps BIDx and TIDx void testGPU_FusionGridReduction4() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); const int bdimx = 128; const int gdimx = 1024; // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, gdimx); // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1] @@ -3940,24 +3507,13 @@ void testGPU_FusionGridReduction4() { int numel_x = bdimx; int numel_y = 65000; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - bdimx, // tid_x - 1, // tid_y - 1, // tid_z - gdimx, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); @@ -3966,10 +3522,8 @@ void testGPU_FusionGridReduction4() { // Grid reduction with 2D thread blocks but only TIDx and BIDx are // mapped to a reduction dim void testGPU_FusionGridReduction5() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); const int bdimx = 64; const int bdimy = 16; @@ -3977,13 +3531,13 @@ void testGPU_FusionGridReduction5() { // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{64}] = tv0[I0, I1] @@ -4007,45 +3561,31 @@ void testGPU_FusionGridReduction5() { int numel_x = bdimy; int numel_y = 6500; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - bdimx, // tid_x - bdimy, // tid_y - 1, // tid_z - gdimx, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); auto aten_output = input.sum({1}); - TORCH_CHECK(aten_output.allclose(cg_output)); + TORCH_CHECK(aten_output.allclose(outputs[0])); } // Similar to FusionGridReduction1 but with 3D tensors void testGPU_FusionGridReduction6() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(3); - fusion->addInput(tv0); + fusion.addInput(tv0); // tv1[I0, R1, R2] = tv0[I0, I1, I2] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); - TORCH_CHECK(fusion->hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); // Splitting for TID tv1->split(2, 128); @@ -4082,24 +3622,13 @@ void testGPU_FusionGridReduction6() { int numel_y = 200; int numel_z = numel_y; - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 128, // tid_x - 1, // tid_y - 1, // tid_z - 128, // gid_x - numel_x, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options); at::Tensor cg_output = at::empty({numel_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1, 2}); TORCH_CHECK(aten_output.allclose(cg_output)); @@ -4110,61 +3639,45 @@ void testGPU_FusionNonRedAxisBind() { int tid_x = 2; int red_dim = 0; - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); tv1->split(-1, tid_x); tv1->axis(-2)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tid_x, // tid_x - 1, // tid_y - 1, // tid_z - bid_x, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({16, bid_x * tid_x}, options); - at::Tensor cg_output = at::empty({bid_x * tid_x}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); auto aten_output = input.sum({red_dim}); TORCH_CHECK( - aten_output.allclose(cg_output), + aten_output.allclose(outputs[0]), "Error of: ", - aten_output.sub(cg_output).abs().max()); + aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionSplitBCast() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(3); TensorView* input_tv1 = makeDummyTensor(3); - fusion->addInput(input_tv0); - fusion->addInput(input_tv1); + fusion.addInput(input_tv0); + fusion.addInput(input_tv1); TensorView* sum_tv2 = reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0); @@ -4192,35 +3705,24 @@ void testGPU_FusionSplitBCast() { bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx); output_tv4->axis(-1)->parallelize(ParallelType::TIDx); - fusion->addOutput(output_tv4); + fusion.addOutput(output_tv4); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 32, // tid_x - 1, // tid_y - 1, // tid_z - 32, // gid_x - 32, // gid_y - 1, // gid_z - 0 // shared_memory size - ); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({32, 32, 128}, options); at::Tensor t1 = at::randn({32, 32, 128}, options); at::Tensor cg_output = at::empty({32, 32, 128}, options); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0, t1}, {cg_output}); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({t0, t1}, {cg_output}); } void testGPU_FusionBCastInnerDim() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); // reduce then broadcast auto tv1 = sum(tv0, {0}); @@ -4230,10 +3732,8 @@ void testGPU_FusionBCastInnerDim() { } void testGPU_FusionBCastReduce() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); @@ -4248,17 +3748,15 @@ void testGPU_FusionBCastReduce() { // Multiple consumer reduction with computeAt // https://github.com/csarofeen/pytorch/issues/110 void testGPU_FusionReductionMultiConsumer() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); auto tv1 = unaryOp(UnaryOpType::Exp, tv0); auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), tv1); auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Float(0), tv1); auto tv4 = add(tv2, tv3); - fusion->addOutput(tv4); + fusion.addOutput(tv4); tv1->computeAt(tv2, -1); TORCH_CHECK( @@ -4269,91 +3767,65 @@ void testGPU_FusionReductionMultiConsumer() { void testGPU_FusionComputeAtExprOrder() { { for (int i = 0; i < 2; ++i) { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(1); - fusion->addInput(tv0); + fusion.addInput(tv0); auto tv1 = add(tv0, new Float(1)); auto tv2 = add(tv0, new Float(1)); TensorView* tv3 = add(tv1, tv2); if (i == 0) { tv1->computeAt(tv3, -1); - fusion->addOutput(tv2); + fusion.addOutput(tv2); } else { tv2->computeAt(tv3, -1); - fusion->addOutput(tv1); + fusion.addOutput(tv1); } - fusion->addOutput(tv3); - - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 1, // tid_x - 1, // tid_y - 1, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - - torch::jit::fuser::cuda::compileKernel(&prog); + fusion.addOutput(tv3); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({100}, options); - at::Tensor output2 = at::empty_like(input, options); - at::Tensor output3 = at::empty_like(input, options); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {output2, output3}); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); + auto aten_output = (input + 1) * 2; TORCH_CHECK( - aten_output.allclose(output3), + aten_output.allclose(outputs[1]), "Error of: ", - aten_output.sub(output3).abs().max()); + aten_output.sub(outputs[1]).abs().max()); } } { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); auto tv1 = add(tv0, new Float(1)); auto tv2 = add(tv0, new Float(1)); TensorView* tv3 = add(tv1, tv2); - fusion->addOutput(tv3); + fusion.addOutput(tv3); tv3->split(-1, 32); tv1->computeAt(tv3, -1); tv2->computeAt(tv3, -2); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 1, // tid_x - 1, // tid_y - 1, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - - torch::jit::fuser::cuda::compileKernel(&prog); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({100, 100}, options); at::Tensor output = at::empty_like(input, options); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {output}); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {output}); + auto aten_output = (input + 1) * 2; TORCH_CHECK( aten_output.allclose(output), @@ -4363,85 +3835,60 @@ void testGPU_FusionComputeAtExprOrder() { } void testGPU_FusionZeroDimComputeAt() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(1); - fusion->addInput(tv0); + fusion.addInput(tv0); auto tv1 = sum(tv0, {0}); auto tv2 = add(tv1, new Float(1)); - fusion->addOutput(tv2); + fusion.addOutput(tv2); TORCH_CHECK(tv2->nDims() == 0); tv1->computeAt(tv2, 0); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 1, // tid_x - 1, // tid_y - 1, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - - torch::jit::fuser::cuda::compileKernel(&prog); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({100}, options); - at::Tensor output = at::empty({}, options); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {output}); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); + auto aten_output = input.sum() + 1; TORCH_CHECK( - aten_output.allclose(output), + aten_output.allclose(outputs[0]), "Error of: ", - aten_output.sub(output).abs().max()); + aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionZeroDimBroadcast() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(0); - fusion->addInput(tv0); + fusion.addInput(tv0); auto tv1 = broadcast(tv0, {true, true}); TORCH_CHECK(tv1->nDims() == 2); TensorView* tv2 = makeDummyTensor(2); - fusion->addInput(tv2); + fusion.addInput(tv2); auto tv3 = add(tv1, tv2); auto tv4 = sum(tv3, {0, 1}); - fusion->addOutput(tv4); + fusion.addOutput(tv4); tv3->computeAt(tv4, -1); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - 1, // tid_x - 1, // tid_y - 1, // tid_z - 1, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - - torch::jit::fuser::cuda::compileKernel(&prog); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::rand({}, options); at::Tensor input2 = at::rand({10, 10}, options); at::Tensor output = at::empty({}, options); - torch::jit::fuser::cuda::runKernel(&prog, {input1, input2}, {output}); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input1, input2}, {output}); + auto aten_output = (input1.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + input2).sum(); TORCH_CHECK( @@ -4451,19 +3898,17 @@ void testGPU_FusionZeroDimBroadcast() { } void testGPU_FusionZeroDimReduction() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); const int bdimx = 32; const int gdimx = 32; TensorView* tv0 = makeDummyTensor(1); - fusion->addInput(tv0); + fusion.addInput(tv0); auto tv1 = sum(tv0, {0}); - fusion->addOutput(tv1); + fusion.addOutput(tv1); tv1->split(0, bdimx); tv1->split(0, gdimx); @@ -4474,24 +3919,14 @@ void testGPU_FusionZeroDimReduction() { tv1->axis(-2)->parallelize(ParallelType::BIDx); tv2->axis(-2)->parallelize(ParallelType::BIDx); - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - bdimx, // tid_x - 1, // tid_y - 1, // tid_z - gdimx, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - - torch::jit::fuser::cuda::compileKernel(&prog); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({1000}, options); at::Tensor output = at::empty({}, options); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {output}); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({input}, {output}); + auto aten_output = input.sum(); TORCH_CHECK( aten_output.allclose(output), @@ -4500,16 +3935,13 @@ void testGPU_FusionZeroDimReduction() { } void testGPU_FusionBCastAfterReduce() { - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); - + Fusion fusion; + FusionGuard fg(&fusion); const int tidx = 128; // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); auto tv2 = broadcast(tv1, {false, true}); @@ -4518,10 +3950,10 @@ void testGPU_FusionBCastAfterReduce() { auto tv3 = tv1->rFactor({-2}); TensorView* tv4 = makeDummyTensor(2); - fusion->addInput(tv4); + fusion.addInput(tv4); auto tv5 = add(tv2, tv4); - fusion->addOutput(tv5); + fusion.addOutput(tv5); tv5->split(1, tidx); tv3->computeAt(tv5, 1); @@ -4540,27 +3972,15 @@ void testGPU_FusionBCastAfterReduce() { at::Tensor t0 = at::randn({x, y}, options); at::Tensor t4 = at::randn({x, y}, options); - at::Tensor cg_output = at::empty({x, y}, options); - - prog.setDevice(0); - setupLaunchConfig( - prog.fusion(), - tidx, // tid_x - 1, // tid_y - 1, // tid_z - x, // gid_x - 1, // gid_y - 1, // gid_z - 0 // shared_memory size - ); - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {t0, t4}, {cg_output}); + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t4}); auto t3 = t0.sum({1}).unsqueeze(-1).expand({x, y}); auto t5 = t3.add(t4); // Error is larger than the default threshold - TORCH_CHECK(t5.allclose(cg_output, 1e-5, 1e-5)); + TORCH_CHECK(t5.allclose(outputs[0], 1e-5, 1e-5)); } void testGPU_FusionReductionScheduler() { @@ -4606,6 +4026,57 @@ void testGPU_FusionReductionScheduler() { aten_output.sub(cg_output).abs().max()); } +// Simple reduction parallelized on a symbolic size. +void testGPU_FusionSymbolicReduction() { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeDummyTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); + fusion.addOutput(tv1); + + // Interface should just be a direct split with a Parallel type. We can + // include the parallelize call if we do this. + tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); + // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1] + // tv1[I0, R1oi{4}, R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] + + // Incrementally, can print in between for debugging + tv0->computeAt(tv2, 1); + tv2->computeAt(tv1, 1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 65000; + int numel_y = 1025; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::rand({numel_x, numel_y}, options); + + // How many threads to use for the block reduction + int runtime_threadIdx_dim = 128; + + torch::jit::fuser::cuda::FusionExecutor executor; + executor.compileFusion(&fusion); + auto outputs = executor.runFusion( + {input}, + torch::jit::fuser::cuda::LaunchParams( + -1, -1, -1, runtime_threadIdx_dim, -1, -1)); + + auto aten_output = input.sum({1}); + TORCH_CHECK(aten_output.allclose(outputs[0])); +} + void testGPU_FusionReductionSchedulerMultiDimNonFastest() { const std::vector red_dims = {0, 2}; // Copy is because CodeGen requires int and Pytorch requires int64_t diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 5d7d390136370b..c1fd06ad0a7024 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -180,7 +180,8 @@ namespace jit { _(GPU_FusionCacheBcast) \ _(GPU_FusionCacheComplex) \ _(GPU_FusionCacheMultiConsumer) \ - _(GPU_FusionConstCheck) + _(GPU_FusionConstCheck) \ + _(GPU_FusionSymbolicReduction) #else #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 7f88b52da2fb2e..a72b1d21c80cee 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -143,6 +143,7 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): o = o + z return o t_jit = torch.jit.script(t) + ''' x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") y = torch.randn(32, 32, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) @@ -150,6 +151,21 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): o = t(x, y, 2.0) self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) + x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") + y = torch.randn(1, 32, 32, dtype=torch.float, device="cuda") + jit_o = t_jit(x, y, 2.0) + jit_o = t_jit(x, y, 2.0) + o = t(x, y, 2.0) + self.assertEqual(o, jit_o) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) + ''' + x = torch.randn(4, 1, 32, 32, dtype=torch.float, device="cuda") + y = torch.randn(8, 32, 32, dtype=torch.float, device="cuda") + jit_o = t_jit(x, y, 2.0) + jit_o = t_jit(x, y, 2.0) + o = t(x, y, 2.0) + self.assertEqual(o, jit_o) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) @unittest.skipIf(True, "real broadcast with different output not supported yet") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @@ -173,6 +189,7 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): # Currently cannot fuse this self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP) + @unittest.skipIf(True, "real broadcast with different output not supported yet") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index bb047cfdb20b6b..766cb9524a96e8 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -334,8 +334,11 @@ libtorch_cuda_sources = [ "torch/csrc/jit/codegen/cuda/compute_at.cpp", "torch/csrc/jit/codegen/cuda/dispatch.cpp", "torch/csrc/jit/codegen/cuda/expr_evaluator.cpp", + "torch/csrc/jit/codegen/cuda/executor.cpp", + "torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp", + "torch/csrc/jit/codegen/cuda/executor_launch_params.cpp", + "torch/csrc/jit/codegen/cuda/executor_utils.cpp", "torch/csrc/jit/codegen/cuda/fusion.cpp", - "torch/csrc/jit/codegen/cuda/scheduler.cpp", "torch/csrc/jit/codegen/cuda/graph_fuser.cpp", "torch/csrc/jit/codegen/cuda/index_compute.cpp", "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp", @@ -354,18 +357,19 @@ libtorch_cuda_sources = [ "torch/csrc/jit/codegen/cuda/lower_validation.cpp", "torch/csrc/jit/codegen/cuda/lower2device.cpp", "torch/csrc/jit/codegen/cuda/manager.cpp", - "torch/csrc/jit/codegen/cuda/shape_inference.cpp", "torch/csrc/jit/codegen/cuda/mutator.cpp", "torch/csrc/jit/codegen/cuda/parser.cpp", "torch/csrc/jit/codegen/cuda/partition.cpp", "torch/csrc/jit/codegen/cuda/predicate_compute.cpp", + "torch/csrc/jit/codegen/cuda/register_interface.cpp", + "torch/csrc/jit/codegen/cuda/scheduler.cpp", + "torch/csrc/jit/codegen/cuda/shape_inference.cpp", "torch/csrc/jit/codegen/cuda/tensor_view.cpp", "torch/csrc/jit/codegen/cuda/transform_iter.cpp", "torch/csrc/jit/codegen/cuda/transform_replay.cpp", "torch/csrc/jit/codegen/cuda/transform_rfactor.cpp", "torch/csrc/jit/codegen/cuda/type.cpp", "torch/csrc/jit/codegen/cuda/utils.cpp", - "torch/csrc/jit/codegen/cuda/register_interface.cpp", "torch/csrc/jit/tensorexpr/cuda_codegen.cpp", ] diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 80d1530c66cc5b..e19777eeafe7f4 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -48,10 +48,10 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { "Tried to create new output TensorView but received empty list."); std::vector out_domain( - tvs[0]->domain()->noReductions().size(), nullptr); + TensorDomain::noReductions(tvs[0]->getRootDomain()).size(), nullptr); for (auto tv : tvs) { - auto dom = tv->domain()->noReductions(); + auto dom = TensorDomain::noReductions(tv->getRootDomain()); TORCH_INTERNAL_ASSERT( dom.size() == out_domain.size(), "Invalid tensor view found while producing and output, it has ", @@ -66,25 +66,64 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { out_domain[i] = new IterDomain(dom[i]->start(), dom[i]->extent()); } } - - std::transform( - out_domain.begin(), - out_domain.end(), - out_domain.begin(), - [](IterDomain* dom) { - if (dom == nullptr) - return new IterDomain( - new Int(0), new Int(1), ParallelType::Serial, false, false, true); - return dom; - }); + for (size_t dim_i = 0; dim_i < out_domain.size(); dim_i++) { + if (out_domain[dim_i] == nullptr) { + BroadcastType bcast_type = BroadcastType::WithoutStride; + for (const auto tv : tvs) { + auto dim = TensorDomain::noReductions(tv->getRootDomain())[dim_i]; + // If there's an unresolved bcast dim and it came from a strided dim, + // assume output of it should be strided too + if (dim->getBroadcastType() == BroadcastType::WithStride) { + bcast_type = BroadcastType::WithStride; + break; + } + } + out_domain[dim_i] = new IterDomain( + new Int(0), + new Int(1), + ParallelType::Serial, + false, + false, + bcast_type); + } + } return new TensorView(new TensorDomain(out_domain), dtype); } -Val* newOutputVal(const std::vector& vals) { - TORCH_INTERNAL_ASSERT( - !vals.empty(), "Cannot promote values if there aren't any."); +std::vector maybeBroadcast(const std::vector& vals) { + std::vector out_vals(vals.size(), nullptr); + size_t n_dims = 0; + for (auto val : vals) { + if (val->getValType().value() == ValType::TensorView) { + n_dims = std::max( + n_dims, + TensorDomain::noReductions(val->as()->getRootDomain()) + .size()); + } + } + for (size_t i = 0; i < vals.size(); i++) { + if (vals[i]->getValType().value() == ValType::TensorView) { + auto tv = vals[i]->as(); + size_t tv_dims = TensorDomain::noReductions(tv->getRootDomain()).size(); + if (tv_dims < n_dims) { + std::vector bcast_flags(n_dims, false); + for (size_t j = 0; j < n_dims - tv_dims; j++) { + bcast_flags[j] = true; + } + out_vals[i] = broadcast(tv, bcast_flags); + } else { + out_vals[i] = vals[i]; + } + } else { + out_vals[i] = vals[i]; + } + } + return out_vals; +} + +Val* newOutputVal(const std::vector& vals) { ValType out_vtype = vals[0]->getValType().value(); DataType out_dtype = vals[0]->getDataType().value(); @@ -180,10 +219,11 @@ TensorView* arithOpOverloads( T1* v1, T2* v2, T3* v3) { + auto vals = maybeBroadcast({v1, v2, v3}); return func( - v1->template as(), - v2->template as(), - v3->template as()) + vals[0]->template as(), + vals[1]->template as(), + vals[2]->template as()) ->template as(); } template @@ -193,17 +233,19 @@ TensorView* arithOpOverloads( T2* v2, T3* v3, T4* v4) { + auto vals = maybeBroadcast({v1, v2, v3, v4}); return func( - v1->template as(), - v2->template as(), - v3->template as(), - v4->template as()) + vals[0]->template as(), + vals[1]->template as(), + vals[2]->template as(), + vals[3]->template as()) ->template as(); } } // namespace -Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) { - Val* out = newOutputVal({v1, v2}); +TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) { + auto vals = maybeBroadcast({v1, v2}); + Val* out = newOutputVal({vals[0], vals[1]}); if (is_logical_op(type)) { if (out->getDataType().value() != DataType::Bool) out = newValLike(out, DataType::Bool); @@ -211,7 +253,8 @@ Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) { if (out->getDataType().value() != DataType::Int) out = newValLike(out, DataType::Int); } - new BinaryOp(type, out, v1, v2); + + new BinaryOp(type, out, vals[0], vals[1]); return out; } TensorView* binaryOp(BinaryOpType type, TensorView* v1, Val* v2) { @@ -390,7 +433,7 @@ static TensorView* newForReduction( ParallelType::Serial, isReduction, false, - id->isBroadcast())); + id->getBroadcastType())); } TensorDomain* td = new TensorDomain(new_domain); @@ -464,9 +507,10 @@ TensorView* broadcast( if (ent) n_broadcasts++; TORCH_CHECK( - nBCastDims - n_broadcasts == inp->domain()->noReductions().size(), + nBCastDims - n_broadcasts == + TensorDomain::noReductions(inp->getRootDomain()).size(), "Invalid broadcast, number of false entries in is_broadcast_dim expected to be ", - inp->domain()->noReductions().size(), + TensorDomain::noReductions(inp->getRootDomain()).size(), " but received ", nBCastDims - n_broadcasts); @@ -479,14 +523,20 @@ TensorView* broadcast( } std::vector out_domain; + auto inp_domain = TensorDomain::noReductions(inp->getRootDomain()); size_t iinp = 0, ibdim = 0; while (ibdim < is_broadcast_dim.size()) { if (is_broadcast_dim[ibdim]) { out_domain.push_back(new IterDomain( - new Int(0), new Int(1), ParallelType::Serial, false, false, true)); + new Int(0), + new Int(1), + ParallelType::Serial, + false, + false, + BroadcastType::WithoutStride)); } else { // Don't propagate reduction IDs through arith ops. - out_domain.push_back(inp->domain()->noReductions()[iinp]); + out_domain.push_back(inp_domain[iinp]); iinp++; } ibdim++; @@ -506,8 +556,9 @@ Val* add_alpha(Val* v1, Val* v2, Val* s) { "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); - Val* intrm = binaryOp(BinaryOpType::Mul, v2, s); - return binaryOp(BinaryOpType::Add, v1, intrm); + auto vals = maybeBroadcast({v1, v2, s}); + Val* intrm = binaryOp(BinaryOpType::Mul, vals[1], vals[2]); + return binaryOp(BinaryOpType::Add, vals[0], intrm); } TensorView* add_alpha(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(add_alpha, v1, v2, v3); @@ -525,8 +576,9 @@ Val* sub_alpha(Val* v1, Val* v2, Val* s) { "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); - Val* intrm = binaryOp(BinaryOpType::Mul, v2, s); - return binaryOp(BinaryOpType::Sub, v1, intrm); + auto vals = maybeBroadcast({v1, v2, s}); + Val* intrm = binaryOp(BinaryOpType::Mul, vals[1], vals[2]); + return binaryOp(BinaryOpType::Sub, vals[0], intrm); } TensorView* sub_alpha(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); @@ -538,10 +590,11 @@ TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } // lerp -Val* lerp(Val* start, Val* end, Val* weight) { - Val* intrm1 = binaryOp(BinaryOpType::Sub, end, start); - Val* intrm2 = binaryOp(BinaryOpType::Mul, weight, intrm1); - return binaryOp(BinaryOpType::Add, start, intrm2); +TORCH_CUDA_API Val* lerp(Val* start, Val* end, Val* weight) { + auto vals = maybeBroadcast({start, end, weight}); + Val* intrm1 = binaryOp(BinaryOpType::Sub, vals[1], vals[0]); + Val* intrm2 = binaryOp(BinaryOpType::Mul, vals[2], intrm1); + return binaryOp(BinaryOpType::Add, vals[0], intrm2); } TensorView* lerp(TensorView* v1, Val* v2, Val* v3) { return arithOpOverloads(lerp, v1, v2, v3); @@ -571,9 +624,10 @@ Val* addcmul(Val* v1, Val* v2, Val* v3, Val* s) { "Alpha value should be a Scalar Valtype and not ", s->getValType().value()); - Val* intrm1 = binaryOp(BinaryOpType::Mul, v3, s); - Val* intrm2 = binaryOp(BinaryOpType::Mul, v2, intrm1); - return binaryOp(BinaryOpType::Add, v1, intrm2); + auto vals = maybeBroadcast({v1, v2, v3, s}); + Val* intrm1 = binaryOp(BinaryOpType::Mul, vals[2], vals[3]); + Val* intrm2 = binaryOp(BinaryOpType::Mul, vals[1], intrm1); + return binaryOp(BinaryOpType::Add, vals[0], intrm2); } TensorView* addcmul(TensorView* v1, Val* v2, Val* v3, Val* v4) { return arithOpOverloads(addcmul, v1, v2, v3, v4); @@ -605,8 +659,9 @@ Val* where(Val* c, Val* v1, Val* v2) { "Condition should be of DataType Bool, not ", c->getDataType().value()); - Val* out = newOutputVal({v1, v2}); - new TernaryOp(TernaryOpType::Where, out, c, v1, v2); + auto vals = maybeBroadcast({c, v1, v2}); + Val* out = newOutputVal({vals[1], vals[2]}); + new TernaryOp(TernaryOpType::Where, out, vals[0], vals[1], vals[2]); return out; } TensorView* where(TensorView* v1, Val* v2, Val* v3) { diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 197c5e68e7e35b..ab6e40271b4ae1 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -102,6 +102,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::ReductionOp: ptr(handler)->handle(expr->as()); return; + case ExprType::GridReduction: + ptr(handler)->handle(expr->as()); + return; case ExprType::BroadcastOp: ptr(handler)->handle(expr->as()); return; @@ -192,6 +195,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::ReductionOp: ptr(handler)->handle(expr->as()); return; + case ExprType::GridReduction: + ptr(handler)->handle(expr->as()); + return; case ExprType::BroadcastOp: ptr(handler)->handle(expr->as()); return; @@ -278,6 +284,8 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { return ptr(mutator)->mutate(expr->as()); case ExprType::ReductionOp: return ptr(mutator)->mutate(expr->as()); + case ExprType::GridReduction: + return ptr(mutator)->mutate(expr->as()); case ExprType::BroadcastOp: return ptr(mutator)->mutate(expr->as()); case ExprType::ForLoop: diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 50f4451d44ac44..62f70278992836 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -73,6 +73,7 @@ class UnaryOp; class BinaryOp; class TernaryOp; class ReductionOp; +class GridReduction; class BroadcastOp; class ForLoop; class IfThenElse; @@ -116,6 +117,7 @@ class TORCH_CUDA_API OptOutConstDispatch { virtual void handle(const BinaryOp*) {} virtual void handle(const TernaryOp*) {} virtual void handle(const ReductionOp*) {} + virtual void handle(const GridReduction*) {} virtual void handle(const BroadcastOp*) {} virtual void handle(const ForLoop*) {} virtual void handle(const IfThenElse*) {} @@ -156,6 +158,7 @@ class TORCH_CUDA_API OptOutDispatch { virtual void handle(BinaryOp*) {} virtual void handle(TernaryOp*) {} virtual void handle(ReductionOp*) {} + virtual void handle(GridReduction*) {} virtual void handle(BroadcastOp*) {} virtual void handle(ForLoop*) {} virtual void handle(IfThenElse*) {} @@ -226,6 +229,9 @@ class TORCH_CUDA_API OptInConstDispatch { virtual void handle(const ReductionOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp."); } + virtual void handle(const GridReduction*) { + TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GridReduction."); + } virtual void handle(const BroadcastOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp."); } @@ -304,6 +310,9 @@ class TORCH_CUDA_API OptInDispatch { virtual void handle(ReductionOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp."); } + virtual void handle(GridReduction*) { + TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GridReduction."); + } virtual void handle(BroadcastOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp."); } @@ -375,6 +384,7 @@ class TORCH_CUDA_API OptOutMutator { virtual Statement* mutate(BinaryOp*); virtual Statement* mutate(TernaryOp*); virtual Statement* mutate(ReductionOp*); + virtual Statement* mutate(GridReduction*); virtual Statement* mutate(BroadcastOp*); virtual Statement* mutate(ForLoop*); virtual Statement* mutate(IfThenElse*); @@ -452,6 +462,9 @@ class TORCH_CUDA_API OptInMutator { virtual Statement* mutate(ReductionOp*) { TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ReductionOp."); } + virtual Statement* mutate(GridReduction*) { + TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for GridReduction."); + } virtual Statement* mutate(BroadcastOp*) { TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BroadcastOp."); } diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp new file mode 100644 index 00000000000000..eb2aff9a818110 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -0,0 +1,322 @@ +#include +#include +#include +#include + +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +int FusionExecutor::fusion_id_counter = 0; + +std::string FusionExecutor::getStructuredCode(const std::string& kernel) { + // generating cuda code; + std::string code = std::string("namespace ") + FusionExecutor::Namespace() + + " {\n" + executor_utils::kernelPreamble() + kernel + "}\n"; + + const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG"); + if (debug_env && atoi(debug_env)) { + std::cout << "\n==== codegen output for kernel: " << KernelName() + << " ====" << std::endl + << code << std::endl + << "=====*===============================" << std::endl; + } + + return code; +} + +void FusionExecutor::compileFusion(Fusion* fusion) { + TORCH_INTERNAL_ASSERT( + !fusion->outputs().empty(), "No output found for this kernel, aborting."); + + for (auto out : fusion->outputs()) + TORCH_INTERNAL_ASSERT( + out->getValType() == ValType::TensorView, + "Output types from fusions that are not tensors are not supported at this point."); + + fusion_ = *fusion; + FusionGuard fg(&fusion_); + + fusion_id = ++fusion_id_counter; + has_random = fusion->hasRNG(); + lowered = GPULower(&fusion_); + auto kernel = lowered.getKernel(KernelName()); + auto structured_code = getStructuredCode(lowered.getKernel(KernelName())); + + compiled_kernel = executor_utils::nvrtcCompile( + structured_code, (Namespace() + "::" + KernelName()).c_str(), fusion_id); +} + +namespace { + +// Check if a value is already bound, if so validate we're trying to bind to the +// same value +void safeBind( + EvaluationContext& ec, + const Val* value, + Int::ScalarType concrete_value) { + auto already_concrete_val = ec.concreteValue(value); + + if (already_concrete_val.has_value()) { + TORCH_INTERNAL_ASSERT( + concrete_value == already_concrete_val.value(), + "Tried to bind ", + value, + " to ", + " concrete value, but it's already set to ", + already_concrete_val.value()); + } else { + ec.bind(value, concrete_value); + } +} + +EvaluationContext bindInputs( + const at::ArrayRef& aten_inputs, + Fusion* fusion) { + TORCH_INTERNAL_ASSERT( + fusion->inputs().size() == aten_inputs.size(), + "Something went wrong configuring launch. Inputs no longer match."); + + auto fusion_inputs = fusion->inputs(); + EvaluationContext ec(fusion); + + // This should probably move to EvaluationContext as we may want to bind + // input values frequently. Bind fusion input values to runtime values. + for (size_t i = 0; i < fusion->inputs().size(); i++) { + if (fusion->inputs()[i]->getValType() == ValType::TensorView) { + TensorView* cg_tensor = fusion->inputs()[i]->as(); + + TORCH_INTERNAL_ASSERT( + aten_inputs[i].isTensor(), + "Something went wrong configuring launch. Inputs no longer match."); + + auto aten_tensor = aten_inputs[i].toTensor(); + auto root_dom = TensorDomain::noReductions(cg_tensor->getRootDomain()); + TORCH_INTERNAL_ASSERT( + aten_tensor.ndimension() == root_dom.size(), + "Something went wrong configuring launch. Inputs no longer match."); + + for (size_t dim = 0; dim < root_dom.size(); dim++) { + safeBind(ec, root_dom[dim]->extent(), aten_tensor.sizes()[dim]); + } + } + } + return std::move(ec); +} + +at::Tensor inferAndAlloc( + TensorView* tv, + EvaluationContext& ec, + const CompileOptions& options, + bool zero_init = false) { + std::vector sizes; + for (auto id : TensorDomain::noReductions(tv->getRootDomain())) { + auto infered_val = ExpressionEvaluator::evaluate(id->rawExtent(), &ec); + TORCH_INTERNAL_ASSERT( + infered_val.has_value(), + "Could not launch kernel as program could not infer ", + id->rawExtent(), + " for the buffer ", + tv); + sizes.push_back(infered_val.value()); + } + + auto at_type = data_type_to_aten(tv->getDataType().value()); + auto tensor_options = + at::TensorOptions().dtype(at_type).device(options.device); + + if (zero_init) { + c10::IntArrayRef isizes(sizes); + return at::zeros(isizes, tensor_options); + } else { + c10::IntArrayRef isizes(sizes); + return at::empty(isizes, tensor_options); + } +} +} // namespace + +LaunchParams FusionExecutor::computeLaunchParams( + const at::ArrayRef& aten_inputs, + const LaunchParams& launch_constraints, + EvaluationContext& ec) { + LaunchParams launch_params; + + // Grab all values that are actually used in the fusion + auto unordered_vals = DependencyCheck::getAllValsBetween( + {fusion_.inputs().begin(), fusion_.inputs().end()}, fusion_.outputs()); + + // Lets collect all IterDomains that are bound to a thread binding + std::unordered_map, TypeHash> + parallel_iter_domains; + + for (auto val : unordered_vals) { + if (val->getValType().value() == ValType::TensorView) { + TensorView* tv = val->as(); + for (auto id : tv->domain()->domain()) { + if (id->isThread() && !id->isBroadcast()) { + if (parallel_iter_domains.find(id->parallel_method()) != + parallel_iter_domains.end()) { + parallel_iter_domains.at(id->parallel_method()).push_back(id); + } else { + parallel_iter_domains[id->parallel_method()] = + std::vector({id}); + } + } + } + } + } + + // If any dimension was set in launch constraints we need to run through + // IterDomains that have been parallelized, and bind those values. Or make + // sure if they could be infered the inference matches what was set. + if (launch_constraints.nBlocks() * launch_constraints.nThreads() != -1) { + for (auto& entry : parallel_iter_domains) { + auto p_type = entry.first; + if (launch_constraints.hasDim(p_type)) { + auto parallel_ids = entry.second; + for (auto parallel_id : parallel_ids) { + auto infered_val = + ExpressionEvaluator::evaluate(parallel_id->rawExtent(), &ec); + if (infered_val.has_value()) { + // This value could have been infered, make sure it was set right. + TORCH_CHECK( + infered_val.value() == launch_constraints.getDim(p_type) || + launch_constraints.getRawVal(p_type) == -1, + "Infered that ", + p_type, + " should be set to ", + infered_val.value(), + " but launch constraints specified ", + launch_constraints.getDim(p_type)); + } else { + // Bind the launch constraint into our evaluation context + safeBind( + ec, + parallel_id->rawExtent(), + launch_constraints.getDim(entry.first)); + launch_params.bind(launch_constraints.getDim(p_type), p_type); + } + } + } + } + } + + // Run through the rest of the parallel IterDomains and infer their size + for (auto& entry : parallel_iter_domains) { + auto p_type = entry.first; + auto parallel_ids = entry.second; + for (auto parallel_id : parallel_ids) { + auto val = ExpressionEvaluator::evaluate(parallel_id->rawExtent(), &ec); + TORCH_INTERNAL_ASSERT( + val, + "Tried to evaluate the extent of ", + parallel_id, + " to set launch bounds but could not."); + launch_params.bind(val.value(), p_type); + } + } + + return launch_params; +} + +std::vector FusionExecutor::allocGlobalVals(EvaluationContext& ec) { + std::vector global_buffers; + for (auto alloc : lowered.global_allocations()) { + TORCH_INTERNAL_ASSERT( + alloc->buffer()->getValType() == ValType::TensorView, + "Cannot allocate global buffers that are not tensors."); + global_buffers.push_back( + inferAndAlloc(alloc->buffer()->as(), ec, options_, false)); + } + + for (auto alloc : lowered.sync_allocations()) { + TORCH_INTERNAL_ASSERT( + alloc->buffer()->getValType() == ValType::TensorView, + "Cannot allocate global buffers that are not tensors."); + global_buffers.push_back( + inferAndAlloc(alloc->buffer()->as(), ec, options_, true)); + } + + return global_buffers; +} + +std::vector FusionExecutor::allocOutputs(EvaluationContext& ec) { + std::vector outputs; + for (auto output : fusion_.outputs()) { + TORCH_INTERNAL_ASSERT( + output->getValType() == ValType::TensorView, + "Cannot allocate outputs that are not tensors."); + outputs.push_back( + inferAndAlloc(output->as(), ec, options_, false)); + } + return outputs; +} + +std::vector FusionExecutor::runFusion( + const at::ArrayRef& inputs, + const std::vector& outputs, + const LaunchParams& launch_constraints) { + TORCH_INTERNAL_ASSERT( + fusion_id > 0, "Cannot run fusion, it was not compiled."); + + FusionGuard fg(&fusion_); + + executor_utils::validateKernelInputs(&fusion_, inputs, options_.device); + + const auto prior_device = at::cuda::current_device(); + c10::DeviceGuard dg(options_.device); + auto stream = at::cuda::getCurrentCUDAStream(); + + EvaluationContext evaluation_context = bindInputs(inputs, &fusion_); + + LaunchParams launch_params = + computeLaunchParams(inputs, launch_constraints, evaluation_context); + + std::vector alloced_outputs = outputs; + if (outputs.empty() || outputs.size() != fusion_.outputs().size()) { + alloced_outputs = allocOutputs(evaluation_context); + } + + executor_utils::validateKernelOutputs( + &fusion_, alloced_outputs, options_.device); + + KernelArgumentHolder kernel_arguments; + kernel_arguments.push(inputs); + kernel_arguments.push(alloced_outputs); + auto buffers = allocGlobalVals(evaluation_context); + kernel_arguments.push(buffers); + + if (has_random) { + const auto rand_offset = 4 * + (std::ceil( + alloced_outputs[0].numel() / (4.0 * 128 * launch_params.gdimx())) + + 1); + kernel_arguments.appendPhiloxRNGSeed(rand_offset); + } + + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLaunchKernel( + compiled_kernel.function, + launch_params.gdimx(), + launch_params.gdimy(), + launch_params.gdimz(), + launch_params.bdimx(), + launch_params.bdimy(), + launch_params.bdimz(), + 0, // smem + stream, + kernel_arguments.getBuffer(), + nullptr)); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + + return alloced_outputs; +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h new file mode 100644 index 00000000000000..f070dafa56308f --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -0,0 +1,83 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +struct TORCH_CUDA_API CompileOptions { + c10::Device device = c10::Device(c10::DeviceType::CUDA, 0); +}; + +class TORCH_CUDA_API FusionExecutor { + public: + FusionExecutor() {} + FusionExecutor(CompileOptions options) : options_(options) {} + + void compileFusion(Fusion* fusion); + + std::vector runFusion( + const at::ArrayRef& inputs, + const std::vector& outputs, + const LaunchParams& launch_constraints = LaunchParams()); + + std::vector runFusion( + const at::ArrayRef& inputs, + const LaunchParams& launch_constraints = LaunchParams()) { + return runFusion(inputs, {}, launch_constraints); + } + + private: + std::string KernelName() const { + std::stringstream ss; + ss << "kernel" << fusion_id; + return ss.str(); + } + + static std::string Namespace() { + return "CudaCodeGen"; + } + + // Add preamble and wrap in namespace + std::string getStructuredCode(const std::string& kernel); + + LaunchParams computeLaunchParams( + const at::ArrayRef& aten_inputs, + const LaunchParams& launch_constraints, + EvaluationContext& ec); + + std::vector allocGlobalVals(EvaluationContext& ec); + + std::vector allocOutputs(EvaluationContext& ec); + + Fusion fusion_; + + CompileOptions options_; + + executor_utils::NvrtcFunction compiled_kernel; + + // State of the fusion that's important + bool has_random = false; + + // Counter to be used for kernel name. + int fusion_id = -1; + static int fusion_id_counter; + + GPULower lowered; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp new file mode 100644 index 00000000000000..dd75d8c7bf64d0 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp @@ -0,0 +1,144 @@ +#include + +// Extract size and strides +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +std::unique_ptr getTensorArg( + c10::ScalarType dtype, + int nDims) { + switch (dtype) { + case c10::ScalarType::Float: + return getTensorArg(nDims); + case c10::ScalarType::Half: + return getTensorArg(nDims); + case c10::ScalarType::Bool: + return getTensorArg(nDims); + case c10::ScalarType::Long: + return getTensorArg(nDims); + default: + TORCH_CHECK( + false, + "Dtype: ", + dtype, + " not currently supported in code generated kernels."); + } +} + +// Push a tensor to the arguments +void KernelArgumentHolder::push(const at::Tensor& tensor) { + changed_ = true; + int nDims = tensor.ndimension(); + + c10::ScalarType dtype = tensor.scalar_type(); + std::unique_ptr tensor_arg = getTensorArg(dtype, nDims); + tensor_arg->setPointer(tensor.data_ptr()); + for (int i = 0; i < nDims; i++) { + tensor_arg->setSize(i, tensor.sizes()[i]); + tensor_arg->setStride(i, tensor.strides()[i]); + } + arguments_.push_back(std::move(tensor_arg)); +} + +// Push a tensor to the arguments +void KernelArgumentHolder::push( + const at::Tensor& val, + c10::optional broadcasted_size) { + changed_ = true; + ExtractSizeStride ess(val, std::move(broadcasted_size)); + int nDims = ess.sizes.size(); + + c10::ScalarType dtype = val.scalar_type(); + std::unique_ptr tensor_arg = getTensorArg(dtype, nDims); + tensor_arg->setPointer(val.data_ptr()); + for (int i = 0; i < nDims; i++) { + tensor_arg->setSize(i, ess.sizes[i]); + tensor_arg->setStride(i, ess.strides[i]); + } + arguments_.push_back(std::move(tensor_arg)); +} + +// Push a scalar or integer to the arguments +void KernelArgumentHolder::push(const IValue& val) { + changed_ = true; + TORCH_INTERNAL_ASSERT( + val.isScalar(), + "Tried to push an arg to run in a fused kernel, expected a scalar but got, ", + val); + switch (val.toScalar().type()) { + case c10::ScalarType::Double: + arguments_.push_back(std::make_unique((float)val.toDouble())); + return; + case c10::ScalarType::Long: + arguments_.push_back(std::make_unique((int)val.toInt())); + return; + default: + TORCH_INTERNAL_ASSERT( + false, + " Tried to create argument to send to a fused kernel, but got an unexpected type."); + } + TORCH_INTERNAL_ASSERT( + false, + " Tried to create argument to send to a fused kernel, but got a non-scalar type."); +} + +void KernelArgumentHolder::push(const uint64_t& val) { + arguments_.push_back(std::make_unique(val)); +} + +// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers +// in the buffer +void** KernelArgumentHolder::getBuffer() { + if (changed_) { + void_ptrs_ = std::vector(arguments_.size(), nullptr); + for (size_t i = 0; i < arguments_.size(); i++) { + void_ptrs_[i] = static_cast(arguments_[i]->arg()); + } + changed_ = false; + } + return void_ptrs_.data(); +} + +void KernelArgumentHolder::push(const c10::ArrayRef& args) { + // Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O + // allocated here from the subgraph could be, and very likely are, different + // from I/O expected by the generated CUDA kernel. + for (const auto& arg : args) { + if (arg.isTensor()) { + push(arg.toTensor()); + } else { + push(arg); + } + } +} + +void KernelArgumentHolder::push(const std::vector& tensors) { + for (const auto& tensor : tensors) { + push(tensor); + } +} + +void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) { + std::pair philox_engine_inputs; + auto gen = at::cuda::detail::getDefaultCUDAGenerator(); + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + philox_engine_inputs = + at::check_generator(gen)->philox_engine_inputs( + rand_offset); + } + push(philox_engine_inputs.first); + push(philox_engine_inputs.second); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_arg.h b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h similarity index 59% rename from torch/csrc/jit/codegen/cuda/kernel_arg.h rename to torch/csrc/jit/codegen/cuda/executor_kernel_arg.h index 984afb063a3ad6..4c6ab9f176521f 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_arg.h +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h @@ -1,6 +1,8 @@ #pragma once -#include +#include +#include +#include namespace torch { namespace jit { @@ -103,26 +105,26 @@ struct TensorArg : public TensorArgAbstract { }; template -TensorArgAbstract* getTensorArg(int nDims) { +std::unique_ptr getTensorArg(int nDims) { switch (nDims) { case (0): - return new TensorArg>(); + return std::make_unique>>(); case (1): - return new TensorArg>(); + return std::make_unique>>(); case (2): - return new TensorArg>(); + return std::make_unique>>(); case (3): - return new TensorArg>(); + return std::make_unique>>(); case (4): - return new TensorArg>(); + return std::make_unique>>(); case (5): - return new TensorArg>(); + return std::make_unique>>(); case (6): - return new TensorArg>(); + return std::make_unique>>(); case (7): - return new TensorArg>(); + return std::make_unique>>(); case (8): - return new TensorArg>(); + return std::make_unique>>(); default: TORCH_INTERNAL_ASSERT( false, @@ -132,22 +134,41 @@ TensorArgAbstract* getTensorArg(int nDims) { } } -TensorArgAbstract* getTensorArg(c10::ScalarType dtype, int nDims) { - switch (dtype) { - case (at::kFloat): - return getTensorArg(nDims); - case (at::kHalf): - return getTensorArg(nDims); - case (at::kBool): - return getTensorArg(nDims); - default: - TORCH_CHECK( - false, - "Dtype: ", - dtype, - " not currently supported in code generated kernels."); - } -} +std::unique_ptr getTensorArg( + c10::ScalarType dtype, + int nDims); + +class KernelArgumentHolder { + public: + // Push a tensor to the arguments + void push(const at::Tensor& tensor); + + // We want to get rid of this version, it's a hack for now because we don't + // have great broadcast support for translation. + void push( + const at::Tensor& tensor, + c10::optional broadcasted_size); + + // Push a scalar or integer to the arguments + void push(const IValue& val); + + void push(const uint64_t& val); + + // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers + // in the buffer + void** getBuffer(); + + void push(const c10::ArrayRef& args); + + void push(const std::vector& tensors); + + void appendPhiloxRNGSeed(uint64_t rand_offset); + + private: + std::vector> arguments_; + std::vector void_ptrs_; + bool changed_ = true; +}; } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/executor_launch_params.cpp b/torch/csrc/jit/codegen/cuda/executor_launch_params.cpp new file mode 100644 index 00000000000000..17f96f34170382 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/executor_launch_params.cpp @@ -0,0 +1,87 @@ +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +void LaunchParams::bind(int64_t val, ParallelType p_type) { + switch (p_type) { + case ParallelType::TIDx: + checkAndSet(val, bdimx_, "blockDim.x"); + break; + case ParallelType::BIDx: + checkAndSet(val, gdimx_, "gridDim.x"); + break; + case ParallelType::TIDy: + checkAndSet(val, bdimy_, "blockDim.y"); + break; + case ParallelType::BIDy: + checkAndSet(val, gdimy_, "gridDim.y"); + break; + case ParallelType::TIDz: + checkAndSet(val, bdimz_, "blockdim.z"); + break; + case ParallelType::BIDz: + checkAndSet(val, gdimz_, "gridDim.z"); + break; + default: + TORCH_INTERNAL_ASSERT( + false, + "Tried to bind invalid parallel type in launch config: ", + p_type); + } +} + +int64_t LaunchParams::getDim(ParallelType p_type) const { + switch (p_type) { + case ParallelType::TIDx: + return bdimx(); + case ParallelType::BIDx: + return gdimx(); + case ParallelType::TIDy: + return bdimy(); + case ParallelType::BIDy: + return gdimy(); + case ParallelType::TIDz: + return bdimz(); + case ParallelType::BIDz: + return gdimz(); + default: + TORCH_INTERNAL_ASSERT( + false, + "Tried to get with invalid parallel type in launch config: ", + p_type); + } +} + +bool LaunchParams::hasDim(ParallelType p_type) const { + return getRawVal(p_type) != UNINITIALIZED_VAL; +} + +const int64_t& LaunchParams::getRawVal(ParallelType p_type) const { + switch (p_type) { + case ParallelType::TIDx: + return bdimx_; + case ParallelType::BIDx: + return gdimx_; + case ParallelType::TIDy: + return bdimy_; + case ParallelType::BIDy: + return gdimy_; + case ParallelType::TIDz: + return bdimz_; + case ParallelType::BIDz: + return gdimz_; + default: + TORCH_INTERNAL_ASSERT( + false, + "Tried to get with invalid parallel type in launch config: ", + p_type); + } +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/executor_launch_params.h b/torch/csrc/jit/codegen/cuda/executor_launch_params.h new file mode 100644 index 00000000000000..8bdc316432f095 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/executor_launch_params.h @@ -0,0 +1,117 @@ +#pragma once +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class TORCH_CUDA_API LaunchParams { + static constexpr int64_t UNINITIALIZED_VAL = -1; + + public: + LaunchParams( + int64_t gdimx = UNINITIALIZED_VAL, + int64_t gdimy = UNINITIALIZED_VAL, + int64_t gdimz = UNINITIALIZED_VAL, + int64_t bdimx = UNINITIALIZED_VAL, + int64_t bdimy = UNINITIALIZED_VAL, + int64_t bdimz = UNINITIALIZED_VAL) + : gdimx_(gdimx), + gdimy_(gdimy), + gdimz_(gdimz), + bdimx_(bdimx), + bdimy_(bdimy), + bdimz_(bdimz) {} + + int64_t smem() const { + return smem_; + } + int64_t nBlocks() const { + return gdimx_ * gdimy_ * gdimz_; + } + + int64_t nThreads() const { + return bdimx_ * bdimy_ * bdimz_; + } + + int64_t bdimx() const { + return static_cast(bdimx_ == UNINITIALIZED_VAL ? 1 : bdimx_); + } + + int64_t gdimx() const { + return static_cast(gdimx_ == UNINITIALIZED_VAL ? 1 : gdimx_); + } + + int64_t bdimy() const { + return static_cast(bdimy_ == UNINITIALIZED_VAL ? 1 : bdimy_); + } + + int64_t gdimy() const { + return static_cast(gdimy_ == UNINITIALIZED_VAL ? 1 : gdimy_); + } + + int64_t bdimz() const { + return static_cast(bdimz_ == UNINITIALIZED_VAL ? 1 : bdimz_); + } + + int64_t gdimz() const { + return static_cast(gdimz_ == UNINITIALIZED_VAL ? 1 : gdimz_); + } + + void checkAndSet( + const int64_t incoming_val, + int64_t& class_val, + std::string val) { + TORCH_INTERNAL_ASSERT( + class_val == UNINITIALIZED_VAL || incoming_val == class_val, + "Tried to set ", + val, + " to ", + incoming_val, + ", but it was already set and new value does not match.", + " Thread dims all have to be bound to the same value."); + TORCH_CHECK( + incoming_val > 0, + "Received a thread binding on ", + val, + " that is ", + incoming_val, + ". Cannot create negative threads."); + if (class_val == UNINITIALIZED_VAL) { + class_val = incoming_val; + } + } + + // Binds dim assocaited with p_type to val + void bind(int64_t val, ParallelType p_type); + + // Adjusted value based on get functions above for each value + int64_t getDim(ParallelType p_type) const; + + // Returns raw value which may be UNINITIALIZED_VAL + const int64_t& getRawVal(ParallelType p_type) const; + + // Returns false if value associated with p_type == UNINITIALIZED_VAL + bool hasDim(ParallelType p_type) const; + + private: + // Spell them out because I want signed ints to know if they were initialized + // or not. + // TODO: convert to c10::optional + int64_t gdimx_ = UNINITIALIZED_VAL; + int64_t gdimy_ = UNINITIALIZED_VAL; + int64_t gdimz_ = UNINITIALIZED_VAL; + int64_t bdimx_ = UNINITIALIZED_VAL; + int64_t bdimy_ = UNINITIALIZED_VAL; + int64_t bdimz_ = UNINITIALIZED_VAL; + + int64_t smem_ = 0; + + // TODO: Fill in output sizes + std::vector> output_sizes; +}; +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp new file mode 100644 index 00000000000000..11964869216fee --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -0,0 +1,306 @@ +#include +#include + +#include + +#include + +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace executor_utils { + +std::string kernelPreamble() { + std::stringstream ss; + ss << code_template_tensor_struct << "\n" + << code_fp16_support << "\n" + << code_random_number_gen << "\n" + << code_helper_funcs << "\n" + << code_template_block_reduction << "\n" + << code_template_grid_reduction << "\n" + << code_template_block_broadcast << "\n"; + return ss.str(); +} + +bool validateKernelArgTensor( + const at::Tensor& arg, + const Val* param, + c10::Device device, + std::stringstream& msg) { + // Arg is a tensor. Param must be a tensor too. + if (*param->getValType() != ValType::TensorView) { + msg << "Argument is a tensor, but the parameter is not."; + return false; + } + + // Check the rank of the tensors. + size_t arg_dim = arg.dim(); + // Note: This requires current Fusion to be active. + size_t param_dim = TensorDomain::noReductions( + static_cast(param)->getRootDomain()) + .size(); + // see [Note - broadcast support in integration] + // Because of broadcasting support handled in integration, we relax the rank + // check as necessary. + if (arg_dim > param_dim) { + msg << "Argument tensor's rank is " << arg_dim << ", but the parameter is " + << param_dim; + return false; + } + + if (arg.device() != device) { + msg << "Argument is on device that is not compiled for"; + return false; + } + // Check element type + at::ScalarType arg_data_type = arg.scalar_type(); + DataType param_data_type = *param->getDataType(); + bool match = false; + switch (arg_data_type) { + case at::ScalarType::Half: + match = param_data_type == DataType::Half; + break; + case at::ScalarType::Float: + match = param_data_type == DataType::Float; + break; + case at::ScalarType::Bool: + match = param_data_type == DataType::Bool; + break; + default: + msg << "Argument element type, " << arg_data_type + << ", is not supported."; + return false; + } + if (!match) + msg << "Argument element type is " << arg_data_type + << ", but the parameter is " << param_data_type; + return match; +} + +bool validateKernelArgScalar( + const c10::TypePtr& arg_type, + const Val* param, + std::stringstream& msg) { + if (!param->isScalar()) { + msg << "Argument is a scalar, but the parameter is not."; + return false; + } + DataType param_type = *param->getDataType(); + bool match = false; + switch (arg_type->kind()) { + case c10::TypeKind::IntType: + match = param_type == DataType::Int; + break; + case c10::TypeKind::FloatType: + match = param_type == DataType::Float; + break; + case c10::TypeKind::BoolType: + match = param_type == DataType::Bool; + break; + default: + match = false; + } + if (!match) { + msg << "Argument type is " << *arg_type << ", but the parameter is " + << param_type; + } + return match; +} + +bool validateKernelArg( + const c10::IValue& arg, + const Val* param, + c10::Device device, + std::stringstream& msg) { + if (arg.type()->kind() != c10::TypeKind::TensorType) { + return validateKernelArgScalar(arg.type(), param, msg); + } else { + return validateKernelArgTensor(arg.toTensor(), param, device, msg); + } +} + +void validateKernelInputs( + Fusion* fusion, + const at::ArrayRef& inputs, + c10::Device device) { + // This is necessary as we were traversing the fusion graph later in the check + FusionGuard fg(fusion); + // Check inputs + TORCH_INTERNAL_ASSERT( + inputs.size() == fusion->inputs().size(), + "Wrong number of kernel inputs."); + for (size_t i = 0; i < inputs.size(); ++i) { + const IValue& arg = inputs[i]; + const Val* param = fusion->inputs()[i]; + std::stringstream msg; + TORCH_INTERNAL_ASSERT( + validateKernelArg(arg, param, device, msg), + "Input argument at position ", + i, + " is invalid; ", + msg.str()); + } +} + +void validateKernelOutputs( + Fusion* fusion, + const std::vector& outputs, + c10::Device device) { + TORCH_INTERNAL_ASSERT( + fusion->outputs().size() != 0, + "Kernel should have at least one output tensor."); + + TORCH_INTERNAL_ASSERT( + outputs.size() == fusion->outputs().size(), + "Wrong number of kernel outputs."); + for (size_t i = 0; i < outputs.size(); ++i) { + const at::Tensor& arg = outputs[i]; + const Val* param = fusion->outputs()[i]; + std::stringstream msg; + TORCH_INTERNAL_ASSERT( + validateKernelArgTensor(arg, param, device, msg), + "Output argument at position ", + i, + " is invalid; ", + msg.str()); + } +} + +NvrtcFunction nvrtcCompile( + const std::string& code, + const std::string& func_name, + int id) { + // lazily construct context if non-existing yet; + CUcontext pctx = nullptr; + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx)); + if (!pctx) { + std::unique_lock cudaFreeMutexLock( + *(c10::cuda::CUDACachingAllocator::getFreeMutex())); + cudaFree(nullptr); + } + + const auto prop = at::cuda::getCurrentDeviceProperties(); + int nvrtc_major, nvrtc_minor; + AT_CUDA_NVRTC_CHECK( + at::globalContext().getNVRTC().nvrtcVersion(&nvrtc_major, &nvrtc_minor)); + + // Short-circuits if NVRTC version too low + TORCH_INTERNAL_ASSERT(nvrtc_major >= 6); + // Major and minor is determined by device properties and + // possibly "downcompiled" to a lower (compatible) compute architecture + // based on the NVRTC version + const int major = prop->major; + const int minor = prop->minor; + nvrtcProgram program; + AT_CUDA_NVRTC_CHECK(at::globalContext().getNVRTC().nvrtcCreateProgram( + &program, code.c_str(), nullptr, 0, nullptr, nullptr)); + ResourceGuard holdProgram([&] { + AT_CUDA_NVRTC_CHECK( + at::globalContext().getNVRTC().nvrtcDestroyProgram(&program)); + }); + + const std::string compute = "--gpu-architecture=compute_" + + std::to_string(major) + std::to_string(minor); + const std::vector args = { + "--std=c++14", compute.c_str(), "-default-device"}; + + at::globalContext().getNVRTC().nvrtcAddNameExpression( + program, func_name.c_str()); + const auto result = at::globalContext().getNVRTC().nvrtcCompileProgram( + program, args.size(), args.data()); + + if (result != NVRTC_SUCCESS) { + size_t logsize; + at::globalContext().getNVRTC().nvrtcGetProgramLogSize(program, &logsize); + std::vector log(logsize); + at::globalContext().getNVRTC().nvrtcGetProgramLog(program, log.data()); + + TORCH_INTERNAL_ASSERT( + false, code.c_str(), "\nCUDA NVRTC compile error: ", log.data()); + } + const char* lowered_kernel_name; + at::globalContext().getNVRTC().nvrtcGetLoweredName( + program, func_name.c_str(), &lowered_kernel_name); + + AT_CUDA_NVRTC_CHECK(result); + size_t ptx_size; + AT_CUDA_NVRTC_CHECK( + at::globalContext().getNVRTC().nvrtcGetPTXSize(program, &ptx_size)); + std::vector ptx; + ptx.resize(ptx_size); + AT_CUDA_NVRTC_CHECK( + at::globalContext().getNVRTC().nvrtcGetPTX(program, ptx.data())); + + NvrtcFunction compiled_kernel; + + // TODO: We do go through different code path, should investigate whether this + // has an impact on generated binary. + const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN"); + if (prefix_env) { + // Output ptx file + std::stringstream ptx_file_name; + ptx_file_name << prefix_env << "_" << id << ".ptx"; + std::ofstream myPtxFile(ptx_file_name.str().c_str(), std::ios::out); + if (myPtxFile.is_open()) { + myPtxFile.write(ptx.data(), ptx.size()); + myPtxFile.close(); + } + + CUlinkState linkState; + + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkCreate( + 0, nullptr, nullptr, &linkState)); + + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkAddData( + linkState, + CU_JIT_INPUT_PTX, + ptx.data(), + ptx_size, + "compiling PTX", + 0, + nullptr, + nullptr)); + + size_t cubinSize; + void* cubin; + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkComplete( + linkState, &cubin, &cubinSize)); + + // Output binary file + std::stringstream cubin_file_name; + cubin_file_name << prefix_env << "_" << id << ".cubin"; + + std::ofstream myCubinFile( + cubin_file_name.str().c_str(), std::ios::out | std::ios::binary); + + if (myCubinFile.is_open()) { + myCubinFile.write(static_cast(cubin), cubinSize); + myCubinFile.close(); + } + // load compiled cubin + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData( + &(compiled_kernel.module), cubin)); + } else { + // load ptx directly + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData( + &(compiled_kernel.module), ptx.data())); + } + AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleGetFunction( + &(compiled_kernel.function), + compiled_kernel.module, + lowered_kernel_name)); + + return compiled_kernel; +} + +} // namespace executor_utils +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.h b/torch/csrc/jit/codegen/cuda/executor_utils.h new file mode 100644 index 00000000000000..a7b42fdb78106f --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/executor_utils.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +#include +#include + +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace executor_utils { + +// Include all the functions we might need in generated code +std::string kernelPreamble(); + +bool validateKernelArgTensor( + const at::Tensor& arg, + const Val* param, + c10::Device device, + std::stringstream& msg); + +bool validateKernelArgScalar( + const c10::TypePtr& arg_type, + const Val* param, + std::stringstream& msg); + +bool validateKernelArg( + const c10::IValue& arg, + const Val* param, + c10::Device device, + std::stringstream& msg); + +void validateKernelInputs( + Fusion* fusion, + const at::ArrayRef& inputs, + c10::Device device); + +void validateKernelOutputs( + Fusion* fusion, + const std::vector& outputs, + c10::Device device); + +struct NvrtcFunction { + public: + CUmodule module = CUmodule(); + CUfunction function = CUfunction(); +}; + +NvrtcFunction nvrtcCompile( + const std::string& code, + const std::string& func_name, + int id); + +} // namespace executor_utils +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 2a333fcbf3312f..21e103c8d8ce43 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include @@ -10,9 +11,29 @@ namespace jit { namespace fuser { void EvaluationContext::bind(const Val* value, Int::ScalarType concrete_value) { - TORCH_CHECK(value->isAnInt()); - TORCH_CHECK(!value->as()->value().has_value()); - TORCH_CHECK(fusion_->origin(value) == nullptr); + TORCH_INTERNAL_ASSERT( + value->isAnInt(), + "Expressoin Evaluation does not support values other than integers at this time."); + + if (value->isConstScalar()) { + auto const_value = value->as()->value().value(); + TORCH_INTERNAL_ASSERT( + concrete_value == const_value, + "Tried to bind ", + concrete_value, + " to ", + value, + " however ", + value, + " is set to a constant ", + const_value); + } + + TORCH_INTERNAL_ASSERT( + fusion_->origin(value) == nullptr, + "Tried to bind to a value that is computed in the fusion IR. ", + "Can only bind to symbolic values to the fusion that do not have an origin expr."); + bindings_[value] = concrete_value; } @@ -27,10 +48,10 @@ void EvaluationContext::print() const { std::cout << "\nEvaluation context\n"; std::cout << "--------------------\n"; for (const auto& kv : bindings_) { - const auto val = kv.first->as(); - std::cout << "i" << val->name() << " = " << kv.second; - if (!val->isSymbolic()) { - std::cout << " ; original value = " << *val->value(); + std::cout << kv.first << " = " << kv.second; + if (kv.first->isConstScalar()) { + std::cout << " ; original value = " + << kv.first->as()->value().value(); } std::cout << "\n"; } @@ -53,6 +74,15 @@ c10::optional ExpressionEvaluator::value( : c10::nullopt; } +void ExpressionEvaluator::handle(NamedScalar* i) { + if (i->isAnInt()) { + const auto& bound_value = context_->concreteValue(i); + if (bound_value.has_value()) { + values_[i] = *bound_value; + } + } +} + void ExpressionEvaluator::handle(Int* i) { if (i->value().has_value()) { values_[i] = *i->value(); diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index bc29aac6b53b31..258ca1dcccde30 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -61,6 +61,7 @@ class TORCH_CUDA_API ExpressionEvaluator : private IterVisitor { using IterVisitor::handle; + void handle(NamedScalar*) override; void handle(Int*) override; void handle(UnaryOp*) override; void handle(BinaryOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 324aed2bfdecf2..bf0f488e6e94d0 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -475,7 +475,7 @@ std::unordered_set Fusion::unordered_uses(Val* val) const { } Expr* Fusion::origin(Val* val) const { - assertInFusion(val, "Cannot dettect the origin of val, "); + assertInFusion(val, "Cannot detect the origin of val, "); auto it = origin_.find(val); if (it == origin_.end()) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 898df6a5f7c4fc..af4bc3ce72380d 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -127,19 +127,23 @@ TensorIndex* Index::getGlobalProducerIndex( std::vector p_inds; auto p_root = TensorDomain::noReductions(producer->getRootDomain()); // Number of root dims that are broadcasted - size_t bcast_dims = 0; + size_t implicit_bcast_dims = 0; { auto c_root = consumer->getRootDomain(); size_t it_c = 0, it_p = 0; while (it_c < c_root.size() && it_p < p_root.size()) { const bool is_bcast = p_root[it_p]->isBroadcast(); - if (c_root[it_c]->isBroadcast() && !is_bcast) { + if (c_root[it_c]->isBroadcast() && !p_root[it_p]->isBroadcast()) { it_c++; } else { - if (!is_bcast) { + if (!p_root[it_p]->isBroadcast()) { p_inds.push_back(c_inds[it_c]); } else { - bcast_dims++; + if (p_root[it_p]->getBroadcastType() == BroadcastType::WithStride) { + p_inds.push_back(new Int(0)); + } else { + implicit_bcast_dims++; + } } it_c++; it_p++; @@ -147,7 +151,7 @@ TensorIndex* Index::getGlobalProducerIndex( } } TORCH_INTERNAL_ASSERT( - p_inds.size() == p_root.size() - bcast_dims, + p_inds.size() == p_root.size() - implicit_bcast_dims, "Dimensionality error in code generator while computing tensor indices."); std::vector strided_inds; @@ -263,16 +267,42 @@ TensorIndex* Index::getGlobalConsumerIndex( computed_inds.size() == root_dom.size(), "Dimensionality error in code generator while computing indexing."); - if (computed_inds.size() == root_dom.size()) + if (computed_inds.size() == root_dom.size()) { for (size_t i = 0; i < root_dom.size(); i++) { // Do this backwards so erase offset will be right auto axis = root_dom.size() - i - 1; - if (root_dom[axis]->isReduction() || root_dom[i]->isBroadcast()) + if (root_dom[axis]->isReduction()) computed_inds.erase(computed_inds.begin() + axis); } + } + + { + size_t root_i = 0, inds_i = 0; + while (root_i < root_dom.size() && inds_i < computed_inds.size()) { + if (root_dom[root_i]->isReduction()) { + root_i++; + } else { + if (root_dom[root_i]->getBroadcastType() == + BroadcastType::WithoutStride) { + computed_inds.erase(computed_inds.begin() + inds_i); + root_i++; + } else { + if (root_dom[root_i]->getBroadcastType() == + BroadcastType::WithStride) { + computed_inds[inds_i] = new Int(0); + } + root_i++; + inds_i++; + } + } + } + } std::vector strided_inds; for (size_t i = 0; i < computed_inds.size(); i++) { + if (computed_inds[i]->isZeroInt()) { + continue; + } std::stringstream ss; ss << "T" << consumer->name() << ".stride[" << i << "]"; strided_inds.push_back( diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index c0564b2511ee04..ec2158fe1c530b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -316,44 +316,45 @@ class TORCH_CUDA_API Scope { std::vector exprs_; }; -/* - * A Expr represents a "computation." These are functions that takes inputs - * and produce outputs, inputs and outputs all being Vals. There are - * specializations of BinaryOp which takes 2 inputs and produces 1 output, and - * UnaryOp which takes 1 input and produces 1 output. Exprs are unique and - * immutable. Conceptually, Exprs could always be manipulated using unique - * pointers, and we could add this later. However, for now Exprs can be replaced - * in a fusion, but they cannot be modified in place. - * - * The IR is static single assignment (SSA). Values can only be defined as an - * output of an Expr once. If they are re-defined the original definition is - * deleted from the program, as opposed to an ordered redefinition of the value - * in the program. - * - * Note: Registering an Expr with a Fusion is actually 2 parts, one part is done - * in the Expr constructor, so that should be called on anything that inherits - * Expr. The issue with having registration in Expr's constructor, is that the - * constructor of an Expr will set ouputs and inputs. This information is - * important for registration with Fuser, so it can track the dependency chain. - * - * Adding an Expr: - * Right now adding an Expr is quite involved. Expr's can be defined in ir.h or - * in their own header file. The following is what is currently needed for Expr - * definitions: - * 1) Definition inheriting from Expr. - * - Members must be private or protected - * - Accessor functions for members - * - Constructors need to register with the Fusion after inputs/outputs are - * defined - * - Implementation of bool sameAs(...) - * 2) dispatch.h/.cpp must be updated to include dispatch of the new Val - * 3) Default mutator function should be added to mutator.h/.cpp - * 4) Printing functions should be added to ir_iostream.h/.cpp - * 5) Lower case convenience functions should be added to arith.h/.cpp (If user - * facing) - * 6) An enum value must be added to ExprType in type.h 7) A string - * entry must be added in expr_type_string_map - */ +// A Expr represents a "computation." These are functions that takes inputs +// and produce outputs, inputs and outputs all being Vals. There are +// specializations of BinaryOp which takes 2 inputs and produces 1 output, and +// UnaryOp which takes 1 input and produces 1 output. Exprs are unique and +// immutable. Conceptually, Exprs could always be manipulated using unique +// pointers, and we could add this later. However, for now Exprs can be +// replaced in a fusion, but they cannot be modified in place. + +// The IR is static single assignment (SSA). Values can only be defined as an +// output of an Expr once. If they are re-defined the original definition is +// deleted from the program, as opposed to an ordered redefinition of the value +// in the program. + +// Note: Registering an Expr with a Fusion is actually 2 parts, one part is +// done in the Expr constructor, so that should be called on anything that +// inherits Expr. The issue with having registration in Expr's constructor, is +// that the constructor of an Expr will set ouputs and inputs. This information +// is important for registration with Fuser, so it can track the dependency +// chain. + +// Adding an Expr: +// Right now adding an Expr is quite involved. Expr's can be defined in ir.h or +// in their own header file. The following is what is currently needed for Expr +// definitions: +// 1) Definition inheriting from Expr. +// - Members must be private or protected +// - Accessor functions for members +// - Constructors need to register with the Fusion after inputs/outputs are +// defined +// - Implementation of bool sameAs(...) +// 2) dispatch.h/.cpp must be updated to include dispatch of the new Val +// 3) Default mutator function should be added to mutator.h/.cpp +// 4) Printing functions should be added to ir_iostream.h/.cpp +// 5) Lower case convenience functions should be added to arith.h/.cpp (If user +// facing) +// 6) An enum value must be added to ExprType in type.h +// 7) A string entry must be added in expr_type_string_map +// 8) Entry added to ir_graphviz .cpp/.h + class TORCH_CUDA_API Expr : public Statement { public: Expr() = delete; diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp index e434b14ccac006..02ea7277f804ef 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp @@ -470,6 +470,15 @@ void IrGraphGenerator::handle(const ReductionOp* op) { addArc(op, op->out()); } +void IrGraphGenerator::handle(const GridReduction* op) { + printExpr(op, "Grid Reduction"); + + // inputs & outputs + addArc(op, op->reduction_op()); + addArc(op->reduction_buffer(), op); + addArc(op->sync_buffer(), op); +} + void IrGraphGenerator::handle(const ForLoop* for_loop) { printExpr(for_loop, "ForLoop"); addArc(for_loop->index(), for_loop); @@ -488,8 +497,11 @@ void IrGraphGenerator::handle(const IfThenElse* if_then_else) { } void IrGraphGenerator::handle(const Allocate* allocate) { - printExpr(allocate, "Allocate"); - addArc(allocate->extent(), allocate); + std::stringstream msg; + msg << "Allocate( memory type = " << allocate->getMemoryType() << ")"; + + printExpr(allocate, msg.str()); + addArc(allocate->size(), allocate); addArc(allocate->buffer(), allocate); } diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h index c022ec0a79921a..cb6a7bc7596f27 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h @@ -79,6 +79,7 @@ class TORCH_CUDA_API IrGraphGenerator : private OptInConstDispatch { void handle(const TernaryOp*) override; void handle(const BroadcastOp*) override; void handle(const ReductionOp*) override; + void handle(const GridReduction*) override; void handle(const ForLoop*) override; void handle(const IfThenElse*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index a44501162a428d..3113e6bc3a049f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -269,6 +269,12 @@ class TORCH_CUDA_API TensorView : public Val { // and outer axis is size axis.size() / factor TensorView* split(int axis, unsigned int factor); + // Split "axis" into 2 axes where the inner axes is size of "factor" + // and outer axis is size axis.size() / factor. Factor can be a symbolic + // value instead of constant. This requires setting the symbolic value as an + // input, or using a parallel dim from NamedScalar::getParallelDim + TensorView* split(int axis, Val* factor); + // Merge axis_o and axis_i into 1 IterDomain TensorView* merge(int axis_o, int axis_i); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index f97225d87f252c..4f6e309d10625b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -16,6 +16,11 @@ namespace torch { namespace jit { namespace fuser { +// Returns true if both v1 and v2 are scalars, are the same type of scalars, and +// dispatches to the inherited Val type's `->sameAs` call. e.g. if both vals are +// `Int` will dispatch to v1->as()->sameAs(v2.as()) +bool areEqualScalars(Val* v1, Val* v2); + /* * TODO: improve implementation bool IterDomain::sameAs(const IterDomain*) const * TODO: Add testing of sameAs functions for these nodes @@ -182,6 +187,45 @@ class TORCH_CUDA_API ReductionOp : public Expr { Val* const in_ = nullptr; }; +// Grid reduction operation, this node is used only after lowering a fusion to +// explicitly mark a grid reduction and the buffer allocation needed to do it. +// This node provides FusionExecutor the information it needs to allocate the +// reduction and sync buffers. +class TORCH_CUDA_API GridReduction : public Expr { + public: + ~GridReduction() = default; + GridReduction(ReductionOp* _reduction_op); + GridReduction( + ReductionOp* _reduction_op, + Allocate* _reduction_buffer, + Allocate* _sync_buffer); + + GridReduction(const GridReduction* src, IrCloner* ir_cloner); + + GridReduction(const GridReduction& other) = delete; + GridReduction& operator=(const GridReduction& other) = delete; + + GridReduction(GridReduction&& other) = delete; + GridReduction& operator=(GridReduction&& other) = delete; + + ReductionOp* reduction_op() const { + return reduction_op_; + } + Allocate* reduction_buffer() const { + return reduction_buffer_; + } + Allocate* sync_buffer() const { + return sync_buffer_; + } + + bool sameAs(const GridReduction* other) const; + + private: + ReductionOp* reduction_op_ = nullptr; + Allocate* reduction_buffer_ = nullptr; + Allocate* sync_buffer_ = nullptr; +}; + class TORCH_CUDA_API TernaryOp : public Expr { public: ~TernaryOp() = default; @@ -241,13 +285,14 @@ class TORCH_CUDA_API IterDomain : public Val { ParallelType _parallel_method = ParallelType::Serial, bool _reduction_domain = false, bool _rfactor_domain = false, - bool _broadcast_domain = false); + BroadcastType _broadcast_domain = BroadcastType::Null); IterDomain(const IterDomain* src, IrCloner* ir_cloner); bool sameAs(const IterDomain* const other) const; // Returns a new IterDomain matching properties of this + // TODO: parallel_method->getParallelType IterDomain* clone() const { return new IterDomain( start(), @@ -255,13 +300,14 @@ class TORCH_CUDA_API IterDomain : public Val { parallel_method(), isReduction(), isRFactorProduct(), - isBroadcast()); + getBroadcastType()); } static IterDomain* merge(IterDomain* outer, IterDomain* inner); - static std::pair split( - IterDomain* in, - unsigned int factor); + + // TODO: Make protected and friend TensorDomain so only it can call into this + // directly, users should not be able to use this call + static std::pair split(IterDomain* in, Val* factor); bool isReduction() const { return is_reduction_domain_; @@ -272,7 +318,7 @@ class TORCH_CUDA_API IterDomain : public Val { } bool isBroadcast() const { - return is_broadcast_domain_; + return getBroadcastType() != BroadcastType::Null; } bool isParallelized() const { @@ -303,14 +349,6 @@ class TORCH_CUDA_API IterDomain : public Val { void parallelize(ParallelType t) { parallel_method_ = t; - // Currently a limitation as we allocate shared memory as static (not based - // off a dynamic size.) - if (isReduction()) - if (isThreadDim()) - TORCH_CHECK( - extent()->isConstScalar(), - "Reductions can only be parallelized across dimensions of compile-time known constants."); - TORCH_CHECK( t != ParallelType::Vectorize, "Vectorization not yet supported."); @@ -329,6 +367,10 @@ class TORCH_CUDA_API IterDomain : public Val { return parallel_method_; } + BroadcastType getBroadcastType() const { + return broadcast_type_; + } + Val* start() const { return start_; } @@ -349,7 +391,7 @@ class TORCH_CUDA_API IterDomain : public Val { ParallelType parallel_method_ = ParallelType::Serial; bool is_reduction_domain_ = false; bool is_rfactor_domain_ = false; - bool is_broadcast_domain_ = false; + BroadcastType broadcast_type_ = BroadcastType::Null; }; /* @@ -437,8 +479,11 @@ class TORCH_CUDA_API TensorDomain : public Val { size_t posOf(IterDomain* id) const; // Split "axis" into 2 axes where the inner axes is size of "factor" - // and outer axis is size axis.size() / factor - void split(int axis, unsigned int factor); + // and outer axis is size axis.size() / factor. Allow factor to be symbolic + // value instead of constant. + // TODO: Make protected and friend TensorDomain so only it can call into this + // directly, users should not be able to use this call + void split(int axis_, Val* factor); // Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting // axis is by default placed at original position axis_o @@ -482,7 +527,7 @@ class TORCH_CUDA_API Split : public Expr { Split(Split&& other) = delete; Split& operator=(Split&& other) = delete; - Split(IterDomain* _outer, IterDomain* _inner, IterDomain* _in, Int* _factor); + Split(IterDomain* _outer, IterDomain* _inner, IterDomain* _in, Val* _factor); Split(const Split* src, IrCloner* ir_cloner); @@ -495,7 +540,7 @@ class TORCH_CUDA_API Split : public Expr { IterDomain* in() const { return in_; } - Int* factor() const { + Val* factor() const { return factor_; } bool sameAs(const Split* const other) const; @@ -504,7 +549,7 @@ class TORCH_CUDA_API Split : public Expr { IterDomain* const outer_ = nullptr; IterDomain* const inner_ = nullptr; IterDomain* const in_ = nullptr; - Int* const factor_ = nullptr; + Val* const factor_ = nullptr; }; /* @@ -718,15 +763,13 @@ class TORCH_CUDA_API TensorIndex : public Val { std::vector indices_; }; -/* - * Allocate is a lower level Node that describes a buffer of memory that - * is required as an intermediate within a kernel. The extent is the expression - * of the size of the buffer that is generated from the TensorView that - * describes the output of an operation. - * - * TODO: The components of Allocate like Type and Name could be separated from - * the the assocated TensorView. Perhaps that is more appropriate? - */ +// Allocate is a lower level Node that describes a buffer of memory that +// is required as an intermediate within a kernel. The extent is the expression +// of the size of the buffer that is generated from the TensorView that +// describes the output of an operation. +// +// TODO: The components of Allocate like Type and Name could be separated from +// the the assocated TensorView. Perhaps that is more appropriate? class TORCH_CUDA_API Allocate : public Expr { public: ~Allocate() = default; @@ -737,23 +780,35 @@ class TORCH_CUDA_API Allocate : public Expr { Allocate(Allocate&& other) = delete; Allocate& operator=(Allocate&& other) = delete; - Allocate(Val* _tv, Val* size); + explicit Allocate( + Val* _buffer, + MemoryType _memory_type = MemoryType::Local, + Val* _size = nullptr); Allocate(const Allocate* src, IrCloner* ir_cloner); - DataType buf_type() const; - Val* extent() const { - return extent_; - } Val* buffer() const { return buffer_; } + MemoryType getMemoryType() const { + return memory_type_; + } + + Val* size() const { + return size_; + } + + DataType buffer_type() const { + return buffer_->getDataType().value(); + } + bool sameAs(const Allocate* other) const; private: Val* buffer_ = nullptr; - Val* extent_ = nullptr; + MemoryType memory_type_ = MemoryType::Local; + Val* size_ = nullptr; }; /* @@ -787,6 +842,20 @@ class TORCH_CUDA_API NamedScalar : public Val { return other->name().compare(name()) == 0; } + // Return the named scalar extent of a parallel dimension (e.g. blockDim.x) + static NamedScalar* getParallelDim(ParallelType p_type); + + // Return the named scalar index of a parallel dimension (e.g. threadIdx.x) + static NamedScalar* getParallelIndex(ParallelType p_type); + + // Return the parallel type of this NamedScalar if it is an extent of a + // parallel dimension + c10::optional getParallelDim() const; + + // Return the parallel type of this NamedScalar if it is an index of a + // parallel dimension + c10::optional getParallelIndex() const; + private: std::string name_; }; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 5a0fd242afbe2f..4392cdf9aa8f3d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -42,7 +42,10 @@ void IRPrinter::handle(const Expr* e) { OptInConstDispatch::handle(e); } -void IRPrinter::printHeader(Fusion* fusion, const std::string& kernel_name_) { +void IRPrinter::printHeader( + Fusion* fusion, + const std::string& kernel_name_, + const std::vector& global_buffers) { os << "__global__ void " << kernel_name_ << "("; std::vector vals; @@ -54,16 +57,20 @@ void IRPrinter::printHeader(Fusion* fusion, const std::string& kernel_name_) { vals.push_back(val); } + for (auto val : global_buffers) { + vals.push_back(val); + } + for (Val* val : vals) { switch (val->getValType().value()) { - case (ValType::TensorView): + case ValType::TensorView: os << "Tensor<" << val->getDataType().value() << ", " << TensorDomain::noReductions( static_cast(val)->getRootDomain()) .size() << "> T" << val->name(); break; - case (ValType::Scalar): + case ValType::Scalar: os << val->getDataType().value() << " " << val; break; default: @@ -79,10 +86,6 @@ void IRPrinter::printHeader(Fusion* fusion, const std::string& kernel_name_) { if (fusion->hasRNG()) os << ", unsigned long long seed, unsigned long long offset"; - if (fusion->hasGridReduction()) { - os << ", void* work_buf, unsigned* sync_flags"; - } - os << "){\n"; indent_size++; if (fusion->hasRNG()) { @@ -121,13 +124,35 @@ void IRPrinter::handle(const TensorDomain* td) { } void IRPrinter::handle(const TensorView* tv) { - os << "T" << tv->name(); - handle(tv->domain()); + if (tv->nDims() == 0) { + switch (tv->getDataType().value()) { + case DataType::Bool: + os << "b"; + break; + case DataType::Float: + os << "f"; + break; + case DataType::Half: + os << "h"; + break; + case DataType::Int: + os << "i"; + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Did not recognize type ", tv->getDataType().value()); + } + os << tv->name(); + + } else { + os << "T" << tv->name(); + handle(tv->domain()); - if (tv->getComputeAtView() != nullptr) { - os << " compute_at( "; - os << "T" << tv->getComputeAtView()->name(); - os << ", " << tv->getRelativeComputeAtAxis() << " )"; + if (tv->getComputeAtView() != nullptr) { + os << " compute_at( "; + os << "T" << tv->getComputeAtView()->name(); + os << ", " << tv->getRelativeComputeAtAxis() << " )"; + } } } @@ -139,13 +164,13 @@ void IRPrinter::handle(const IterDomain* id) { else os << "i"; switch (id->parallel_method()) { - case (ParallelType::Vectorize): + case ParallelType::Vectorize: os << "V"; break; - case (ParallelType::Unroll): + case ParallelType::Unroll: os << "U"; break; - case (ParallelType::Serial): + case ParallelType::Serial: os << "S"; break; default: @@ -412,6 +437,7 @@ void IRPrinter::handle(const ReductionOp* rop) { bool has_grid_reduce = out->view()->hasGridReduction(); if (!has_block_reduce && !has_grid_reduce) { + FusionGuard fg(rop->fusion()); handle(new BinaryOp(rop->getReductionOpType(), out, out, rop->in())); return; } @@ -450,29 +476,58 @@ void IRPrinter::handle(const ReductionOp* rop) { os << ", reinterpret_cast<" << d_type << "*>(shared_mem)"; os << ");\n"; } - if (has_grid_reduce) { - indent(); - // Since block-level reduction is already done, those dimensions - // with tidx/y/z being true do not participate in the grid reduction. - os << "reduction::gridReduce< " << (bidx ? "true" : "false") << ", " - << (bidy ? "true" : "false") << ", " << (bidz ? "true" : "false") << ", " - << (!tidx ? "true" : "false") << ", " << (!tidy ? "true" : "false") - << ", " << (!tidz ? "true" : "false") << " >" - << " ( "; - handle(rop->out()); - os << ", "; - if (has_block_reduce) { - os << block_result; - } else { - handle(rop->in()); - } - os << ", "; - os << "reduction_" << op_type << "_" << d_type; - os << ", static_cast<" << d_type << "*>(work_buf)"; - os << ", sync_flags"; - os << ", reinterpret_cast<" << d_type << "*>(shared_mem)"; - os << ");\n"; +} + +void IRPrinter::handle(const GridReduction* gr) { + // Check if we've lowered yet. + const auto rop = gr->reduction_op(); + TORCH_INTERNAL_ASSERT( + rop->out()->getValType() == ValType::TensorIndex, + "GridReduction node is a lowered node but did not find the output to be a TensorIndex."); + + const auto out = rop->out()->as(); + TORCH_INTERNAL_ASSERT(out->view()->hasGridReduction()); + + const auto vec_domain = out->view()->domain()->domain(); + + const auto par_domains = rop->getParallelReductionDomains(); + const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end(); + const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end(); + const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end(); + const bool bidx = par_domains.find(ParallelType::BIDx) != par_domains.end(); + const bool bidy = par_domains.find(ParallelType::BIDy) != par_domains.end(); + const bool bidz = par_domains.find(ParallelType::BIDz) != par_domains.end(); + + const auto d_type = rop->out()->getDataType().value(); + const auto op_type = rop->getReductionOpType(); + TORCH_INTERNAL_ASSERT( + gr->reduction_buffer()->buffer()->getValType().value() == + ValType::TensorView); + TORCH_INTERNAL_ASSERT( + gr->sync_buffer()->buffer()->getValType().value() == ValType::TensorView); + TensorView* work_buffer = gr->reduction_buffer()->buffer()->as(); + TensorView* sync_buffer = gr->sync_buffer()->buffer()->as(); + indent(); + // Since block-level reduction is already done, those dimensions + // with tidx/y/z being true do not participate in the grid reduction. + os << "reduction::gridReduce< " << (bidx ? "true" : "false") << ", " + << (bidy ? "true" : "false") << ", " << (bidz ? "true" : "false") << ", " + << (!tidx ? "true" : "false") << ", " << (!tidy ? "true" : "false") << ", " + << (!tidz ? "true" : "false") << " >" + << " ( "; + handle(rop->out()); + os << ", "; + if (out->view()->hasBlockReduction()) { + os << "block_result"; + } else { + handle(rop->in()); } + os << ", "; + os << "reduction_" << op_type << "_" << d_type; + os << ", &T" << work_buffer->name() << "[0]"; + os << ", T" << sync_buffer->name() << ""; + os << ", reinterpret_cast<" << d_type << "*>(shared_mem)"; + os << ");\n"; } void IRPrinter::handle(const BroadcastOp* bop) { @@ -580,22 +635,50 @@ void IRPrinter::handle(const IfThenElse* ite) { void IRPrinter::handle(const Allocate* a) { indent(); - os << a->buf_type(); - if (a->buffer()->getValType() == ValType::TensorView) { - os << " T" << a->buffer()->name() << "["; - print_inline(a->extent()); - os << "];\n"; - } else { - if (a->extent()->isOneInt()) { - os << " " << a->buffer() << ";\n"; - } else { - TORCH_INTERNAL_ASSERT( - false, - "Received unexpected allocation: ", - a->buffer(), - " with alloc of ", - a->extent()); + if (a->buffer()->getValType().value() == ValType::TensorView) { + auto tv = a->buffer()->as(); + + switch (tv->getMemoryType()) { + case MemoryType::Global: + os << "// Allocate global tensor " << a->buffer_type() << " T" + << tv->name() << "["; + if (a->size() == nullptr) { + handle(tv); + } else { + print_inline(a->size()); + } + os << "];\n"; + break; + case MemoryType::Shared: + os << "__shared__ "; + os << a->buffer_type(); + if (tv->nDims() == 0) { + os << tv; + } else { + os << " T" << tv->name(); + os << "["; + print_inline(a->size()); + os << "]"; + } + os << ";\n"; + break; + case MemoryType::Local: + os << a->buffer_type(); + if (tv->nDims() == 0) { + os << tv; + } else { + os << " T" << tv->name(); + os << "["; + print_inline(a->size()); + os << "]"; + } + os << ";\n"; + break; } + } else { + os << a->buffer_type() << " "; + handle(a->buffer()); + os << ";\n"; } } @@ -643,6 +726,7 @@ class ReductionOps : OptOutDispatch { } // namespace void IRPrinter::printReductionOps(Fusion* fusion) { + FusionGuard fg(fusion); auto a = new NamedScalar("a", DataType::Null); auto b = new NamedScalar("b", DataType::Null); for (auto rop_pair : ReductionOps::get(fusion)) { @@ -654,6 +738,7 @@ void IRPrinter::printReductionOps(Fusion* fusion) { << d_type << "& a, " << "const " << d_type << " b) {\n"; indent_size++; + handle(new BinaryOp(op_type, a, a, b)); indent_size--; indent(); @@ -663,10 +748,18 @@ void IRPrinter::printReductionOps(Fusion* fusion) { void IRPrinter::printKernel( const std::vector& exprs, - const std::string& kernel_name) { + const std::string& kernel_name, + const std::vector& global_buffers) { Fusion* fusion = FusionGuard::getCurFusion(); + if (exprs.empty()) + return; + TORCH_INTERNAL_ASSERT( + exprs[0]->fusion() == FusionGuard::getCurFusion(), + "Incorrect fusion set during printKernel."); + printReductionOps(fusion); - printHeader(fusion, kernel_name); + printHeader(fusion, kernel_name, global_buffers); + for (auto* expr : exprs) { handle(expr); } diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 3dc95d16e95c22..4e1b6d58d97d13 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -22,6 +22,7 @@ class UnaryOp; class BinaryOp; class TernaryOp; class ReductionOp; +class GridReduction; class BroadcastOp; class ForLoop; @@ -70,7 +71,10 @@ class TORCH_CUDA_API IRPrinter : public OptInConstDispatch { indent_size = 0; } - void printHeader(Fusion* fusion, const std::string& kernel_name_); + void printHeader( + Fusion* fusion, + const std::string& kernel_name_, + const std::vector& global_buffers); IRPrinter(std::ostream& _os) : os(_os) {} @@ -106,6 +110,7 @@ class TORCH_CUDA_API IRPrinter : public OptInConstDispatch { void handle(const BinaryOp*) override; void handle(const TernaryOp*) override; void handle(const ReductionOp*) override; + void handle(const GridReduction*) override; void handle(const BroadcastOp*) override; void handle(const ForLoop*) override; @@ -126,7 +131,8 @@ class TORCH_CUDA_API IRPrinter : public OptInConstDispatch { void printKernel( const std::vector& exprs, - const std::string& kernel_name); + const std::string& kernel_name, + const std::vector& global_buffers); private: std::unique_ptr thread_predicates_; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index a4cd8628fe4660..0439b8d12a0489 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -65,6 +65,10 @@ class ScalarCheck : OptInDispatch { } // namespace +bool areEqualScalars(Val* v1, Val* v2) { + return ScalarCheck::sameAs(v1, v2); +} + Bool::Bool(const Bool* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} @@ -286,10 +290,12 @@ std::vector ReductionOp::getReductionDomains() const { out_val->getValType() == ValType::TensorView || out_val->getValType() == ValType::TensorIndex, "Output of reduction must be TensorView or TensorIndex"); + // out is a TensorIndex after lowering if (out_val->getValType() == ValType::TensorIndex) { out_val = static_cast(out_val)->view(); } + auto vec_domain = out_val->as()->domain()->domain(); vec_domain.erase( std::remove_if( @@ -311,25 +317,51 @@ std::unordered_map ReductionOp:: return parallel_domains; } +GridReduction::GridReduction(ReductionOp* _reduction_op) + : Expr(ExprType::GridReduction), reduction_op_(_reduction_op) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} + +GridReduction::GridReduction( + ReductionOp* _reduction_op, + Allocate* _reduction_buffer, + Allocate* _sync_buffer) + : Expr(ExprType::GridReduction), + reduction_op_(_reduction_op), + reduction_buffer_(_reduction_buffer), + sync_buffer_(_sync_buffer) {} + +GridReduction::GridReduction(const GridReduction* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + reduction_op_(ir_cloner->clone(src->reduction_op_)), + reduction_buffer_(ir_cloner->clone(src->reduction_buffer_)), + sync_buffer_(ir_cloner->clone(src->sync_buffer_)) {} + +bool GridReduction::sameAs(const GridReduction* other) const { + return reduction_op_->sameAs(other->reduction_op()) && + reduction_buffer_->sameAs(other->reduction_buffer()) && + sync_buffer_->sameAs(other->sync_buffer()); +} + IterDomain::IterDomain( Val* _start, Val* _extent, ParallelType _parallel_method, bool _reduction_domain, bool _rfactor_domain, - bool _broadcast_domain) + BroadcastType _broadcast_type) : Val(ValType::IterDomain, DataType::Int, false), start_(_start), extent_(_extent), parallel_method_(_parallel_method), is_reduction_domain_(_reduction_domain), is_rfactor_domain_(_rfactor_domain), - is_broadcast_domain_(_broadcast_domain) { + broadcast_type_(_broadcast_type) { TORCH_CHECK( - !(is_reduction_domain_ && is_broadcast_domain_), + !(isReduction() && isBroadcast()), "IterDomain cannot be both a broadcast and reduction domain."); TORCH_CHECK( - !(is_rfactor_domain_ && is_broadcast_domain_), + !(isRFactorProduct() && isBroadcast()), "IterDomain cannot be both a broadcast and rfactor domain."); TORCH_INTERNAL_ASSERT( @@ -352,7 +384,7 @@ IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) parallel_method_(src->parallel_method_), is_reduction_domain_(src->is_reduction_domain_), is_rfactor_domain_(src->is_rfactor_domain_), - is_broadcast_domain_(src->is_broadcast_domain_) {} + broadcast_type_(src->broadcast_type_) {} bool IterDomain::sameAs(const IterDomain* const other) const { if (other == this) @@ -378,13 +410,22 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { "Merging IterDomains requires that their parallel types match."); Val* merged_id_size = mul(outer->extent(), inner->extent()); + BroadcastType bcast_type = BroadcastType::Null; + if (outer->isBroadcast() && inner->isBroadcast()) { + if (outer->getBroadcastType() == BroadcastType::WithStride || + inner->getBroadcastType() == BroadcastType::WithStride) { + bcast_type = BroadcastType::WithStride; + } else { + bcast_type = BroadcastType::WithoutStride; + } + } IterDomain* merged_id = new IterDomain( new Int(0), static_cast(merged_id_size), outer->parallel_method(), outer->isReduction(), outer->isRFactorProduct() || inner->isRFactorProduct(), - outer->isBroadcast() && inner->isBroadcast()); + bcast_type); new Merge(merged_id, outer, inner); @@ -393,7 +434,7 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { std::pair IterDomain::split( IterDomain* in, - unsigned int factor) { + Val* factor) { TORCH_CHECK( in->start()->isZeroInt(), "Splitting IterDomains with starting values that aren't 0 is not supported at this time."); @@ -404,9 +445,24 @@ std::pair IterDomain::split( "Splitting an axis of non-Serial iteration is not supported at this time." " Parallelization strategy must be set after calling split."); - Int* fact = new Int(factor); + TORCH_CHECK(factor->isAnInt(), "Cannot split by non-integer value ", factor); + + if (factor->getValType() == ValType::Scalar) { + TORCH_CHECK( + factor->isConstScalar() || + FusionGuard::getCurFusion()->hasInput(factor), + factor, + " is not a constant nor an input. It must be one or the other to be used in a split.", + " If you want a symbolic split based on a thread dimension please use IterDomain::split(IterDomain*, ParallelType);"); + } else if (factor->getValType() == ValType::NamedScalar) { + TORCH_CHECK( + factor->as()->getParallelDim() != c10::nullopt, + "Splitting a dimension by a named scalar is only supported on block or grid dimensions but received ", + factor); + } + // outer loop size - Val* vo = ceilDiv(in->extent(), fact); + Val* vo = ceilDiv(in->extent(), factor); // outer loop IterDomain IterDomain* ido = new IterDomain( @@ -415,17 +471,17 @@ std::pair IterDomain::split( in->parallel_method(), in->isReduction(), in->isRFactorProduct(), - in->isBroadcast()); + in->getBroadcastType()); // inner loop IterDomain IterDomain* idi = new IterDomain( new Int(0), - fact, + factor, in->parallel_method(), in->isReduction(), in->isRFactorProduct(), - in->isBroadcast()); - new Split(ido, idi, in, fact); + in->getBroadcastType()); + new Split(ido, idi, in, factor); return {ido, idi}; } @@ -435,8 +491,7 @@ Val* IterDomain::extent() const { if (static_cast(extent_)->isConst()) return extent_; - std::string parallel_dim = stringifyThreadSize(parallel_method_); - return new NamedScalar(parallel_dim, DataType::Int); + return NamedScalar::getParallelDim(parallel_method()); } return extent_; } @@ -605,9 +660,7 @@ size_t TensorDomain::posOf(IterDomain* id) const { TORCH_CHECK(false, "Provided id is not part of this domain."); } -// Split "axis" into 2 axes where the inner axes is size of "factor" -// and outer axis is size axis.extent() / factor -void TensorDomain::split(int axis_, unsigned int factor) { +void TensorDomain::split(int axis_, Val* factor) { TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain"); if (axis_ < 0) axis_ += nDims(); @@ -875,12 +928,15 @@ Split::Split( IterDomain* _outer, IterDomain* _inner, IterDomain* _in, - Int* _factor) + Val* _factor) : Expr(ExprType::Split), outer_{_outer}, inner_{_inner}, in_{_in}, factor_{_factor} { + TORCH_INTERNAL_ASSERT( + factor_->isAnInt(), + "Attempted to create a Split node with a non-integer factor."); addOutput(_outer); addOutput(_inner); addInput(_in); @@ -1012,37 +1068,60 @@ Val* TensorIndex::index(int i) const { return indices_[i]; } -Allocate::Allocate(Val* _val, Val* _size) - : Expr(ExprType::Allocate), buffer_(_val), extent_{_size} { - if (!_size->isAnInt() || !_size->isConstScalar()) { - std::stringstream flat_size; - IRPrinter irp(flat_size); - irp.print_inline(_size); +Allocate::Allocate(Val* _buffer, MemoryType _memory_type, Val* _size) + : Expr(ExprType::Allocate), + buffer_(_buffer), + memory_type_(_memory_type), + size_(_size) { + if (size_ != nullptr) { TORCH_INTERNAL_ASSERT( - false, - "Allocations must be based on constant integers but tried to alloc ", - _val, - " with size ", - flat_size.str(), - "."); + size_->isOneInt() || + buffer_->getValType().value() == ValType::TensorView, + "Cannot allocate a non-TensorView buffer with a size != 1, received buffer: ", + buffer_); + } else { + if (buffer_->getValType().value() == ValType::TensorView) { + auto tv = buffer_->as(); + size_ = tv->nDims() == 0 ? new Int(1) : tv->axis(0)->extent(); + for (size_t i = 1; i < tv->nDims(); i++) { + size_ = mul(size_, tv->axis(i)->extent()); + } + + if ((memory_type_ == MemoryType::Local || + memory_type_ == MemoryType::Shared)) { + if (!size_->isConstScalar()) { + std::stringstream flat_size; + IRPrinter irp(flat_size); + irp.print_inline(size_); + TORCH_INTERNAL_ASSERT( + false, + "Allocations must be based on constant integers for the memory type ", + memory_type_, + " but tried to alloc ", + buffer_, + " with size ", + flat_size.str(), + "."); + } + } + } } - addInput(_size); + addInput(size_); this->name_ = FusionGuard::getCurFusion()->registerExpr(this); } Allocate::Allocate(const Allocate* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), buffer_(ir_cloner->clone(src->buffer_)), - extent_(ir_cloner->clone(src->extent_)) {} - -DataType Allocate::buf_type() const { - return buffer_->getDataType().value(); -} + memory_type_(src->memory_type_), + size_(ir_cloner->clone(src->size_)) {} bool Allocate::sameAs(const Allocate* other) const { if (!this->buffer_->sameAs(other->buffer())) return false; - if (!this->extent()->sameAs(other->extent())) + if (!this->size()->sameAs(other->size())) + return false; + if (this->getMemoryType() != other->getMemoryType()) return false; if (this->type() != other->type()) return false; @@ -1053,6 +1132,50 @@ bool Allocate::sameAs(const Allocate* other) const { NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner) : Val(src, ir_cloner), name_(src->name_) {} +NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) { + std::string parallel_dim = stringifyThreadSize(p_type); + return new NamedScalar(parallel_dim, DataType::Int); +} + +NamedScalar* NamedScalar::getParallelIndex(ParallelType p_type) { + std::string parallel_ind = stringifyThread(p_type); + return new NamedScalar(parallel_ind, DataType::Int); +} + +c10::optional NamedScalar::getParallelDim() const { + if (stringifyThreadSize(ParallelType::TIDx).compare(name()) == 0) { + return c10::optional(ParallelType::TIDx); + } else if (stringifyThreadSize(ParallelType::TIDy).compare(name()) == 0) { + return c10::optional(ParallelType::TIDy); + } else if (stringifyThreadSize(ParallelType::TIDz).compare(name()) == 0) { + return c10::optional(ParallelType::TIDz); + } else if (stringifyThreadSize(ParallelType::BIDx).compare(name()) == 0) { + return c10::optional(ParallelType::BIDx); + } else if (stringifyThreadSize(ParallelType::BIDy).compare(name()) == 0) { + return c10::optional(ParallelType::BIDy); + } else if (stringifyThreadSize(ParallelType::BIDz).compare(name()) == 0) { + return c10::optional(ParallelType::BIDz); + } + return c10::nullopt; +} + +c10::optional NamedScalar::getParallelIndex() const { + if (stringifyThread(ParallelType::TIDx).compare(name()) == 0) { + return c10::optional(ParallelType::TIDx); + } else if (stringifyThread(ParallelType::TIDy).compare(name()) == 0) { + return c10::optional(ParallelType::TIDy); + } else if (stringifyThread(ParallelType::TIDz).compare(name()) == 0) { + return c10::optional(ParallelType::TIDz); + } else if (stringifyThread(ParallelType::BIDx).compare(name()) == 0) { + return c10::optional(ParallelType::BIDx); + } else if (stringifyThread(ParallelType::BIDy).compare(name()) == 0) { + return c10::optional(ParallelType::BIDy); + } else if (stringifyThread(ParallelType::BIDz).compare(name()) == 0) { + return c10::optional(ParallelType::BIDz); + } + return c10::nullopt; +} + } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index 58c665676501a2..3843a967c45a51 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -1,13 +1,12 @@ #include -#include #include #include #include #include +#include #include #include -#include #include #include #include @@ -35,77 +34,6 @@ int ceilDiv(const int a, const int b) { return (a + b - 1) / b; } -struct KernelArgumentHolder { - private: - std::vector arguments; - std::vector void_ptrs; - bool changed = true; - - public: - virtual ~KernelArgumentHolder() { - for (auto arg : arguments) - delete arg; - } - - // Push a tensor to the arguments - void push( - const at::Tensor& val, - c10::optional broadcasted_size = c10::nullopt) { - changed = true; - ExtractSizeStride ess(val, std::move(broadcasted_size)); - int nDims = ess.sizes.size(); - - c10::ScalarType dtype = val.scalar_type(); - TensorArgAbstract* tensor_arg = getTensorArg(dtype, nDims); - tensor_arg->setPointer(val.data_ptr()); - for (int i = 0; i < nDims; i++) { - tensor_arg->setSize(i, ess.sizes[i]); - tensor_arg->setStride(i, ess.strides[i]); - } - arguments.push_back(tensor_arg); - } - - // Push a scalar or integer to the arguments - void push(const IValue& val) { - changed = true; - TORCH_INTERNAL_ASSERT( - val.isScalar(), - "Tried to push an arg to run in a fused kernel, expected a scalar but got, ", - val); - switch (val.toScalar().type()) { - case (c10::ScalarType::Double): - arguments.push_back(new FloatArg((float)val.toDouble())); - return; - case (c10::ScalarType::Long): - arguments.push_back(new IntArg((int)val.toInt())); - return; - default: - TORCH_INTERNAL_ASSERT( - false, - " Tried to create argument to send to a fused kernel, but got an unexpected type."); - } - TORCH_INTERNAL_ASSERT( - false, - " Tried to create argument to send to a fused kernel, but got a non-scalar type."); - } - - void push(const uint64_t& val) { - arguments.push_back(new ULongArg(val)); - } - - // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers - // in the buffer - void** getBuffer() { - if (changed) { - void_ptrs = std::vector(arguments.size(), nullptr); - for (decltype(arguments.size()) i{0}; i < arguments.size(); i++) - void_ptrs[i] = static_cast(arguments[i]->arg()); - changed = false; - } - return void_ptrs.data(); - } -}; - std::pair codeGeneration(Fusion* fusion) { std::stringstream str_stream; str_stream << "namespace " << kCgNamespace << " {\n" diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index 04966fc5cafe4c..4a245313fb3f32 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -11,8 +11,7 @@ namespace fuser { namespace cuda { at::optional CudaKernelCache::getKernelPtr( - const at::ArrayRef inputs, - const std::vector& broadcasted_shape) { + const at::ArrayRef inputs) { for (auto& cuda_kernel : kernels_) { // bound input sizes Fusion* fusion = cuda_kernel.fusion(); @@ -20,7 +19,7 @@ at::optional CudaKernelCache::getKernelPtr( EvaluationContext eval_context(fusion); for (int i = 0; i < (int)inputs.size(); i++) { if (inputs[i].isTensor()) { - ExtractSizeStride ess(inputs[i].toTensor(), broadcasted_shape); + ExtractSizeStride ess(inputs[i].toTensor()); const int nDims = ess.sizes.size(); TensorView* tv = fusion->inputs()[i]->as(); for (int j = 0; j < nDims; j++) { diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index 7608094c9d97fb..31c5ef52847af6 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -5,9 +5,6 @@ #include -/* - */ - namespace torch { namespace jit { namespace fuser { @@ -107,8 +104,7 @@ class CudaKernelCache { CudaKernelCache() = default; at::optional getKernelPtr( - const at::ArrayRef inputs, - const std::vector& broadcasted_shape); + const at::ArrayRef inputs); CudaKernel* allocateKernelInCache(const at::ArrayRef inputs); private: diff --git a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h index c6c009fd0bb3ee..5d2504b5bf9bad 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h +++ b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h @@ -478,10 +478,6 @@ __device__ void gridReduceLastBlock(T& out, const T *in, const size_t in_size, } } -__device__ unsigned atomic_inc(unsigned* sync_flag, unsigned max_val) { - return atomicInc(sync_flag, max_val - 1); -} - /** Reduces per-thread values across thread blocks. Function parameters: @@ -528,7 +524,7 @@ template __device__ void gridReduce(T& out, T inp_val, Func reduction_op, volatile T* work_buf, - unsigned* sync_flags, + Tensor sync_flags, T* shared_buf) { const auto seg_size = size_of_reduction_segment(gridDim); @@ -555,8 +551,8 @@ __device__ void gridReduce(T& out, T inp_val, Func reduction_op, __shared__ bool last_block; if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { __threadfence(); - auto old = atomic_inc(&sync_flags[seg_idx], seg_size); - last_block = old == seg_size - 1; + auto old = atomicAdd( (unsigned long long*) &sync_flags[seg_idx], 1); + last_block = old + 1 == seg_size; } __syncthreads(); diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index ea917f83a9e345..e690a6e5ab2fed 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -13,8 +13,62 @@ namespace torch { namespace jit { namespace fuser { -// Traverse through the fusion and print CUDA code associated with it -std::vector GPULower::getLoweredExprs() { +namespace { + +class GridReductionBuffers : OptOutDispatch { + public: + static std::vector getGlobalAllocs( + const std::vector& exprs) { + GridReductionBuffers fgr; + for (auto expr : exprs) { + fgr.handle(expr); + } + return fgr.global_allocations_; + } + + static std::vector getSyncAllocs(const std::vector& exprs) { + GridReductionBuffers fgr; + for (auto expr : exprs) { + fgr.handle(expr); + } + return fgr.sync_allocations_; + } + + private: + std::vector global_allocations_; + std::vector sync_allocations_; + + GridReductionBuffers() = default; + + void handle(Expr* expr) final { + OptOutDispatch::handle(expr); + } + + void handle(ForLoop* fl) final { + for (auto expr : fl->body().exprs()) { + OptOutDispatch::handle(expr); + } + } + + void handle(IfThenElse* ite) final { + for (auto expr : ite->body().exprs()) { + OptOutDispatch::handle(expr); + } + + for (auto expr : ite->elseBody().exprs()) { + OptOutDispatch::handle(expr); + } + } + + void handle(GridReduction* gr) final { + global_allocations_.push_back(gr->reduction_buffer()); + sync_allocations_.push_back(gr->sync_buffer()); + } +}; + +} // namespace + +void GPULower::lower() { FusionGuard fg(fusion_); // Validate and make some minor modifications in preparation to generate code. @@ -27,23 +81,47 @@ std::vector GPULower::getLoweredExprs() { fusion_, fusion_->exprs(true, false, true), preds); auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests, preds); - auto indexed_loops = IndexLowering::getIndexedExprs(fusion_, unrolled_loops); + lowered_exprs_ = indexed_loops; - return indexed_loops; + // Get allocations: + global_allocations_ = GridReductionBuffers::getGlobalAllocs(lowered_exprs_); + sync_allocations_ = GridReductionBuffers::getSyncAllocs(lowered_exprs_); +} + +// Traverse through the fusion and print CUDA code associated with it +std::vector GPULower::lowered_exprs() { + return lowered_exprs_; } std::ostream& GPULower::printKernel( std::ostream& os, const std::string& kernel_name) { FusionGuard fg(fusion_); - auto exprs = getLoweredExprs(); + std::vector allocs; + allocs.insert( + allocs.end(), global_allocations_.begin(), global_allocations_.end()); + allocs.insert( + allocs.end(), sync_allocations_.begin(), sync_allocations_.end()); + + std::vector global_tensors(allocs.size(), nullptr); + std::transform( + allocs.begin(), + allocs.end(), + global_tensors.begin(), + [](Allocate* alloc) { return alloc->buffer(); }); IRPrinter irp(os); - irp.printKernel(exprs, kernel_name); + irp.printKernel(lowered_exprs_, kernel_name, global_tensors); return os; } +std::string GPULower::getKernel(const std::string& kernel_name) { + std::stringstream ss; + printKernel(ss, kernel_name); + return ss.str(); +} + } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index 57eadbfd57a878..0b8592a3033952 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -13,16 +13,43 @@ namespace fuser { class TORCH_CUDA_API GPULower { public: // Init printer on ostream - GPULower(Fusion* _fusion) : fusion_(_fusion) {} + explicit GPULower(Fusion* _fusion) : fusion_(_fusion) { + lower(); + } + + GPULower() = default; + GPULower(const GPULower& lower) = default; + GPULower& operator=(const GPULower& other) = default; // print generated code to ostream - std::vector getLoweredExprs(); + std::vector lowered_exprs(); + std::ostream& printKernel( std::ostream& _os, const std::string& kernel_name = "CUDAGeneratedKernel"); + std::string getKernel(const std::string& kernel_name = "CUDAGeneratedKernel"); + + std::vector global_allocations() { + return global_allocations_; + } + + std::vector sync_allocations() { + return sync_allocations_; + } + private: - Fusion* const fusion_ = nullptr; + void lower(); + + // List of global buffers (not including buffers for grid syncronization) + std::vector global_allocations_; + + // List of syncronization buffers that must be initialized to 0 when running + // the fusion + std::vector sync_allocations_; + + std::vector lowered_exprs_; + Fusion* fusion_ = nullptr; }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index d506d64054df19..67f61f21215a11 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -1,4 +1,6 @@ +#include #include +#include #include #include @@ -11,103 +13,75 @@ void IndexLowering::pushBack(Expr* expr) { if (active_scope == nullptr) lowered_exprs.push_back(expr); else - scope_utils::pushBack(active_scope, expr); + active_scope->push_back(expr); } -Statement* IndexLowering::mutate(Expr* expr) { - Statement* mutated_stmt = OptOutMutator::mutate(expr); - TORCH_INTERNAL_ASSERT( - mutated_stmt->isExpr(), - "Tried to generate a kernel but hit a non expression during lowering: ", - mutated_stmt); - return mutated_stmt; -} +void IndexLowering::handle(IfThenElse* ite) { + Expr* prev_scope_expr = active_scope_expr; + Scope* prev_scope = active_scope; + + auto new_ite = new IfThenElse(ite->cond(), {}, {}, prev_scope_expr); + pushBack(new_ite); + active_scope_expr = new_ite; + active_scope = &new_ite->body(); -Statement* IndexLowering::mutate(IfThenElse* ite) { - Expr* prev_scope = active_scope; - active_scope = ite; - std::vector mutated_exprs; - bool is_mutated = false; for (auto expr : ite->body().exprs()) { - Statement* mutated_stmt = mutate(expr); - Expr* mutated_expr = ir_utils::asExpr(mutated_stmt); - mutated_exprs.push_back(mutated_expr); - is_mutated = is_mutated | (mutated_expr != expr); + OptInDispatch::handle(expr); } - std::vector mutated_else_exprs; - for (auto expr : ite->elseBody().exprs()) { - Statement* mutated_stmt = mutate(expr); - Expr* mutated_expr = ir_utils::asExpr(mutated_stmt); - mutated_else_exprs.push_back(mutated_expr); - is_mutated = is_mutated | (mutated_expr != expr); - } + active_scope = &new_ite->elseBody(); - if (is_mutated) { - ite->body().clear(); - for (auto expr : mutated_exprs) - ite->body().push_back(expr); - ite->elseBody().clear(); - for (auto expr : mutated_else_exprs) - ite->elseBody().push_back(expr); + for (auto expr : ite->elseBody().exprs()) { + OptInDispatch::handle(expr); } active_scope = prev_scope; + active_scope_expr = prev_scope_expr; +} - if (is_mutated) { - auto new_ite = new IfThenElse( - ite->cond(), mutated_exprs, mutated_else_exprs, ite->parentScope()); - return new_ite; - } +void IndexLowering::handle(ForLoop* fl) { + Expr* prev_scope_expr = active_scope_expr; + Scope* prev_scope = active_scope; - return ite; -} + auto newFl = new ForLoop(fl->index(), fl->iter_domain(), {}, prev_scope_expr); + pushBack(newFl); + + active_scope_expr = newFl; + active_scope = &newFl->body(); -Statement* IndexLowering::mutate(ForLoop* fl) { - Expr* prev_scope = active_scope; - active_scope = fl; - std::vector mutated_exprs; - bool is_mutated = false; for (auto expr : fl->body().exprs()) { - Statement* mutated_stmt = mutate(expr); - Expr* mutated_expr = ir_utils::asExpr(mutated_stmt); - mutated_exprs.push_back(mutated_expr); - is_mutated = is_mutated | (mutated_expr != expr); + OptInDispatch::handle(expr); } active_scope = prev_scope; - if (is_mutated) { - auto newFL = new ForLoop( - fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); - return newFL; - } - - return fl; + active_scope_expr = prev_scope_expr; } -Statement* IndexLowering::mutate(UnaryOp* uop) { - if (!ir_utils::isTVOp(uop)) - return OptOutMutator::mutate(uop); +void IndexLowering::handle(UnaryOp* uop) { + if (!ir_utils::isTVOp(uop)) { + pushBack(uop); + return; + } TensorIndex* out = Index::getConsumerIndex( - ir_utils::asTV(uop->out()), scope_utils::getLoops(active_scope)); + ir_utils::asTV(uop->out()), scope_utils::getLoops(active_scope_expr)); Val* in = uop->in(); if (ir_utils::isTV(in)) in = Index::getProducerIndex( ir_utils::asTV(in), ir_utils::asTV(uop->out()), - scope_utils::getLoops(active_scope)); - Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in); - - return new_op; + scope_utils::getLoops(active_scope_expr)); + pushBack(new UnaryOp(uop->getUnaryOpType(), out, in)); } -Statement* IndexLowering::mutate(BinaryOp* bop) { - if (!ir_utils::isTVOp(bop)) - return OptOutMutator::mutate(bop); +void IndexLowering::handle(BinaryOp* bop) { + if (!ir_utils::isTVOp(bop)) { + pushBack(bop); + return; + } TensorIndex* out = Index::getConsumerIndex( - ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope)); + ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope_expr)); Val* lhs = bop->lhs(); Val* rhs = bop->rhs(); @@ -116,25 +90,25 @@ Statement* IndexLowering::mutate(BinaryOp* bop) { lhs = Index::getProducerIndex( ir_utils::asTV(lhs), ir_utils::asTV(bop->out()), - scope_utils::getLoops(active_scope)); + scope_utils::getLoops(active_scope_expr)); if (ir_utils::isTV(rhs)) rhs = Index::getProducerIndex( ir_utils::asTV(rhs), ir_utils::asTV(bop->out()), - scope_utils::getLoops(active_scope)); - - Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs); + scope_utils::getLoops(active_scope_expr)); - return new_op; + pushBack(new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs)); } -Statement* IndexLowering::mutate(TernaryOp* top) { - if (!ir_utils::isTVOp(top)) - return OptOutMutator::mutate(top); +void IndexLowering::handle(TernaryOp* top) { + if (!ir_utils::isTVOp(top)) { + pushBack(top); + return; + } TensorIndex* out = Index::getConsumerIndex( - ir_utils::asTV(top->out()), scope_utils::getLoops(active_scope)); + ir_utils::asTV(top->out()), scope_utils::getLoops(active_scope_expr)); Val* in1 = top->in1(); Val* in2 = top->in2(); Val* in3 = top->in3(); @@ -143,80 +117,141 @@ Statement* IndexLowering::mutate(TernaryOp* top) { in1 = Index::getProducerIndex( ir_utils::asTV(in1), ir_utils::asTV(top->out()), - scope_utils::getLoops(active_scope)); + scope_utils::getLoops(active_scope_expr)); if (ir_utils::isTV(in2)) in2 = Index::getProducerIndex( ir_utils::asTV(in2), ir_utils::asTV(top->out()), - scope_utils::getLoops(active_scope)); + scope_utils::getLoops(active_scope_expr)); if (ir_utils::isTV(in3)) in3 = Index::getProducerIndex( ir_utils::asTV(in3), ir_utils::asTV(top->out()), - scope_utils::getLoops(active_scope)); - - Expr* new_op = new TernaryOp(top->getTernaryOpType(), out, in1, in2, in3); + scope_utils::getLoops(active_scope_expr)); - return new_op; + pushBack(new TernaryOp(top->getTernaryOpType(), out, in1, in2, in3)); } -Statement* IndexLowering::mutate(ReductionOp* rop) { +void IndexLowering::handle(ReductionOp* rop) { TORCH_INTERNAL_ASSERT( ir_utils::isTVOp(rop), - "Cannot have a reduction operation on something other than a tensor view."); - auto loops = scope_utils::getLoops(active_scope); + "Cannot have a reduction operation on something other than a tensor view, but received ", + rop); - bool is_private_reduce = - std::none_of(loops.begin(), loops.end(), [](ForLoop* fl) { - return fl->iter_domain()->isThread() && - fl->iter_domain()->isReduction(); - }); + auto out_tv = ir_utils::asTV(rop->out()); - TensorIndex* out = Index::getConsumerIndex(ir_utils::asTV(rop->out()), loops); + bool is_block_reduce = out_tv->hasBlockReduction(); - Val* in = rop->in(); - if (ir_utils::isTV(in)) - in = Index::getProducerIndex( - ir_utils::asTV(in), - ir_utils::asTV(rop->out()), - scope_utils::getLoops(active_scope)); + bool is_grid_reduce = out_tv->hasGridReduction(); + + // If we do a grid reduction we can't have a reduction axis that is not bound + // to a grid or block dim () + if (is_grid_reduce) { + TORCH_INTERNAL_ASSERT( + std::none_of( + out_tv->domain()->domain().begin(), + out_tv->domain()->domain().end(), + [](IterDomain* id) { + return !id->isThread() && id->isReduction(); + }), + "Found a reduction stage that has both a non-parallelized reduction and a grid reduction.", + " This is not supported, please use rfactor to do the serialized reduction first, then the grid reduction."); + } + auto loops = scope_utils::getLoops(active_scope_expr); - if (!is_private_reduce) - return new ReductionOp(rop->getReductionOpType(), rop->init(), out, in); + TensorIndex* out = Index::getConsumerIndex(out_tv, loops); + Val* in = rop->in(); + in = Index::getProducerIndex( + ir_utils::asTV(in), ir_utils::asTV(rop->out()), loops); + + ReductionOp* block_reduction = nullptr; + if (is_block_reduce) { + block_reduction = + new ReductionOp(rop->getReductionOpType(), rop->init(), out, in); + pushBack(block_reduction); + } - Expr* new_op = new BinaryOp(rop->getReductionOpType(), out, out, in); + if (is_grid_reduce) { + std::vector buffer_ids(out_tv->domain()->domain()); + buffer_ids.erase( + std::remove_if( + buffer_ids.begin(), + buffer_ids.end(), + [](IterDomain* id) { + return id->isReduction() & !id->isBlockDim(); + }), + buffer_ids.end()); + + Val* buffer_size = + buffer_ids.empty() ? new Int(1) : buffer_ids[0]->rawExtent(); + for (size_t i = 1; i < buffer_ids.size(); i++) { + buffer_size = mul(buffer_size, buffer_ids[i]->rawExtent()); + } + + std::vector sync_ids(out_tv->domain()->domain()); + sync_ids.erase( + std::remove_if( + sync_ids.begin(), + sync_ids.end(), + [](IterDomain* id) { + return id->isReduction() || !id->isBlockDim(); + }), + sync_ids.end()); + + Val* sync_size = sync_ids.empty() ? new Int(1) : sync_ids[0]->rawExtent(); + for (size_t i = 1; i < sync_ids.size(); i++) { + sync_size = mul(sync_size, sync_ids[i]->rawExtent()); + } + + IterDomain* buffer_id = new IterDomain(new Int(0), buffer_size); + TensorView* reduce_buffer_tv = new TensorView( + new TensorDomain({buffer_id}), out->getDataType().value()); + + IterDomain* sync_id = new IterDomain(new Int(0), sync_size); + TensorView* reduce_sync_tv = + new TensorView(new TensorDomain({sync_id}), DataType::Int); + + auto reduce_buffer = new Allocate(reduce_buffer_tv, MemoryType::Global); + auto sync_buffer = new Allocate(reduce_sync_tv, MemoryType::Global); + + pushBack(reduce_buffer); + pushBack(sync_buffer); + pushBack(new GridReduction( + block_reduction == nullptr + ? new ReductionOp(rop->getReductionOpType(), rop->init(), out, in) + : block_reduction, + reduce_buffer, + sync_buffer)); + } - return new_op; + if (!is_block_reduce && !is_grid_reduce) { + pushBack(new BinaryOp(rop->getReductionOpType(), out, out, in)); + } } -Statement* IndexLowering::mutate(BroadcastOp* bop) { - if (!ir_utils::isTVOp(bop)) - return OptOutMutator::mutate(bop); +void IndexLowering::handle(BroadcastOp* bop) { + TORCH_INTERNAL_ASSERT( + ir_utils::isTVOp(bop), + "Cannot have a broadcast operation on something other than a tensor view, but received ", + bop); TensorIndex* out = Index::getConsumerIndex( - ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope)); + ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope_expr)); Val* in = bop->in(); if (ir_utils::isTV(in)) in = Index::getProducerIndex( ir_utils::asTV(in), ir_utils::asTV(bop->out()), - scope_utils::getLoops(active_scope)); - Expr* new_op = new BroadcastOp(out, in); - - return new_op; + scope_utils::getLoops(active_scope_expr)); + pushBack(new BroadcastOp(out, in)); } void IndexLowering::generate(const std::vector& exprs) { // Run through loop nests and further lower the expressions for (auto* expr : exprs) { - Statement* mutated_stmt = mutate(expr); - TORCH_INTERNAL_ASSERT( - mutated_stmt->isExpr(), - "Tried to generate a kernel but hit a non expression during lowering: ", - mutated_stmt); - lowered_exprs.push_back(static_cast(mutated_stmt)); + OptInDispatch::handle(expr); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 564cc49681b8ed..a320fffdd82706 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -10,41 +10,50 @@ namespace torch { namespace jit { namespace fuser { -class TORCH_CUDA_API IndexLowering : public OptOutMutator { - private: - std::vector lowered_exprs; - Expr* active_scope = nullptr; +class TORCH_CUDA_API IndexLowering : public OptInDispatch { + public: + static std::vector getIndexedExprs( + Fusion* fusion, + std::vector incoming_exprs) { + FusionGuard fg(fusion); + IndexLowering il; + il.generate(incoming_exprs); + return il.lowered_exprs; + } - // Wrap pushBack in lower_utils if active_scope is null we want it to go + private: + // Wrap pushBack, if active_scope is null we want it to go // straight to lower_exprs void pushBack(Expr*); - // Custom dispatch for Expr, want to find out of it's a TV op - Statement* mutate(Expr*) final; - // Open the for loop. - Statement* mutate(ForLoop*) final; + void handle(ForLoop*) final; // Open the for loop. - Statement* mutate(IfThenElse*) final; + void handle(IfThenElse*) final; // Remake operations with TensorIndex - Statement* mutate(UnaryOp*) final; - Statement* mutate(BinaryOp*) final; - Statement* mutate(TernaryOp*) final; - Statement* mutate(ReductionOp*) final; - Statement* mutate(BroadcastOp*) final; + void handle(UnaryOp*) final; + void handle(BinaryOp*) final; + void handle(TernaryOp*) final; + void handle(ReductionOp*) final; + void handle(BroadcastOp*) final; + void handle(Allocate* expr) final { + pushBack(expr); + } + void generate(const std::vector& exprs); - public: - static std::vector getIndexedExprs( - Fusion* fusion, - std::vector incoming_exprs) { - FusionGuard fg(fusion); - IndexLowering il; - il.generate(incoming_exprs); - return il.lowered_exprs; - } + std::vector lowered_exprs; + + // This is a slight work around as scope has a couple definitions, we have the + // Scope that's in ForLoop/IfThenElse which is really just a wrapper around + // std::vector and then we have the actual ForLoop/IfThenElse. We want + // to be able to carry both around because when we push back to a scope it + // could be either the body or else body of the IfThenElse. However, we want + // to understand the nesting of IfThenElse/ForLoop nodes. + Scope* active_scope = nullptr; + Expr* active_scope_expr = nullptr; }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 93f2613cff507f..fa9af334b3e313 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -67,7 +67,7 @@ Expr* LoopNestGenerator::pushAlloc(TensorView* tv) { } // Create the allocation node - Allocate* alloc = new Allocate(tv, size); + Allocate* alloc = new Allocate(tv, MemoryType::Local, size); // Place the allocation if (alloc_pos == 0) { @@ -258,8 +258,16 @@ void LoopNestGenerator::initReduction( */ void LoopNestGenerator::handle(Expr* expr) { if (!ir_utils::isTVOp(expr)) { - for (auto out : expr->outputs()) - pushBack(new Allocate(out, new Int(1))); + for (auto out : expr->outputs()) { + TORCH_INTERNAL_ASSERT( + out->getValType().value() == ValType::Scalar, + "Unrecognized output type found in expr ", + expr, + " cannot lower ", + out->getValType().value()); + + pushBack(new Allocate(out, MemoryType::Local, new Int(1))); + } pushBack(expr); return; } diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp index 913beb5dd59cbb..ea7292f5dcbd49 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -11,7 +11,7 @@ namespace fuser { namespace { Val* threadPredicate(ParallelType pt) { - return eq(new NamedScalar(stringifyThread(pt), DataType::Int), new Int(0)); + return eq(NamedScalar::getParallelIndex(pt), new Int(0)); } Bool* getThreadPredicate(const ir_utils::ParallelTypeBitmap& bits) { diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp index 64235f34cd41c0..746b3b409d1534 100644 --- a/torch/csrc/jit/codegen/cuda/manager.cpp +++ b/torch/csrc/jit/codegen/cuda/manager.cpp @@ -8,8 +8,10 @@ #include #include #include +#include #include +#include #include namespace torch { @@ -70,17 +72,15 @@ class CudaFusionManager { int32_t kernel_id, std::shared_ptr& graph, const at::ArrayRef inputs, - const std::vector& outputs, - const std::vector& broadcasted_shape) { + const std::vector& outputs) { std::lock_guard guard(mutex_); TORCH_CHECK( kernel_cache_.count(kernel_id) != 0, "kernel id not recognized"); - if (auto cuda_kernel_opt = - kernel_cache_[kernel_id].getKernelPtr(inputs, broadcasted_shape)) { + if (auto cuda_kernel_opt = kernel_cache_[kernel_id].getKernelPtr(inputs)) { // TODO: update launch config for specific sizes; // maybe we should store it in CudaKernel and compute it later - runKernel(*cuda_kernel_opt, inputs, outputs, broadcasted_shape); + runKernel(*cuda_kernel_opt, inputs, outputs); } else { // TODO: this should somehow be done after kernel compilation. // we will want compileKernel to return a heuristic @@ -113,7 +113,7 @@ class CudaFusionManager { // NVRTC compile kernel compileKernel(cuda_kernel); - runKernel(cuda_kernel, inputs, outputs, broadcasted_shape); + runKernel(cuda_kernel, inputs, outputs); } } @@ -168,24 +168,30 @@ void runCudaFusionGroup(const Node* fusion_node, Stack& stack) { const auto nInputs = graph->inputs().size(); at::ArrayRef inputs = last(stack, nInputs); - // shape inference in graph + // Only needed if we are doing codegen + // if no shape information available, we feed current shape into the kernel; + if (!IsNewExecutorEnabled()) { + EraseShapeInformation(graph); + for (size_t i = 0; i < nInputs; i++) { + graph->inputs()[i]->setType(inputs[i].type()); + } + ShapeTypePropagate(graph); + } + /* + // TODO: Delete the shape inference here once we switch to + // ExpressionEvaluator to allocate outputs + // shape inference in graph to allocate outputs // update shape information per the new inputs; - EraseShapeInformation(graph); + EraseShapeInformation(shape_inf_graph); for (size_t i = 0; i < nInputs; i++) { - graph->inputs()[i]->setType(inputs[i].type()); + shape_inf_graph->inputs()[i]->setType(inputs[i].type()); } // shape inference - ShapeTypePropagate(graph); - - // TODO: temporary WAR that allows us to handle fusion with uniform output - // shape and consistent broadcast scheme. The difinition is loose and the - // implementation is risky. We'll do this properly when we integrate proper - // broadcast support. - std::vector broadcasted_shape; + ShapeTypePropagate(shape_inf_graph); // we need to construct outputs; std::vector outputs; - for (const auto* output : graph->outputs()) { + for (const auto* output : shape_inf_graph->outputs()) { const auto type = output->type()->expect(); // Expect output to be tensor; TORCH_CHECK( @@ -207,32 +213,16 @@ void runCudaFusionGroup(const Node* fusion_node, Stack& stack) { const auto tensor = at::empty_strided(sizes, strides, options); outputs.push_back(tensor); - - // TODO: unsafe broadcast assumption. We assume all output from fusion has - // identical size when broadcasting. - if (broadcasted_shape.empty()) { - if (!hasReductionNode(graph->block())) { - broadcasted_shape = sizes; - } else if (isReductionNode(output->node())) { - auto i_type = - output->node()->inputs()[0]->type()->expect(); - TORCH_CHECK( - i_type && i_type->sizes().isComplete(), - "Complete TensorType for output is expected."); - broadcasted_shape = extractSizes(i_type); - } else { - // TODO: this assert is not fool proof. We could have ignored - // pre-reduction tensor marked as output after we first encountered - // reduction output tensor. - TORCH_INTERNAL_ASSERT( - false, - "pre-reduction tensor output for reduction fusion is nor properly supported yet."); - } - } } - CudaFusionManager::getManager().runFusionNode( - kernel_id, graph, inputs, outputs, broadcasted_shape); + kernel_id, graph, inputs, outputs); + */ + FusionExecutor executor; + auto fusion = parseJitIR(graph); + scheduleFusion(fusion.get(), inputs); + executor.compileFusion(fusion.get()); + auto outputs = executor.runFusion(inputs); + drop(stack, inputs.size()); stack.insert( stack.end(), diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index f6cf432d01139f..59bb32c594fc05 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -126,23 +126,33 @@ Statement* OptOutMutator::mutate(NamedScalar* ns) { // MUTATE FUNCTIONS FOR EXPRESSIONS. Statement* OptOutMutator::mutate(Allocate* a) { - TensorView* tv = static_cast(mutateAsVal(a->buffer())); - Val* ext = mutateAsVal(a->extent())->asVal(); - if (ext->sameAs(a->extent()) && tv->sameAs(a->buffer())) - return a; - FusionGuard::getCurFusion()->removeExpr(a); - return new Allocate(tv, ext); + if (a->buffer()->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(mutateAsVal(a->buffer())); + Val* ext = mutateAsVal(a->size())->asVal(); + if (ext->sameAs(a->size()) && tv->sameAs(a->buffer())) + return a; + FusionGuard::getCurFusion()->removeExpr(a); + return new Allocate(tv, a->getMemoryType(), a->size()); + } else { + Val* buffer = mutateAsVal(a->buffer())->asVal(); + Val* ext = mutateAsVal(a->size())->asVal(); + if (ext->sameAs(a->size()) && buffer->sameAs(a->buffer())) + return a; + FusionGuard::getCurFusion()->removeExpr(a); + return new Allocate(buffer, a->getMemoryType(), a->size()); + } } Statement* OptOutMutator::mutate(Split* s) { - IterDomain* ot = static_cast(mutateAsVal(s->outer())); - IterDomain* inr = static_cast(mutateAsVal(s->inner())); - IterDomain* in = static_cast(mutateAsVal(s->in())); - Int* fact = static_cast(mutateAsVal(s->factor())); + IterDomain* ot = mutateAsVal(s->outer())->as(); + IterDomain* inr = mutateAsVal(s->inner())->as(); + IterDomain* in = mutateAsVal(s->in())->as(); + Val* fact = mutateAsVal(s->factor())->as(); if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && - in->sameAs(s->in()) && fact->sameAs(s->factor())) + in->sameAs(s->in()) && areEqualScalars(fact, s->factor())) { return s; + } FusionGuard::getCurFusion()->removeExpr(s); return new Split(ot, inr, in, fact); } @@ -202,6 +212,19 @@ Statement* OptOutMutator::mutate(ReductionOp* rop) { return new ReductionOp(rop->getReductionOpType(), init, out, in); } +Statement* OptOutMutator::mutate(GridReduction* gr) { + ReductionOp* reduction_op = mutate(gr->reduction_op())->as(); + Allocate* reduction_buffer = mutate(gr->reduction_buffer())->as(); + Allocate* sync_buffer = mutate(gr->sync_buffer())->as(); + + if (reduction_op->sameAs(gr->reduction_op()) && + reduction_buffer->sameAs(gr->reduction_buffer()) && + sync_buffer->sameAs(gr->sync_buffer())) + return gr; + + return new GridReduction(reduction_op, reduction_buffer, sync_buffer); +} + Statement* OptOutMutator::mutate(BroadcastOp* bop) { Val* out = mutateAsVal(bop->out())->asVal(); Val* in = mutateAsVal(bop->in())->asVal(); diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 12e4d674644c60..5b94feb0484a9b 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -100,45 +100,11 @@ class IrParser { FusionGuard fg(fusion.get()); auto block = graph_->block(); - // [ Note - broadcast support in integration ] - // - // in case of broadcast, we don't support explicit broadcast, - // 1. for point-wise fusion, so we need to convert/expand all inputs - // tensors to comply to the broadcasted size. This supports very limited - // case, which we try to accomodate in graph partition, that we only merge - // nodes with identical output shapes. - // 2. in case of reduction-at-end fusion, right now we only support single - // reduction operation in fusion, hence we can use the same logig for PW - // fusion and conver/expand all inputs to the input tensor to reduction op. - - // TODO: proper broadcast support in integration - int broadcast_dim = -1; - // broadcast support hack is disabled to reduction. - if (hasReductionNode(graph_->block())) { - // reduction-at-end fusion, broadcast all inputs to tensor before - // reduction - // TODO: Not perfectly safe! We could have intermediate output that is not - // part of outputs of reduction operations. But we have similar limitation - // for broadcast support in PW fusion. We should properly fix this after - // broadcast integration. - broadcast_dim = block->outputs()[0] - ->node() - ->inputs()[0] - ->type() - ->cast() - ->dim() - .value(); - } else { - // point-wise fusion, broadcast all inputs to output size. - broadcast_dim = - block->outputs()[0]->type()->cast()->dim().value(); - } - // register all inputs; - // shape propagation during parsing is effctively done in parsing rules, as - // we only explicitly register inputs in the graph. for (auto val : block->inputs()) { - TORCH_INTERNAL_ASSERT(registerValue(val, broadcast_dim)); + TORCH_INTERNAL_ASSERT( + registerValue(val), + "Error trying to register value with code generation."); fusion->addInput(value_map_[val->unique()]); auto opt_dtype = value_map_[val->unique()]->getDataType(); @@ -180,7 +146,6 @@ class IrParser { } fusion->addOutput(out); } - return fusion; } @@ -593,8 +558,8 @@ class IrParser { } } - bool registerValue(const JitValue* val, int broadcast_dim = -1) { - return registerTensor(val, broadcast_dim) || registerScalar(val); + bool registerValue(const JitValue* val) { + return registerTensor(val) || registerScalar(val); } bool registerScalar(const JitValue* val) { @@ -640,17 +605,9 @@ class IrParser { return false; } - bool registerTensor(const JitValue* val, int broadcast_dim = -1) { + bool registerTensor(const JitValue* val) { CgValue cg_val; if (auto tensor_type = val->type()->cast()) { - // TODO: make this a static function in Tensor class; - // create tensor; - if (broadcast_dim >= 0) { - TORCH_INTERNAL_ASSERT( - broadcast_dim >= (int)*tensor_type->dim(), - "attempt to broadcast a tensor to shrinked dimension is invalid"); - tensor_type = tensor_type->withDim(broadcast_dim); - } // TODO: make this a static function in Tensor class; // create tensor; cg_val = new TensorView(tensor_type); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index f8da6f64ae5326..1986b4941562eb 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -31,12 +31,26 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) aten_opt_type_map(tensor_type->scalarType()), false) { std::vector sizes; + TORCH_CHECK( tensor_type->dim().has_value(), "Requires static rank for Tensor"); + for (decltype(tensor_type->dim().value()) i = 0; i < tensor_type->dim().value(); i++) { - sizes.push_back(new IterDomain(new Int(0), new Int())); + if (tensor_type->sizes()[i].has_value() && + tensor_type->sizes()[i].value() == 1) { + // If size is known to be 1, assuem it needs to be broadcasted. + sizes.push_back(new IterDomain( + new Int(0), + new Int(1), + ParallelType::Serial, + false, + false, + BroadcastType::WithStride)); + } else { + sizes.push_back(new IterDomain(new Int(0), new Int())); + } } domain_ = new TensorDomain(sizes); @@ -214,8 +228,11 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { return this; } -TensorView* TensorView::split(int axis, unsigned int factor) { +TensorView* TensorView::split(int axis, Val* factor) { + // Only check things associated with axis, factor will be validated in + // IterDomain TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim TensorView"); + if (axis < 0) axis += domain()->nDims(); @@ -232,6 +249,11 @@ TensorView* TensorView::split(int axis, unsigned int factor) { return this; } +TensorView* TensorView::split(int axis, unsigned int factor) { + domain()->split(axis, new Int(factor)); + return this; +} + // Merge "axis" and "axis+1" into 1 dimension TensorView* TensorView::merge(int axis_o, int axis_i) { TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim TensorView"); @@ -343,7 +365,7 @@ TensorView* TensorView::cache_before() { root->parallel_method(), false, false, - true)); + root->getBroadcastType())); } else if (!root->isBroadcast() && !root->isReduction()) { new_root_domain.push_back(new IterDomain( root->start(), root->extent(), root->parallel_method())); @@ -421,7 +443,7 @@ TensorView* TensorView::cache_after() { root->parallel_method(), false, false, - true)); + root->getBroadcastType())); } else if (!root->isBroadcast() && !root->isReduction()) { new_root_domain.push_back(new IterDomain( root->start(), root->extent(), root->parallel_method())); diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 1c6996ba93b8b8..f76e88a48e64f9 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -35,16 +35,13 @@ void ReplayTransformations::handle(Split* s) { } auto mapped = (*it).second; - TORCH_INTERNAL_ASSERT( - s->factor()->isConst(), - "Transform traversal does not support splitting on non-const values."); // Make sure this ID is a leaf ID (meaning it has no uses we generated) TORCH_INTERNAL_ASSERT( leaf_ids_.find(mapped) != leaf_ids_.end(), "Transform traversal failed, modified a node but it was not a leaf node."); // Replay the split onto mapped - auto outs = IterDomain::split(mapped, s->factor()->value().value()); + auto outs = IterDomain::split(mapped, s->factor()); // Remove mapped from the leaf IDs leaf_ids_.erase(mapped); diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index ef43c2ca43dba8..f15bcd44b5e821 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -48,7 +48,7 @@ class ReplaySelf : public ReplayTransformations { s->outer()->parallel_method(), s->outer()->isReduction(), s->outer()->isRFactorProduct(), - s->outer()->isBroadcast()); + s->outer()->getBroadcastType()); // inner IterDomain IterDomain* idi = new IterDomain( @@ -57,7 +57,7 @@ class ReplaySelf : public ReplayTransformations { s->inner()->parallel_method(), s->inner()->isReduction(), s->inner()->isRFactorProduct(), - s->inner()->isBroadcast()); + s->inner()->getBroadcastType()); // Generate the split node new Split(ido, idi, mapped, s->factor()); @@ -106,7 +106,7 @@ class ReplaySelf : public ReplayTransformations { m->out()->parallel_method(), m->out()->isReduction(), m->out()->isRFactorProduct(), - m->out()->isBroadcast()); + m->out()->getBroadcastType()); new Merge(merged_id, id_outer_mapped, id_inner_mapped); diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index ac10289b9da6b5..8e96ee315c6094 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -25,9 +25,6 @@ class ReplayRFactor : public ReplayTransformations { "Transform traversal failed, dependencies not met."); // Grab the ID we're going to replay on auto mapped = (*it).second; - TORCH_INTERNAL_ASSERT( - s->factor()->isConst(), - "Transform traversal does not support splitting on non-const values."); // This ID should be a leaf ID (meaning it has no uses we generated) TORCH_INTERNAL_ASSERT( leaf_ids_.find(mapped) != leaf_ids_.end(), @@ -59,7 +56,7 @@ class ReplayRFactor : public ReplayTransformations { mapped->parallel_method(), rfactor_outer, true, - mapped->isBroadcast()); + mapped->getBroadcastType()); // inner IterDomain IterDomain* idi = new IterDomain( @@ -68,7 +65,7 @@ class ReplayRFactor : public ReplayTransformations { mapped->parallel_method(), rfactor_inner, true, - mapped->isBroadcast()); + mapped->getBroadcastType()); // Generate the split node new Split(ido, idi, mapped, s->factor()); @@ -117,13 +114,24 @@ class ReplayRFactor : public ReplayTransformations { Val* merged_id_size = mul(id_outer_mapped->extent(), id_inner_mapped->extent()); + + BroadcastType bcast_type = BroadcastType::Null; + if (id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast()) { + if (id_outer_mapped->getBroadcastType() == BroadcastType::WithStride || + id_inner_mapped->getBroadcastType() == BroadcastType::WithStride) { + bcast_type = BroadcastType::WithStride; + } else { + bcast_type = BroadcastType::WithoutStride; + } + } + IterDomain* merged_id = new IterDomain( new Int(0), static_cast(merged_id_size), id_outer_mapped->parallel_method(), rfactor_output, true, - id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast()); + bcast_type); new Merge(merged_id, id_outer_mapped, id_inner_mapped); @@ -249,7 +257,7 @@ TensorDomain* TransformRFactor::runReplay( id->parallel_method(), true, true, - false); + BroadcastType::Null); // If this is not an rfactor root, but a reduction root, it should be // turned into an iteration domain } else if (id->isReduction()) { @@ -259,7 +267,7 @@ TensorDomain* TransformRFactor::runReplay( id->parallel_method(), false, false, - false); + BroadcastType::Null); } else { new_root[i] = id->clone(); } diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 0ece4d1ae40343..0d5d3f5af9e3d9 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -40,7 +40,7 @@ static const char* data_type2string(DataType t) { case DataType::Half: return "__half"; case DataType::Int: - return "size_t"; + return "int64_t"; case DataType::Null: return "nullptr"; default: @@ -81,6 +81,8 @@ static const char* expr_type2string(ExprType t) { return "TernaryOp"; case ExprType::ReductionOp: return "ReductionOp"; + case ExprType::GridReduction: + return "GridReduction"; case ExprType::BroadcastOp: return "BroadcastOp"; case ExprType::ForLoop: @@ -331,21 +333,18 @@ static const char* memory_type2string(MemoryType t) { return nullptr; } -static DataType at_type2data_type(at::ScalarType t) { +static const char* broadcast_type2string(BroadcastType t) { switch (t) { - case at::ScalarType::Bool: - return DataType::Bool; - case at::ScalarType::Float: - return DataType::Float; - case at::ScalarType::Half: - return DataType::Half; - case at::ScalarType::Int: - return DataType::Int; + case BroadcastType::Null: + return ""; + case BroadcastType::WithStride: + return "sb"; + case BroadcastType::WithoutStride: + return "b"; default: - break; + TORCH_INTERNAL_ASSERT(false, "No string found for Broadcast type."); + return nullptr; } - TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type."); - return DataType::Null; } static const char* thread_size2string(ParallelType t) { @@ -402,7 +401,36 @@ bool is_logical_op(const BinaryOpType& bot) { } DataType aten_to_data_type(const at::ScalarType& scalar_type) { - return at_type2data_type(scalar_type); + switch (scalar_type) { + case at::ScalarType::Bool: + return DataType::Bool; + case at::ScalarType::Float: + return DataType::Float; + case at::ScalarType::Half: + return DataType::Half; + case at::ScalarType::Long: + return DataType::Int; + default: + TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type."); + return DataType::Null; + } +} + +at::ScalarType data_type_to_aten(const DataType& data_type) { + switch (data_type) { + case DataType::Bool: + return at::ScalarType::Bool; + case DataType::Float: + return at::ScalarType::Float; + case DataType::Half: + return at::ScalarType::Half; + case DataType::Int: + return at::ScalarType::Long; + default: + break; + } + TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type."); + return at::ScalarType::Undefined; } std::ostream& operator<<(std::ostream& out, const ValType vtype) { @@ -437,7 +465,14 @@ std::ostream& operator<<(std::ostream& out, const MemoryType mtype) { return out << memory_type2string(mtype); } -c10::optional inline_op_str(const UnaryOpType uotype) { +TORCH_CUDA_API std::ostream& operator<<( + std::ostream& out, + const BroadcastType bt) { + return out << broadcast_type2string(bt); +} + +TORCH_CUDA_API c10::optional inline_op_str( + const UnaryOpType uotype) { const char* str = unary_op_type_inline_op2string(uotype); return str != nullptr ? c10::optional(std::string(str)) : c10::nullopt; diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index f70ac2f8191fa4..298f9ee1ef08c1 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -32,6 +32,7 @@ enum class ExprType { TernaryOp, ReductionOp, BroadcastOp, + GridReduction, ForLoop, IfThenElse, Allocate, @@ -122,11 +123,22 @@ enum class ParallelType { enum class MemoryType { Local, Shared, Global }; +// sometimes broadcasted tensors may be inputed in the kernel with an explicit 1 +// size. If that size is there, we need to account that there's also a stride +// there, even if the stride = 0. If we don't account for that stride when +// accessing a tensor like: [b2{1}, i0, i1] we would linearize the access like: +// [i0*stride[0] + i1*stride[1]] when it should be: [i0*stride[1] + +// i1*stride[2]]. Broadcasts that translate to a physical memory dim we consider +// "with stride", Broadcasts only through our broadcast op we consider "without +// stride" +enum class BroadcastType { Null, WithStride, WithoutStride }; + ValType promote_type(const ValType& t1, const ValType& t2); DataType promote_type(const DataType& t1, const DataType& t2); bool is_logical_op(const BinaryOpType& bot); DataType aten_to_data_type(const at::ScalarType& scalar_type); +at::ScalarType data_type_to_aten(const DataType& data_type); TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const ValType); TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const DataType); @@ -135,6 +147,8 @@ TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const UnaryOpType); TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const BinaryOpType); TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const TernaryOpType); TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const ParallelType); +TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const MemoryType); +TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const BroadcastType); std::string stringifyThreadSize(const ParallelType); std::string stringifyThread(const ParallelType);