Skip to content

Commit

Permalink
Fix div(Val, TensorView) (#1778)
Browse files Browse the repository at this point in the history
* Fix div(scalar, tensor)

* lintrunner: clang-format
  • Loading branch information
IvanYashchuk authored Jun 28, 2022
1 parent d3de227 commit 86f46aa
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ TensorView* binaryOp(
} \
TensorView* op_name(Val* v1, TensorView* v2) { \
return binaryOp( \
BinaryOpType::op_type, v2, v2, TypePromotion::float_op_config); \
BinaryOpType::op_type, v1, v2, TypePromotion::float_op_config); \
} \
TensorView* op_name(TensorView* v1, TensorView* v2) { \
return binaryOp( \
Expand Down
24 changes: 24 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23354,6 +23354,30 @@ TEST_F(NVFuserTest, FusionContigPredicate_CUDA) {
testValidate(fe.kernel(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
}

// Repro of https://github.com/csarofeen/pytorch/issues/1777
TEST_F(NVFuserTest, FusionDivScalarLhs_CUDA) {
// tv1 = 2.0 / tv0
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
TensorView* tv1 = div(IrBuilder::create<Double>(2.0), tv0);
fusion.addOutput(tv1);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({3, 3}, options);
// There's no overload div(Scalar, Tensor) in ATen
auto aten_output = at::div(
at::native::wrapped_scalar_tensor(at::Scalar(2.0), options.device()), t0);

FusionExecutor fe;
fe.compileFusion(&fusion, {t0});
auto cg_outputs = fe.runFusion({t0});

testValidate(&fusion, cg_outputs, {t0}, {aten_output}, __LINE__, __FILE__);
}

// Repro of an issue of the reduction scheduler with a broadcast
// domain concretized to multiple domains that are not proven to have
// the same extent
Expand Down

0 comments on commit 86f46aa

Please sign in to comment.