From 86f46aad83cbb2aa06943419a7335d71a8798f2a Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 28 Jun 2022 14:14:56 +0300 Subject: [PATCH] Fix div(Val, TensorView) (#1778) * Fix div(scalar, tensor) * lintrunner: clang-format --- torch/csrc/jit/codegen/cuda/arith.cpp | 2 +- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 211d5c4aad2342..32edaec3c1faf1 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -733,7 +733,7 @@ TensorView* binaryOp( } \ TensorView* op_name(Val* v1, TensorView* v2) { \ return binaryOp( \ - BinaryOpType::op_type, v2, v2, TypePromotion::float_op_config); \ + BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \ } \ TensorView* op_name(TensorView* v1, TensorView* v2) { \ return binaryOp( \ diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 72096120484673..eb2a14e7d80871 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23354,6 +23354,30 @@ TEST_F(NVFuserTest, FusionContigPredicate_CUDA) { testValidate(fe.kernel(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); } +// Repro of https://github.com/csarofeen/pytorch/issues/1777 +TEST_F(NVFuserTest, FusionDivScalarLhs_CUDA) { + // tv1 = 2.0 / tv0 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = div(IrBuilder::create(2.0), tv0); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({3, 3}, options); + // There's no overload div(Scalar, Tensor) in ATen + auto aten_output = at::div( + at::native::wrapped_scalar_tensor(at::Scalar(2.0), options.device()), t0); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {aten_output}, __LINE__, __FILE__); +} + // Repro of an issue of the reduction scheduler with a broadcast // domain concretized to multiple domains that are not proven to have // the same extent