Skip to content

Commit

Permalink
Make sure ops are looking at root domains, not at current transformat…
Browse files Browse the repository at this point in the history
…ions.
  • Loading branch information
csarofeen authored and rdspring1 committed Jul 8, 2020
1 parent 37dc0ca commit ce176bd
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,15 @@ BroadcastOp::BroadcastOp(Val* _out, Val* _in)
"Cannot braodcast a non-tensor object.");

int ndims = 0;
for (auto dom : static_cast<TensorView*>(out())->domain()->domain())
for (auto dom : out()->as<TensorView>()->getRootDomain())
if (!dom->isBroadcast())
ndims++;

TORCH_INTERNAL_ASSERT(
ndims == (int)in_->as<TensorView>()->domain()->noReductions().size(),
ndims ==
(int)TensorDomain::noReductions(
in_->as<TensorView>()->getRootDomain())
.size(),
"Invalid broadcast op. Non-broadcasted dims don't match from input to output.");
} else {
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -234,9 +237,30 @@ ReductionOp::ReductionOp(
init_(_init),
out_(_out),
in_(_in) {
if (_out->getValType().value() == ValType::TensorView) {
TORCH_INTERNAL_ASSERT(
_in->getValType() == ValType::TensorView &&
_out->getValType() == ValType::TensorView,
"Reduction operation was created that does not have tensor inputs and outputs.");

TORCH_INTERNAL_ASSERT(
TensorDomain::noReductions(_in->as<TensorView>()->getRootDomain())
.size() == _out->as<TensorView>()->getRootDomain().size() ||
TensorDomain::noReductions(
_in->as<TensorView>()->domain()->rfactorDomain())
.size() == _out->as<TensorView>()->getRootDomain().size(),
"Reduction operation created with mismatched domains.");

} else {
TORCH_INTERNAL_ASSERT(
_in->getValType() == ValType::TensorIndex &&
_out->getValType() == ValType::TensorIndex,
"Reduction operation was created that does not have tensor inputs and outputs.");
}
TORCH_INTERNAL_ASSERT(
_init->isConstScalar(),
"Tried to create a reduction operation whith an initial value that isn't a constant.");

addOutput(_out);
addInput(_in);
this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
Expand Down

0 comments on commit ce176bd

Please sign in to comment.