diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index efa4e28c105b..664181403c79 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -559,171 +559,6 @@ void Val::mutatorDispatch(T mutator, Val* val) { TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!"); } -template -void Expr::mutatorDispatch(T mutator, Expr* expr) { - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); -} - template void Statement::mutatorDispatch(T mutator, Statement* stmt) { if (stmt->isVal()) { @@ -774,8 +609,6 @@ template void Statement::mutatorDispatch(OptOutMutator&, Statement*); template void Statement::mutatorDispatch(OptOutMutator*, Statement*); template void Val::mutatorDispatch(OptOutMutator&, Val*); template void Val::mutatorDispatch(OptOutMutator*, Val*); -template void Expr::mutatorDispatch(OptOutMutator&, Expr*); -template void Expr::mutatorDispatch(OptOutMutator*, Expr*); void OptOutDispatch::handle(Statement* s) { Statement::dispatch(this, s); diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 5d9d363554d4..84704f23d5c2 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -326,50 +326,6 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(kir::Predicate*); virtual void mutate(kir::TensorIndex*); - // Exprs - virtual void mutate(FullOp*); - virtual void mutate(ARangeOp*); - virtual void mutate(EyeOp*); - virtual void mutate(UnaryOp*); - virtual void mutate(BinaryOp*); - virtual void mutate(TernaryOp*); - virtual void mutate(SelectOp*); - virtual void mutate(RNGOp*); - virtual void mutate(ReductionOp*); - virtual void mutate(GroupedReductionOp*); - virtual void mutate(WelfordOp*); - virtual void mutate(GroupedWelfordOp*); - virtual void mutate(LoadStoreOp*); - virtual void mutate(MmaOp*); - virtual void mutate(BroadcastOp*); - virtual void mutate(SqueezeOp*); - - virtual void mutate(Split*); - virtual void mutate(Merge*); - virtual void mutate(Swizzle2D*); - virtual void mutate(TransposeOp*); - virtual void mutate(ExpandOp*); - virtual void mutate(ShiftOp*); - virtual void mutate(GatherOp*); - virtual void mutate(ViewAsScalar*); - virtual void mutate(ViewOp*); - - virtual void mutate(kir::Allocate*); - virtual void mutate(kir::BlockSync*); - virtual void mutate(kir::GridSync*); - virtual void mutate(kir::CpAsyncWait*); - virtual void mutate(kir::CpAsyncCommit*); - virtual void mutate(kir::InitMagicZero*); - virtual void mutate(kir::UpdateMagicZero*); - virtual void mutate(kir::ForLoop*); - virtual void mutate(kir::IfThenElse*); - virtual void mutate(kir::GridReduction*); - virtual void mutate(kir::GroupedGridReduction*); - virtual void mutate(kir::GridBroadcast*); - virtual void mutate(kir::GridWelford*); - virtual void mutate(kir::GroupedGridWelford*); - virtual void mutate(kir::AllocateFusedReduction*); - protected: void removeExpr(IrContainer*, Expr*); }; diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 64abef3d44a8..623b49c48389 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -337,7 +337,8 @@ Expr::Expr( attributes_(std::move(attributes)) {} Expr* Expr::shallowCopy() const { - auto result = newObject(inputs(), outputs(), attributes()); + auto result = + newObjectFunc()(ir_container_, inputs(), outputs(), attributes()); if (container()->isA()) { result->predicate_ = predicate_; result->write_predicate_ = write_predicate_; diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index e6b4f35d78d1..013034e1ecff 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -429,6 +429,12 @@ class TORCH_CUDA_CU_API Attribute : public Val { } }; +using newObjectFuncType = Expr*( + IrContainer*, + std::vector, + std::vector, + std::vector); + //! A Expr represents a "computation." These are functions that takes inputs //! and produce outputs, inputs and outputs all being Vals. There are //! specializations of BinaryOp which takes 2 inputs and produces 1 output, and @@ -480,19 +486,14 @@ class TORCH_CUDA_CU_API Expr : public Statement { std::vector outputs, std::vector attributes); + virtual newObjectFuncType* newObjectFunc() const = 0; + // Creates a new instance of the expression with all its field copied. // Note that unlike IrCloner, this function only do a shallow copy Expr* shallowCopy() const; bool sameAs(const Statement* other) const override; - // Creates a new instance of the same expression type with the given inputs, - // outputs, and attributes. - virtual Expr* newObject( - std::vector inputs, - std::vector outputs, - std::vector attributes) const = 0; - // Input/output accessors const auto& inputs() const { return inputs_; @@ -529,9 +530,6 @@ class TORCH_CUDA_CU_API Expr : public Statement { template static void constDispatch(T handler, const Expr* const); - template - static void mutatorDispatch(T mutator, Expr*); - // TODO: Protect based on being in kernel container kir::Predicate* predicate() const; @@ -599,20 +597,26 @@ bool Val::isDefinitionType() const { #define NVFUSER_DECLARE_CLONE_AND_CREATE \ virtual Statement* clone(IrCloner* ir_cloner) const override; \ - virtual Expr* newObject( \ + static Expr* newObject( \ + IrContainer* container, \ std::vector inputs, \ std::vector outputs, \ - std::vector attributes) const override; - -#define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \ - Statement* ClassName::clone(IrCloner* ir_cloner) const { \ - return IrBuilder::clone(this, ir_cloner); \ - } \ - Expr* ClassName::newObject( \ - std::vector inputs, \ - std::vector outputs, \ - std::vector attributes) const { \ - return IrBuilder::create(inputs, outputs, attributes); \ + std::vector attributes); \ + virtual newObjectFuncType* newObjectFunc() const override { \ + return newObject; \ + } + +#define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \ + Statement* ClassName::clone(IrCloner* ir_cloner) const { \ + return IrBuilder::clone(this, ir_cloner); \ + } \ + Expr* ClassName::newObject( \ + IrContainer* container, \ + std::vector inputs, \ + std::vector outputs, \ + std::vector attributes) { \ + return IrBuilder::create( \ + container, inputs, outputs, attributes); \ } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index c490ef12daed..c7f28912e872 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -15,10 +15,6 @@ void OptOutMutator::mutate(Statement* s) { Statement::mutatorDispatch(this, s); } -void OptOutMutator::mutate(Expr* e) { - Expr::mutatorDispatch(this, e); -} - void OptOutMutator::mutate(Val* v) { Val::mutatorDispatch(this, v); } @@ -127,543 +123,60 @@ void OptOutMutator::mutate(kir::TensorIndex*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } -void OptOutMutator::mutate(FullOp* fop) { - Val* out = maybeMutated(fop->output(0)); - Val* fill_value = maybeMutated(fop->getFillValue()); - - if (out->sameAs(fop->output(0))) { - return; - } - auto container = fop->container(); - container->removeExpr(fop); - IrBuilder::create(container, out, fill_value, fop->dtype()); -} - -void OptOutMutator::mutate(SelectOp* sop) { - Val* out = maybeMutated(sop->output(0)); - Val* in = maybeMutated(sop->input(0)); - Val* index = maybeMutated(sop->input(1)); - IterDomain* select_axis = - maybeMutated(sop->getSelectAxis())->as(); - - if (out->sameAs(sop->output(0)) && in->sameAs(sop->output(0)) && - index->sameAs(sop->output(1)) && - select_axis->sameAs(sop->getSelectAxis())) { - return; - } - auto container = sop->container(); - container->removeExpr(sop); - IrBuilder::create(container, out, in, select_axis, index); -} - -void OptOutMutator::mutate(ARangeOp* aop) { - Val* out = maybeMutated(aop->output(0)); - - if (out->sameAs(aop->output(0))) { - return; - } - auto container = aop->container(); - container->removeExpr(aop); - IrBuilder::create( - container, - out, - aop->start(), - aop->end(), - aop->step(), - aop->dtype(), - aop->getLinearLogicalIndex()); -} - -void OptOutMutator::mutate(EyeOp* eop) { - Val* out = maybeMutated(eop->output(0)); - - if (out->sameAs(eop->output(0))) { - return; - } - auto container = eop->container(); - container->removeExpr(eop); - IrBuilder::create( - container, out, eop->dtype(), eop->getIndex1(), eop->getIndex2()); -} - -void OptOutMutator::mutate(UnaryOp* uop) { - Val* out = maybeMutated(uop->out()); - Val* in = maybeMutated(uop->in()); - - if (out->sameAs(uop->out()) && in->sameAs(uop->in())) { - return; - } - auto container = uop->container(); - auto uop_type = uop->getUnaryOpType(); - container->removeExpr(uop); - IrBuilder::create(container, uop_type, out, in); -} - -void OptOutMutator::mutate(BinaryOp* bop) { - Val* out = maybeMutated(bop->out()); - Val* lhs = maybeMutated(bop->lhs()); - Val* rhs = maybeMutated(bop->rhs()); - - if (out->sameAs(bop->out()) && lhs->sameAs(bop->lhs()) && - rhs->sameAs(bop->rhs())) { - return; - } - - auto container = bop->container(); - auto bop_type = bop->getBinaryOpType(); - container->removeExpr(bop); - IrBuilder::create(container, bop_type, out, lhs, rhs); -} - -void OptOutMutator::mutate(TernaryOp* top) { - Val* out = maybeMutated(top->out()); - Val* in1 = maybeMutated(top->in1()); - Val* in2 = maybeMutated(top->in2()); - Val* in3 = maybeMutated(top->in3()); - - if (out->sameAs(top->out()) && in1->sameAs(top->in1()) && - in2->sameAs(top->in2()) && in3->sameAs(top->in3())) { - return; - } - - auto container = top->container(); - auto top_type = top->getTernaryOpType(); - container->removeExpr(top); - IrBuilder::create(container, top_type, out, in1, in2, in3); -} - -void OptOutMutator::mutate(RNGOp* rop) { - Val* out = maybeMutated(rop->output(0)); - Val* philox_idx = maybeMutated(rop->getPhiloxIndex()); - - auto parameters = rop->getParameters(); - std::vector mutated_parameters; - bool all_mutated_same = true; - for (auto v : parameters) { - mutated_parameters.emplace_back(maybeMutated(v)); - all_mutated_same = all_mutated_same && mutated_parameters.back()->sameAs(v); - } - - if (out->sameAs(rop->output(0)) && - ((philox_idx == nullptr && rop->getPhiloxIndex() == nullptr) || - philox_idx->sameAs(rop->getPhiloxIndex())) && - all_mutated_same) { - return; - } - - auto container = rop->container(); - auto rop_type = rop->getRNGOpType(); - container->removeExpr(rop); - IrBuilder::create( - container, - rop_type, - out, - rop->dtype(), - mutated_parameters, - rop->getRNGOffset(), - philox_idx); -} - -void OptOutMutator::mutate(ReductionOp* rop) { - Val* out = maybeMutated(rop->out()); - Val* in = maybeMutated(rop->in()); - Val* init = rop->init(); - if (out->sameAs(rop->out()) && in->sameAs(rop->in()) && - init->sameAs(rop->init())) { - return; - } - - auto container = rop->container(); - auto rop_type = rop->getReductionOpType(); - container->removeExpr(rop); - IrBuilder::create( - container, rop_type, init, out, in, rop->isAllreduce()); -} - -void OptOutMutator::mutate(GroupedReductionOp* rop) { - bool is_same = true; - - std::vector outputs; - for (auto out : rop->outputs()) { - auto maybe_mutated = maybeMutated(out); - is_same = is_same && maybe_mutated->sameAs(out); - outputs.push_back(maybe_mutated); - } - - std::vector inputs; - for (auto in : rop->inputs()) { - auto maybe_mutated = maybeMutated(in); - is_same = is_same && maybe_mutated->sameAs(in); - inputs.push_back(maybe_mutated); - } - - std::vector init_vals; - for (auto init : rop->initVals()) { - auto maybe_mutated = maybeMutated(init); - is_same = is_same && maybe_mutated->sameAs(init); - init_vals.push_back(maybe_mutated); - } - - if (is_same) { - return; - } - - auto container = rop->container(); - const auto& rop_types = rop->getReductionOpTypes(); - container->removeExpr(rop); - IrBuilder::create( - container, rop_types, init_vals, outputs, inputs, rop->isAllreduce()); -} - -namespace { -inline bool compareOptional(Val* a, Val* b) { - if (!a || !b) { - return (!a && !b); - } - return a->sameAs(b); -} - -} // namespace - -void OptOutMutator::mutate(WelfordOp* wop) { - Val* out_avg = maybeMutated(wop->outAvg()); - Val* out_var = maybeMutated(wop->outVar()); - Val* out_N = maybeMutated(wop->outN()); - - Val* in_avg = maybeMutated(wop->inAvg()); - Val* in_var = wop->inVar() ? maybeMutated(wop->inVar()) : nullptr; - Val* in_N = maybeMutated(wop->inN()); - - Val* init_avg = wop->initAvg() ? maybeMutated(wop->initAvg()) : nullptr; - Val* init_var = wop->initVar() ? maybeMutated(wop->initVar()) : nullptr; - Val* init_N = maybeMutated(wop->initN()); - - const bool out_compare = out_avg->sameAs(wop->outAvg()) && - out_var->sameAs(wop->outVar()) && out_N->sameAs(wop->outN()); - const bool in_compare = in_avg->sameAs(wop->inAvg()) && - compareOptional(in_var, wop->inVar()) && in_N->sameAs(wop->inN()); - const bool init_compare = compareOptional(init_avg, wop->initAvg()) && - compareOptional(init_var, wop->initVar()) && init_N->sameAs(wop->initN()); - - if (out_compare && init_compare && in_compare) { - return; - } - - auto container = wop->container(); - container->removeExpr(wop); - IrBuilder::create( - container, - out_avg, - out_var, - out_N, - in_avg, - in_var, - in_N, - init_avg, - init_var, - init_N, - wop->isAllreduce()); -} - -void OptOutMutator::mutate(GroupedWelfordOp* wop) { - bool is_same = true; - - std::vector output_vals; - for (const auto& out : wop->outputVals()) { - auto maybe_mutated = - out.transform([&](Val* val) { return maybeMutated(val); }); - is_same = is_same && maybe_mutated.sameAs(out); - output_vals.push_back(maybe_mutated); - } - - std::vector input_vals; - for (const auto& inp : wop->inputVals()) { - auto maybe_mutated = - inp.transform([&](Val* val) { return maybeMutated(val); }); - is_same = is_same && maybe_mutated.sameAs(inp); - input_vals.push_back(maybe_mutated); - } - - std::vector init_vals; - for (const auto& init : wop->initVals()) { - auto maybe_mutated = - init.transform([&](Val* val) { return maybeMutated(val); }); - is_same = is_same && maybe_mutated.sameAs(init); - init_vals.push_back(maybe_mutated); - } - - if (is_same) { - return; - } - - auto container = wop->container(); - container->removeExpr(wop); - IrBuilder::create( - container, output_vals, input_vals, init_vals, wop->isAllreduce()); -} - -void OptOutMutator::mutate(MmaOp* mma) { - Val* out = maybeMutated(mma->out()); - Val* in_a = maybeMutated(mma->inA()); - Val* in_b = maybeMutated(mma->inB()); - Val* init = mma->init(); - - if (out->sameAs(mma->out()) && in_a->sameAs(mma->inA()) && - in_b->sameAs(mma->inB())) { - return; - } - - auto container = mma->container(); - auto options = mma->options(); - container->removeExpr(mma); - C10_UNUSED auto new_mma = - IrBuilder::create(container, out, in_a, in_b, init, options); -} - -void OptOutMutator::mutate(LoadStoreOp* ldst) { - Val* out = maybeMutated(ldst->out()); - Val* in = maybeMutated(ldst->in()); - auto op_type = ldst->opType(); - - if (out->sameAs(ldst->out()) && in->sameAs(ldst->in())) { - return; - } - - auto container = ldst->container(); - container->removeExpr(ldst); - IrBuilder::create(container, op_type, out, in); -} - -void OptOutMutator::mutate(BroadcastOp* bop) { - Val* out = maybeMutated(bop->out()); - Val* in = maybeMutated(bop->in()); - - if (out->sameAs(bop->out()) && in->sameAs(bop->in())) { - return; +void OptOutMutator::mutate(Expr* op) { + std::vector mutated_inputs; + mutated_inputs.reserve(op->inputs().size()); + for (auto input : op->inputs()) { + mutated_inputs.emplace_back(maybeMutated(input)); } - auto container = bop->container(); - auto flags = bop->getBroadcastDimFlags(); - container->removeExpr(bop); - IrBuilder::create(container, out, in, flags); -} - -void OptOutMutator::mutate(SqueezeOp* sop) { - Val* out = maybeMutated(sop->out()); - Val* in = maybeMutated(sop->in()); - - if (out->sameAs(sop->out()) && in->sameAs(sop->in())) { - return; + std::vector mutated_outputs; + mutated_outputs.reserve(op->outputs().size()); + for (auto output : op->outputs()) { + mutated_outputs.emplace_back(maybeMutated(output)); } - auto container = sop->container(); - auto flags = sop->getSqueezeDimFlags(); - container->removeExpr(sop); - IrBuilder::create(container, out, in, flags); -} - -void OptOutMutator::mutate(TransposeOp* top) { - TensorView* out = maybeMutated(top->out())->as(); - TensorView* in = maybeMutated(top->in())->as(); - - if (out->sameAs(top->out()) && in->sameAs(top->in())) { - return; + std::vector mutated_attrs; + mutated_attrs.reserve(op->attributes().size()); + for (auto attr : op->attributes()) { + if (auto attr_val = dynamic_cast(attr)) { + mutated_attrs.emplace_back(maybeMutated(attr_val)); + } else { + mutated_attrs.emplace_back(attr); + } } - auto container = top->container(); - auto new2old = top->new2old(); - container->removeExpr(top); - IrBuilder::create(container, out, in, new2old); -} - -void OptOutMutator::mutate(ExpandOp* eop) { - bool is_same = true; - - TensorView* out = maybeMutated(eop->out())->as(); - is_same = is_same && out->sameAs(eop->out()); - TensorView* in = maybeMutated(eop->in())->as(); - is_same = is_same && in->sameAs(eop->in()); - - std::vector expanded_extents; - expanded_extents.reserve(eop->expanded_extents().size()); - for (auto expanded_extent : eop->expanded_extents()) { - expanded_extents.push_back(maybeMutated(expanded_extent)); - if (!expanded_extents.back()->sameAs(expanded_extent)) { - is_same = false; + bool all_same = true; + for (auto i : c10::irange(op->outputs().size())) { + if (!all_same) { + break; } + all_same = all_same && mutated_outputs[i]->sameAs(op->output(i)); } - - if (is_same) { - return; + for (auto i : c10::irange(op->inputs().size())) { + if (!all_same) { + break; + } + all_same = all_same && mutated_inputs[i]->sameAs(op->input(i)); } - - auto container = eop->container(); - container->removeExpr(eop); - IrBuilder::create(container, out, in, expanded_extents); -} - -void OptOutMutator::mutate(ShiftOp* sop) { - Val* out = maybeMutated(sop->out())->asVal(); - Val* in = maybeMutated(sop->in())->asVal(); - - if (out->sameAs(sop->out()) && in->sameAs(sop->in())) { - return; + for (auto i : c10::irange(op->attributes().size())) { + if (!all_same) { + break; + } + bool same = + ((mutated_attrs[i] == nullptr) && (op->attribute(i) == nullptr)) || + mutated_attrs[i]->sameAs(op->attribute(i)); + all_same = all_same && same; } - auto offsets = sop->offsets(); - auto pad_width = sop->padWidth(); - auto container = sop->container(); - container->removeExpr(sop); - IrBuilder::create(container, out, in, offsets, pad_width); -} - -void OptOutMutator::mutate(GatherOp* op) { - Val* out = maybeMutated(op->out())->asVal(); - Val* in = maybeMutated(op->in())->asVal(); - - if (out->sameAs(op->out()) && in->sameAs(op->in())) { + if (all_same) { return; } - auto window_shape = op->windowShape(); - auto pad_width = op->padWidth(); auto container = op->container(); + auto newObjectFunc = op->newObjectFunc(); container->removeExpr(op); - IrBuilder::create(container, out, in, window_shape, pad_width); -} - -void OptOutMutator::mutate(ViewAsScalar* vop) { - TensorView* out = maybeMutated(vop->out())->as(); - TensorView* in = maybeMutated(vop->in())->as(); - IterDomain* vid = maybeMutated(vop->vector_id())->as(); - Val* idx = maybeMutated(vop->index()); - - if (out->sameAs(vop->out()) && in->sameAs(vop->in()) && - vid->sameAs(vop->vector_id()) && - ((idx == nullptr && vop->index() == nullptr) || - idx->sameAs(vop->index()))) { - return; - } - - auto container = vop->container(); - container->removeExpr(vop); - IrBuilder::create(container, out, in, vid, idx); -} - -void OptOutMutator::mutate(ViewOp* vop) { - TensorView* out = maybeMutated(vop->out())->as(); - TensorView* in = maybeMutated(vop->in())->as(); - - if (out->sameAs(vop->out()) && in->sameAs(vop->in())) { - return; - } - - auto container = vop->container(); - container->removeExpr(vop); - IrBuilder::create(container, out, in); -} - -void OptOutMutator::mutate(Split* s) { - IterDomain* ot = maybeMutated(s->outer())->as(); - IterDomain* inr = maybeMutated(s->inner())->as(); - IterDomain* in = maybeMutated(s->in())->as(); - Val* fact = maybeMutated(s->factor())->as(); - Val* start_offset = maybeMutated(s->startOffset()); - Val* stop_offset = maybeMutated(s->stopOffset()); - - if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && - in->sameAs(s->in()) && areEqualScalars(fact, s->factor()) && - start_offset->sameAs(s->startOffset()) && - stop_offset->sameAs(s->stopOffset())) { - return; - } - - auto container = s->container(); - auto inner_split = s->innerSplit(); - container->removeExpr(s); - C10_UNUSED auto new_node = IrBuilder::create( - container, ot, inr, in, fact, inner_split, start_offset, stop_offset); -} - -void OptOutMutator::mutate(Merge* m) { - IterDomain* ot = maybeMutated(m->out())->as(); - IterDomain* otr = maybeMutated(m->outer())->as(); - IterDomain* in = maybeMutated(m->inner())->as(); - - if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) && - in->sameAs(m->inner())) { - return; - } - - auto container = m->container(); - container->removeExpr(m); - C10_UNUSED auto new_node = IrBuilder::create(container, ot, otr, in); -} - -void OptOutMutator::mutate(Swizzle2D* m) { - IterDomain* outx = maybeMutated(m->outX())->as(); - IterDomain* outy = maybeMutated(m->outY())->as(); - - IterDomain* inx = maybeMutated(m->inX())->as(); - IterDomain* iny = maybeMutated(m->inY())->as(); - - auto swizzle_type = m->swizzleType(); - - if (outx->sameAs(m->outX()) && outy->sameAs(m->outY()) && - inx->sameAs(m->inX()) && iny->sameAs(m->inY())) { - return; - } - auto container = m->container(); - container->removeExpr(m); - FusionGuard::getCurFusion()->removeExpr(m); - C10_UNUSED auto new_node = IrBuilder::create( - container, outx, outy, inx, iny, swizzle_type); -} - -void OptOutMutator::mutate(kir::Allocate*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::BlockSync*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GridSync*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::CpAsyncWait*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::CpAsyncCommit*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::InitMagicZero*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::UpdateMagicZero*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::ForLoop*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::IfThenElse*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GridReduction*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GroupedGridReduction*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GridBroadcast*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GridWelford*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GroupedGridWelford*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::AllocateFusedReduction*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); + newObjectFunc(container, mutated_inputs, mutated_outputs, mutated_attrs); } void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) { diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 42e4925d1a03..413cdb902da9 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -2500,8 +2500,6 @@ TEST_F(NVFuserTest, FusionGeluBwdReduction_CUDA) { fusion.addOutput(t26); fusion.addOutput(t27); - fusion.printMath(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(1); @@ -2524,7 +2522,7 @@ TEST_F(NVFuserTest, FusionGeluBwdReduction_CUDA) { auto reduction_params = getReductionHeuristics(&fusion, {at_grad, at_xvar}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); scheduleReduction(&fusion, *reduction_params); - fusion.printKernel(); + FusionExecutor fe; fe.compileFusion(&fusion, {at_grad, at_xvar}, reduction_params->lparams); auto cg_outputs = fe.runFusion({at_grad, at_xvar}, reduction_params->lparams);