From cdb56f41c9248e1284a486ae0941b7a2f9f172dd Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 15 Jun 2022 15:07:07 -0700 Subject: [PATCH] Proper propagation of IterType --- torch/csrc/jit/codegen/cuda/arith.cpp | 51 ++++++++++++++++--- .../jit/codegen/cuda/test/test_gpu_shift.cpp | 40 +++++++++++++++ 2 files changed, 85 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 8894a32ddc106..086c8e73c20ce 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -110,6 +110,49 @@ Val* newScalar(ValType vtype, DataType dtype) { " in newScalar."); } +IterType promoteIterType(IterType type1, IterType type2) { + // Iteration: Default + // Reduction: Should not appear here + // Broadcast: Propagated only if type1 and type2 are Broadcast + // Gather: Converted to Iteration + // Stride: Shold not appear here + // VectorComponent: Converted to Iteration + + TORCH_INTERNAL_ASSERT( + type1 != IterType::Reduction && type1 != IterType::Stride, + "Invalid IterType: ", + type1) + TORCH_INTERNAL_ASSERT( + type2 != IterType::Reduction && type2 != IterType::Stride, + "Invalid IterType: ", + type2); + + // Do not propagate Gather and VectorComponent + if (type1 == IterType::Gather || type1 == IterType::VectorComponent) { + type1 = IterType::Iteration; + } + if (type2 == IterType::Gather || type2 == IterType::VectorComponent) { + type2 = IterType::Iteration; + } + + // At this point, type1 and type2 must be either Iteration or + // Broadcast + TORCH_INTERNAL_ASSERT( + type1 == IterType::Iteration || type1 == IterType::Broadcast, + "Unexpected IterType: ", + type1); + TORCH_INTERNAL_ASSERT( + type2 == IterType::Iteration || type2 == IterType::Broadcast, + "Unexpected IterType: ", + type2); + + if (type1 == IterType::Broadcast) { + return type2; + } else { + return type1; + } +} + TensorView* newOutputTV(const std::vector& vals, DataType dtype) { std::vector tvs; for (auto val : vals) { @@ -155,12 +198,8 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { } extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent()); if (iter_types[i].has_value()) { - // TODO: Enable, see conv tests and gather promotion/gather broadcast - // behavior. - // - // TORCH_INTERNAL_ASSERT( - // iter_types[i].value() == dom[i]->getIterType(), - // "Invalid iter type promotion in newOutputTv for expression."); + iter_types[i] = + promoteIterType(iter_types[i].value(), dom[i]->getIterType()); } else { iter_types[i] = dom[i]->getIterType(); } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp index d4ebe51c893e1..8c45bb37bbeb8 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp @@ -5384,6 +5384,46 @@ TEST_F(NVFuserTest, FusionValidateParallelizeShift_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } +// Test IterType promotion with gather +TEST_F(NVFuserTest, FusionGatherIterTypePromotion_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int s1 = 11; + const int s2 = 3; + + auto tv0 = makeConcreteTensor({s1}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({s1, s2}); + fusion.addInput(tv1); + + const std::vector window_shape = {3}; + const std::vector> padding_width = {{1, 1}}; + + auto tv2 = gather(tv0, window_shape, padding_width); + auto tv3 = add(tv2, tv1); + + fusion.addOutput(tv3); + + TORCH_CHECK( + tv3->axis(1)->getIterType() == IterType::Iteration, + "Invalid IterType promotion: ", + tv3->axis(1)->toString()); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({s1}, options); + at::Tensor t1 = at::randn({s1, s2}, options); + std::vector inputs = {t0, t1}; + + auto ref = gather(t0, window_shape, padding_width) + t1; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)