From ee2f23430791a277370ba886c99d375b255ac211 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 17 Nov 2022 12:49:56 -0800 Subject: [PATCH 01/10] structured expr squash merge --- torch/csrc/jit/codegen/cuda/codegen.cpp | 10 +- torch/csrc/jit/codegen/cuda/dispatch.cpp | 4 + torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 41 +- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 91 +- torch/csrc/jit/codegen/cuda/ir_builder.cpp | 68 - torch/csrc/jit/codegen/cuda/ir_builder.h | 15 +- torch/csrc/jit/codegen/cuda/ir_builder_key.h | 29 + torch/csrc/jit/codegen/cuda/ir_cloner.cpp | 160 +-- torch/csrc/jit/codegen/cuda/ir_cloner.h | 85 +- .../jit/codegen/cuda/ir_interface_nodes.h | 10 + .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 809 +++++------- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 1130 ++++------------- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 569 ++++----- torch/csrc/jit/codegen/cuda/kernel_ir.h | 447 +++---- torch/csrc/jit/codegen/cuda/lower_index.cpp | 7 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 2 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 2 + .../csrc/jit/codegen/cuda/test/test_gpu1.cpp | 10 +- torch/csrc/jit/codegen/cuda/type.h | 3 +- 19 files changed, 1279 insertions(+), 2213 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/ir_builder_key.h diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 188167c80b13..b71a0159b457 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -1954,10 +1954,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { ArgumentBuilder read_preds; ArgumentBuilder write_preds; + auto output_vals = grouped_gwop->outputVals(); + auto input_vals = grouped_gwop->inputVals(); + auto init_vals = grouped_gwop->initVals(); + for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) { - const auto& output = grouped_gwop->outputVals().at(expr_index); - const auto& input = grouped_gwop->inputVals().at(expr_index); - const auto& init = grouped_gwop->initVals().at(expr_index); + const auto& output = output_vals.at(expr_index); + const auto& input = input_vals.at(expr_index); + const auto& init = init_vals.at(expr_index); for (const auto& group_index : c10::irange(index_replacement_maps.size())) { diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 8c2172f3f383..40fdf11a48cf 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -307,6 +307,10 @@ void Val::constDispatch(T handler, const Val* val) { case ValType::TensorIndex: ptr(handler)->handle(val->as()); return; + case ValType::Plain: + // Plain Val is just a wrapper for non-IR data, so there is nothing to + // handle + return; default: break; } diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 48c6e0959b2f..d78fa5b9623c 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -34,6 +34,8 @@ Statement::Statement(const Statement* src, IrCloner* ir_cloner) { ir_container_ = ir_cloner->container(); } +DEFINE_CLONE(Statement) + void Statement::setName(IrContainerPasskey, StmtNameType name) { name_ = name; } @@ -98,6 +100,8 @@ Val::Val(IrBuilderPasskey passkey, ValType _vtype, DataType _dtype) Val::Val(const Val* src, IrCloner* ir_cloner) : Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {} +DEFINE_CLONE(Val) + const std::vector& Val::uses() const { if (vtype_ == ValType::TensorView) { if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) { @@ -315,7 +319,27 @@ Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {} Expr::Expr(const Expr* src, IrCloner* ir_cloner) : Statement(src, ir_cloner), inputs_(ir_cloner->clone(src->inputs_)), - outputs_(ir_cloner->clone(src->outputs_)) {} + outputs_(ir_cloner->clone(src->outputs_)), + attributes_(ir_cloner->clone(src->attributes_)) {} + +Expr::Expr( + IrBuilderPasskey passkey, + std::vector inputs, + std::vector outputs, + std::vector attributes) + : Statement(passkey), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), + attributes_(std::move(attributes)) {} + +Expr* Expr::shallowCopy() const { + auto result = newObject(inputs(), outputs(), attributes()); + if (container()->isA()) { + result->predicate_ = predicate_; + result->write_predicate_ = write_predicate_; + } + return result; +} bool Expr::sameAs(const Statement* other) const { if (this == other) { @@ -329,7 +353,8 @@ bool Expr::sameAs(const Statement* other) const { return false; } if (inputs().size() != other_expr->inputs().size() || - outputs().size() != other_expr->outputs().size()) { + outputs().size() != other_expr->outputs().size() || + attributes().size() != other_expr->attributes().size()) { return false; } for (const auto i : c10::irange(inputs().size())) { @@ -337,6 +362,11 @@ bool Expr::sameAs(const Statement* other) const { return false; } } + for (const auto i : c10::irange(attributes().size())) { + if (!attribute(i)->sameAs(other_expr->attribute(i))) { + return false; + } + } return true; } @@ -376,13 +406,6 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) { return result; } -void Expr::copyPredicatesFrom(const Expr* expr) { - if (container()->isA()) { - predicate_ = expr->predicate_; - write_predicate_ = expr->write_predicate_; - } -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index f612a72352f3..3f39e2204703 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -70,6 +71,14 @@ class ExprPasskey { TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept; +#define DECLARE_CLONE \ + virtual Statement* clone(IrCloner* ir_cloner) const override; + +#define DEFINE_CLONE(ClassName) \ + Statement* ClassName::clone(IrCloner* ir_cloner) const { \ + return IrBuilder::clone(this, ir_cloner); \ + } + //! Statement is the highest level node representation. Everything that is //! considered "IR" will be derived from this class at some point. Both Values //! and Expr's are a Statement. If there will ever be any more fundamental @@ -159,6 +168,8 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { std::string toString() const; std::string toInlineString() const; + virtual Statement* clone(IrCloner* ir_cloner) const; + protected: Statement(IrBuilderPasskey); @@ -353,6 +364,8 @@ class TORCH_CUDA_CU_API Val : public Statement { void resolveIndexDtype(); + DECLARE_CLONE + protected: friend Fusion; @@ -391,6 +404,31 @@ class TORCH_CUDA_CU_API Val : public Statement { int evaluator_index_ = -1; }; +//! A Val object that stores a plain data. Note that this class is only intended +//! to hold non-IR data, such as DataType, std::vector, etc. Please don't +//! use this class to hold IR nodes or their pointers. +template +class TORCH_CUDA_CU_API PlainVal : public Val { + public: + T value; + PlainVal(IrBuilderPasskey passkey, const T& value) + : Val(passkey, ValType::Plain), value(value) {} + PlainVal(const PlainVal* src, IrCloner* ir_cloner) + : Val(src, ir_cloner), value(src->value) {} + template + PlainVal(IrBuilderPasskey passkey, Args... args) + : Val(passkey, ValType::Plain), value(std::forward(args)...) {} + + DECLARE_CLONE + + bool sameAs(const Statement* other) const override { + if (auto pv = dynamic_cast(other)) { + return pv->value == value; + } + return false; + } +}; + //! A Expr represents a "computation." These are functions that takes inputs //! and produce outputs, inputs and outputs all being Vals. There are //! specializations of BinaryOp which takes 2 inputs and produces 1 output, and @@ -436,12 +474,25 @@ class TORCH_CUDA_CU_API Expr : public Statement { Expr(const Expr* src, IrCloner* ir_cloner); + Expr( + IrBuilderPasskey, + std::vector inputs, + std::vector outputs, + std::vector attributes); + // Creates a new instance of the expression with all its field copied. // Note that unlike IrCloner, this function only do a shallow copy - virtual Expr* shallowCopy() const = 0; + Expr* shallowCopy() const; bool sameAs(const Statement* other) const override; + // Creates a new instance of the same expression type with the given inputs, + // outputs, and attributes. + virtual Expr* newObject( + std::vector inputs, + std::vector outputs, + std::vector attributes) const = 0; + // Input/output accessors const auto& inputs() const { return inputs_; @@ -451,12 +502,20 @@ class TORCH_CUDA_CU_API Expr : public Statement { return outputs_; } + const auto& attributes() const { + return attributes_; + } + auto input(size_t index) const { - return inputs_[index]; + return inputs_.at(index); } auto output(size_t index) const { - return outputs_[index]; + return outputs_.at(index); + } + + auto attribute(size_t index) const { + return attributes_.at(index); } // Dispatch functions, definitions in dispatch.cpp @@ -494,8 +553,6 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container void setWritePredicate(kir::Predicate* write_predicate); - void copyPredicatesFrom(const Expr* expr); - // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); @@ -508,6 +565,11 @@ class TORCH_CUDA_CU_API Expr : public Statement { outputs_.push_back(output); } + // TODO: Add Fusion passkey + void addAttribute(Val* attr) { + attributes_.push_back(attr); + } + ExprPasskey exprPasskey() { return ExprPasskey(); } @@ -515,6 +577,7 @@ class TORCH_CUDA_CU_API Expr : public Statement { private: std::vector inputs_; std::vector outputs_; + std::vector attributes_; kir::Predicate* predicate_ = nullptr; @@ -530,6 +593,24 @@ bool Val::isDefinitionType() const { return false; } +#define DECLARE_CLONE_AND_CREATE \ + virtual Statement* clone(IrCloner* ir_cloner) const override; \ + virtual Expr* newObject( \ + std::vector inputs, \ + std::vector outputs, \ + std::vector attributes) const override; + +#define DEFINE_CLONE_AND_CREATE(ClassName) \ + Statement* ClassName::clone(IrCloner* ir_cloner) const { \ + return IrBuilder::clone(this, ir_cloner); \ + } \ + Expr* ClassName::newObject( \ + std::vector inputs, \ + std::vector outputs, \ + std::vector attributes) const { \ + return IrBuilder::create(inputs, outputs, attributes); \ + } + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index 7b58a7d444f7..76542ba6e022 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -8,74 +8,6 @@ namespace jit { namespace fuser { namespace cuda { -//! Clone an IR node, forwarding the arguments to the IrCloner constructor. -template -T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { - TORCH_INTERNAL_ASSERT( - ir_cloner != nullptr, - "Cannot use create when a cloner object is set. Use clone."); - - TORCH_INTERNAL_ASSERT( - ir_cloner->container() != nullptr, - "Cloner doesn't have a valid container to store cloned object."); - - T* dest = new T(src, ir_cloner); - const Statement* src_stmt = dynamic_cast(src); - Statement* dest_stmt = dynamic_cast(dest); - - auto dest_container = ir_cloner->container(); - auto src_container = src_stmt->container(); - - dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); - - if (src_container != dest_container) { - dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); - } - - ir_cloner->registerClone(src_stmt, dest_stmt); - - return dest; -} - -#define IR_BUILDER_INSTANTIATE(T) \ - template T* IrBuilder::clone(const T* src, IrCloner* ir_cloner); - -// Vals -IR_BUILDER_INSTANTIATE(IterDomain) -IR_BUILDER_INSTANTIATE(TensorDomain) -IR_BUILDER_INSTANTIATE(TensorView) -IR_BUILDER_INSTANTIATE(Bool) -IR_BUILDER_INSTANTIATE(Double) -IR_BUILDER_INSTANTIATE(Int) -IR_BUILDER_INSTANTIATE(ComplexDouble) -IR_BUILDER_INSTANTIATE(NamedScalar) - -// Exprs -IR_BUILDER_INSTANTIATE(Split) -IR_BUILDER_INSTANTIATE(Merge) -IR_BUILDER_INSTANTIATE(Swizzle2D) -IR_BUILDER_INSTANTIATE(TransposeOp) -IR_BUILDER_INSTANTIATE(ExpandOp) -IR_BUILDER_INSTANTIATE(ShiftOp) -IR_BUILDER_INSTANTIATE(GatherOp) -IR_BUILDER_INSTANTIATE(ViewAsScalar) -IR_BUILDER_INSTANTIATE(ViewOp) -IR_BUILDER_INSTANTIATE(FullOp) -IR_BUILDER_INSTANTIATE(ARangeOp) -IR_BUILDER_INSTANTIATE(EyeOp) -IR_BUILDER_INSTANTIATE(UnaryOp) -IR_BUILDER_INSTANTIATE(BinaryOp) -IR_BUILDER_INSTANTIATE(TernaryOp) -IR_BUILDER_INSTANTIATE(SelectOp) -IR_BUILDER_INSTANTIATE(RNGOp) -IR_BUILDER_INSTANTIATE(ReductionOp) -IR_BUILDER_INSTANTIATE(GroupedReductionOp) -IR_BUILDER_INSTANTIATE(WelfordOp) -IR_BUILDER_INSTANTIATE(LoadStoreOp) -IR_BUILDER_INSTANTIATE(MmaOp) -IR_BUILDER_INSTANTIATE(BroadcastOp) -IR_BUILDER_INSTANTIATE(SqueezeOp) - Val* IrBuilder::newResult(DataType dtype) { switch (dtype) { case DataType::Bool: diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.h b/torch/csrc/jit/codegen/cuda/ir_builder.h index f122232f8fb8..77f7018708aa 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.h +++ b/torch/csrc/jit/codegen/cuda/ir_builder.h @@ -2,6 +2,7 @@ #include #include +#include #include namespace torch { @@ -15,20 +16,6 @@ class Kernel; class IrCloner; -// Passkey for builder to register properties with statements, and to call -// functions in IrContainer -class TORCH_CUDA_CU_API IrBuilderPasskey { - friend class IrBuilder; - - public: - // TODO: Collapse ir_container and Kernel once Kernel inherits from - // IrContainer - IrContainer* const ir_container_ = nullptr; - - private: - explicit IrBuilderPasskey(IrContainer* ir_container); -}; - //! IR builder interface class TORCH_CUDA_CU_API IrBuilder { public: diff --git a/torch/csrc/jit/codegen/cuda/ir_builder_key.h b/torch/csrc/jit/codegen/cuda/ir_builder_key.h new file mode 100644 index 000000000000..95c8f21ea4f1 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ir_builder_key.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class IrContainer; + +// Passkey for builder to register properties with statements, and to call +// functions in IrContainer +class TORCH_CUDA_CU_API IrBuilderPasskey { + friend class IrBuilder; + + public: + // TODO: Collapse ir_container and Kernel once Kernel inherits from + // IrContainer + IrContainer* const ir_container_ = nullptr; + + private: + explicit IrBuilderPasskey(IrContainer* ir_container); +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 1a538f88997d..7bdbfc3774e1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -21,12 +21,7 @@ Statement* IrCloner::clone(const Statement* statement) { if (it != clones_map_.end()) { return it->second; } else { - // Clone the new node, saving/restoring this->clone_ - // since the cloning can be reentrant - auto saved_clone = clone_; - handle(statement); - auto new_node = clone_; - clone_ = saved_clone; + auto new_node = handle(statement); // The base cloning constructor (Statement) should have // registered the new node. Failure to do so indicates @@ -44,144 +39,8 @@ void IrCloner::registerClone(const Statement* src, Statement* clone) { TORCH_CHECK(clones_map_.insert({src, clone}).second); } -void IrCloner::handle(const Statement* s) { - OptInConstDispatch::handle(s); -} - -void IrCloner::handle(const Val* v) { - OptInConstDispatch::handle(v); -} - -void IrCloner::handle(const Expr* e) { - OptInConstDispatch::handle(e); -} - -void IrCloner::handle(const TensorDomain* td) { - clone_ = IrBuilder::clone(td, this); -} - -void IrCloner::handle(const IterDomain* id) { - clone_ = IrBuilder::clone(id, this); -} - -void IrCloner::handle(const Bool* b) { - clone_ = IrBuilder::clone(b, this); -} - -void IrCloner::handle(const Double* d) { - clone_ = IrBuilder::clone(d, this); -} - -void IrCloner::handle(const Int* i) { - clone_ = IrBuilder::clone(i, this); -} - -void IrCloner::handle(const ComplexDouble* c) { - clone_ = IrBuilder::clone(c, this); -} - -void IrCloner::handle(const NamedScalar* named_scalar) { - clone_ = IrBuilder::clone(named_scalar, this); -} - -void IrCloner::handle(const TensorView* tv) { - clone_ = IrBuilder::clone(tv, this); -} - -void IrCloner::handle(const FullOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ARangeOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const EyeOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const UnaryOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const BinaryOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const TernaryOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const SelectOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const RNGOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const BroadcastOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const SqueezeOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ReductionOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const GroupedReductionOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const WelfordOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const LoadStoreOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const MmaOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const TransposeOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ExpandOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ShiftOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const GatherOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ViewAsScalar* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ViewOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const Split* split) { - clone_ = IrBuilder::clone(split, this); -} - -void IrCloner::handle(const Merge* merge) { - clone_ = IrBuilder::clone(merge, this); -} - -void IrCloner::handle(const Swizzle2D* swizzle) { - clone_ = IrBuilder::clone(swizzle, this); +Statement* IrCloner::handle(const Statement* s) { + return s->clone(this); } TensorView* RecomputeTv::recompute(TensorView* tv) { @@ -228,11 +87,18 @@ RecomputeTv::RecomputeTv(Fusion* fusion, std::vector exprs) } // Clone the expressions for (auto expr : exprs) { - IrCloner::handle(expr); + handle(expr); + } +} + +Statement* RecomputeTv::handle(const Statement* s) { + if (s->isA()) { + return handle(s->as()); } + return s->clone(this); } -void RecomputeTv::handle(const TensorDomain* td) { +Statement* RecomputeTv::handle(const TensorDomain* td) { // Make sure to recompute the history of the iteration domains, explicitly go // through the expressions and send them to IrCloner. auto exprs = @@ -241,7 +107,7 @@ void RecomputeTv::handle(const TensorDomain* td) { for (auto expr : exprs) { IrCloner::handle(expr); } - IrCloner::handle(td); + return IrCloner::handle(td); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index 9e54b074acc7..e7852906b1c2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -20,13 +20,14 @@ class IrContainer; //! Fusion copy operations and the and limited scope of RecomputeTv below. //! It is not intended for any other uses. //! -class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { +class TORCH_CUDA_CU_API IrCloner { friend class Statement; friend class IrBuilder; public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit IrCloner(IrContainer* container); + virtual ~IrCloner() {} Statement* clone(const Statement* statement); @@ -53,46 +54,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { protected: void registerClone(const Statement* src, Statement* clone); - - void handle(const Statement*) override; - void handle(const Val*) override; - void handle(const Expr*) override; - - void handle(const TensorDomain*) override; - void handle(const TensorView*) override; - void handle(const IterDomain*) override; - - void handle(const Bool*) override; - void handle(const Double*) override; - void handle(const Int*) override; - void handle(const ComplexDouble*) override; - void handle(const NamedScalar*) override; - - void handle(const FullOp*) override; - void handle(const ARangeOp*) override; - void handle(const EyeOp*) override; - void handle(const UnaryOp*) override; - void handle(const BinaryOp*) override; - void handle(const TernaryOp*) override; - void handle(const SelectOp*) override; - void handle(const RNGOp*) override; - void handle(const BroadcastOp*) override; - void handle(const SqueezeOp*) override; - void handle(const ReductionOp*) override; - void handle(const GroupedReductionOp*) override; - void handle(const WelfordOp*) override; - void handle(const LoadStoreOp*) override; - void handle(const MmaOp*) override; - void handle(const TransposeOp*) override; - void handle(const ExpandOp*) override; - void handle(const ShiftOp*) override; - void handle(const GatherOp*) override; - void handle(const ViewAsScalar*) override; - void handle(const ViewOp*) override; - - void handle(const Split*) override; - void handle(const Merge*) override; - void handle(const Swizzle2D*) override; + virtual Statement* handle(const Statement* s); protected: // We keep track of the original -> clone map so we don't @@ -103,11 +65,6 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { // The destination Fusion container IrContainer* ir_container_ = nullptr; - // The dispatch interface doesn't allow returning values from - // individual `handle()` methods, so they are storing the - // result here - Statement* clone_ = nullptr; - // Builder to make all the new nodes IrBuilder builder_; }; @@ -122,12 +79,44 @@ class RecomputeTv : private IrCloner { private: RecomputeTv(Fusion* fusion, std::vector exprs); - - void handle(const TensorDomain*) final; + virtual Statement* handle(const Statement* s) override; + Statement* handle(const TensorDomain*); Fusion* fusion_; }; +//! Clone an IR node, forwarding the arguments to the IrCloner constructor. +template +T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { + TORCH_INTERNAL_ASSERT( + ir_cloner != nullptr, + "Cannot use create when a cloner object is set. Use clone."); + + TORCH_INTERNAL_ASSERT( + ir_cloner->container() != nullptr, + "Cloner doesn't have a valid container to store cloned object."); + + T* dest = new T(src, ir_cloner); + const Statement* src_stmt = dynamic_cast(src); + Statement* dest_stmt = dynamic_cast(dest); + + auto dest_container = ir_cloner->container(); + auto src_container = src_stmt->container(); + + dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); + + if (src_container != dest_container) { + dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); + } + + ir_cloner->registerClone(src_stmt, dest_stmt); + + return dest; +} + +template +DEFINE_CLONE(PlainVal) + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index b6b32e4d7a9f..5b262a426b58 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -38,6 +38,8 @@ class TORCH_CUDA_CU_API Bool : public Val { Bool(const Bool* src, IrCloner* ir_cloner); + DECLARE_CLONE + bool isSymbolic() const { return !(maybe_value_.has_value()); } @@ -69,6 +71,8 @@ class TORCH_CUDA_CU_API Double : public Val { Double(const Double* src, IrCloner* ir_cloner); + DECLARE_CLONE + bool isSymbolic() const { return !(maybe_value_.has_value()); } @@ -99,6 +103,8 @@ class TORCH_CUDA_CU_API Int : public Val { Int(const Int* src, IrCloner* ir_cloner); + DECLARE_CLONE + bool isSymbolic() const { return !(maybe_value_.has_value()); } @@ -132,6 +138,8 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { ComplexDouble(const ComplexDouble* src, IrCloner* ir_cloner); + DECLARE_CLONE + bool isSymbolic() const { return !(maybe_value_.has_value()); } @@ -209,6 +217,8 @@ class TORCH_CUDA_CU_API TensorView : public Val { TensorView(const TensorView* src, IrCloner* ir_cloner); + DECLARE_CLONE + TensorDomain* domain() const { return domain_; } diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 2d572bbe2d61..4f3a593e95e6 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -32,33 +32,29 @@ bool areEqualScalars(Val* v1, Val* v2); class TORCH_CUDA_CU_API FullOp : public Expr { public: - FullOp(IrBuilderPasskey, Val* out, Val* fill_value, DataType dtype); + using Expr::Expr; - FullOp(const FullOp* src, IrCloner* ir_cloner); + FullOp(IrBuilderPasskey, Val* out, Val* fill_value, DataType dtype); - Expr* shallowCopy() const override; - - bool sameAs(const Statement* other) const override; + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "FullOp"; } DataType dtype() const { - return dtype_; + return attribute(0)->as>()->value; } Val* getFillValue() const { - return fill_value_; + return inputs().back(); } - - private: - const DataType dtype_; - Val* fill_value_; }; class TORCH_CUDA_CU_API SelectOp : public Expr { public: + using Expr::Expr; + SelectOp( IrBuilderPasskey, Val* out, @@ -66,30 +62,25 @@ class TORCH_CUDA_CU_API SelectOp : public Expr { IterDomain* select_id, Val* index); - SelectOp(const SelectOp* src, IrCloner* ir_cloner); - - Expr* shallowCopy() const override; - - bool sameAs(const Statement* other) const override; + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "SelectOp"; } - std::unordered_map getIndexOverridingMap() const { - return {{select_id_, input(1)}}; - } - IterDomain* getSelectAxis() const { - return select_id_; + return attribute(0)->as(); } - private: - IterDomain* select_id_; + std::unordered_map getIndexOverridingMap() const { + return {{getSelectAxis(), input(1)}}; + } }; class TORCH_CUDA_CU_API ARangeOp : public Expr { public: + using Expr::Expr; + ARangeOp( IrBuilderPasskey, Val* out, @@ -99,46 +90,31 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { DataType dtype, Val* linear_index = nullptr); - ARangeOp(const ARangeOp* src, IrCloner* ir_cloner); - - Expr* shallowCopy() const override; - - bool sameAs(const Statement* other) const override; + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ARangeOp"; } DataType dtype() const { - return dtype_; + return attribute(0)->as>()->value; } Val* start() const { - return start_; + return input(0); } Val* end() const { - return end_; + return input(1); } Val* step() const { - return step_; + return input(2); } Val* getLinearLogicalIndex() const { - return linear_index_; + return attribute(1); } - - void setLinearIndex(Val* index) { - linear_index_ = index; - } - - private: - const DataType dtype_; - Val* start_; - Val* end_; - Val* step_; - Val* linear_index_ = nullptr; }; // Tensor factory for generating identity matrices like @@ -161,6 +137,8 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { // [0, 0, 1, 0]] class TORCH_CUDA_CU_API EyeOp : public Expr { public: + using Expr::Expr; + EyeOp( IrBuilderPasskey, Val* out, @@ -168,40 +146,23 @@ class TORCH_CUDA_CU_API EyeOp : public Expr { Val* index1 = nullptr, Val* index2 = nullptr); - EyeOp(const EyeOp* src, IrCloner* ir_cloner); - - Expr* shallowCopy() const override; - - bool sameAs(const Statement* other) const override; + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "EyeOp"; } DataType dtype() const { - return dtype_; + return attribute(0)->as>()->value; } Val* getIndex1() const { - return index1_; - } - - void setIndex1(Val* index) { - index1_ = index; + return attribute(1); } Val* getIndex2() const { - return index2_; - } - - void setIndex2(Val* index) { - index2_ = index; + return attribute(2); } - - private: - const DataType dtype_; - Val* index1_ = nullptr; - Val* index2_ = nullptr; }; //! A specialization for Unary operations. Unary operations take in a single @@ -212,38 +173,26 @@ class TORCH_CUDA_CU_API EyeOp : public Expr { //! 4) split/merge class TORCH_CUDA_CU_API UnaryOp : public Expr { public: - UnaryOp( - IrBuilderPasskey, - UnaryOpType type, - Val* out, - Val* in, - int rng_offset = -1); + using Expr::Expr; - UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); + UnaryOp(IrBuilderPasskey, UnaryOpType type, Val* out, Val* in); - Expr* shallowCopy() const override; + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "UnaryOp"; } Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } UnaryOpType getUnaryOpType() const { - return unary_op_type_; + return attribute(0)->as>()->value; } - - bool sameAs(const Statement* other) const override; - - private: - const UnaryOpType unary_op_type_; - Val* const out_ = nullptr; - Val* const in_ = nullptr; }; //! A specialization for Binary operations. Binary operations take in two inputs @@ -252,43 +201,89 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { //! 2) LT (A < B) class TORCH_CUDA_CU_API BinaryOp : public Expr { public: - BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs); + using Expr::Expr; - BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); + BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs); - Expr* shallowCopy() const override; + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "BinaryOp"; } Val* out() const { - return out_; + return output(0); } Val* lhs() const { - return lhs_; + return input(0); } Val* rhs() const { - return rhs_; + return input(1); } BinaryOpType getBinaryOpType() const { - return binary_op_type_; + return attribute(0)->as>()->value; } +}; - bool sameAs(const Statement* other) const override; +class TORCH_CUDA_CU_API TernaryOp : public Expr { + public: + using Expr::Expr; - private: - const BinaryOpType binary_op_type_; - Val* const out_ = nullptr; - Val* const lhs_ = nullptr; - Val* const rhs_ = nullptr; + TernaryOp( + IrBuilderPasskey, + TernaryOpType type, + Val* out, + Val* in1, + Val* in2, + Val* in3); + + DECLARE_CLONE_AND_CREATE + + virtual const char* getOpString() const override { + return "TernaryOp"; + } + + Val* out() const { + return output(0); + } + + Val* in1() const { + return input(0); + } + Val* in2() const { + return input(1); + } + Val* in3() const { + return input(2); + } + + TernaryOpType getTernaryOpType() const { + return attribute(0)->as>()->value; + } }; //! A specialization for random number generator (RNG) operations. RNG //! operations take in no tensor input and produce a single output. class TORCH_CUDA_CU_API RNGOp : public Expr { + size_t getOutputDims() const; + public: + struct Attributes { + RNGOpType rtype; + DataType dtype; + int rng_offset; + + // TODO: Enable the following in C++20: + // bool operator==(const Attributes &other) const = default; + bool operator==(const Attributes& other) const { + return rtype == other.rtype && dtype == other.dtype && + rng_offset == other.rng_offset; + } + }; + + using Expr::Expr; + RNGOp( IrBuilderPasskey, RNGOpType type, @@ -298,62 +293,47 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { int rng_offset = 0, Val* philox_index = nullptr); - RNGOp(const RNGOp* src, IrCloner* ir_cloner); - - Expr* shallowCopy() const override; + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "RNGOp"; } RNGOpType getRNGOpType() const { - return rng_op_type_; + return attribute(0)->as>()->value.rtype; } DataType dtype() const { - return dtype_; + return attribute(0)->as>()->value.dtype; } int getRNGOffset() const { - return rng_offset_; + return attribute(0)->as>()->value.rng_offset; } void setRNGOffset(int val) { - rng_offset_ = val; + attribute(0)->as>()->value.rng_offset = val; } - const std::vector& getParameters() const { - return parameters_; + std::vector getParameters() const { + return {inputs().begin() + getOutputDims(), inputs().end()}; } - const std::vector& getShape() const { - return shape_; + std::vector getShape() const { + return {inputs().begin(), inputs().begin() + getOutputDims()}; } Val* getPhiloxIndex() const { - return philox_index_; - } - - void setPhiloxIndex(Val* index) { - philox_index_ = index; + return attribute(1); } - - bool sameAs(const Statement* other) const override; - - private: - const RNGOpType rng_op_type_; - const DataType dtype_; - std::vector parameters_; - std::vector shape_; - int rng_offset_ = -1; - // The index used to feed philox's subsequence and component - Val* philox_index_ = nullptr; }; //! Broadcast in to match out. is_broadcast_dims are relative to out. Where //! is_broadcast_dims.size() == out->nDims(). class TORCH_CUDA_CU_API BroadcastOp : public Expr { public: + using Expr::Expr; + //! \param out The output tensor //! \param in The input tensor //! \param is_broadcast_dims True when output dim is a new broadcast domain @@ -363,42 +343,32 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { Val* in, std::vector is_broadcast_dims); - BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "BroadcastOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } bool isBroadcastDim(size_t dim) const { - return is_broadcast_dims_.at(dim); + return getBroadcastDimFlags().at(dim); } - const std::vector& getBroadcastDimFlags() const { - return is_broadcast_dims_; - } - - bool sameAs(const Statement* other) const override; - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! The same list passed to the broadcast arithmetic op. Each //! element corresponds to an IterDomain of the output tensor and is //! true when the IterDomain is a new broadcast domain. Note //! that the output tensor may have other broadcast domains whose //! flags are false because the input tensor may already have //! broadcast domains. - const std::vector is_broadcast_dims_; + const std::vector& getBroadcastDimFlags() const { + return attribute(0)->as>>()->value; + } }; //! Squeeze in to match out. is_squeeze_dims are relative to in. Where @@ -406,6 +376,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { //! broadcast. class TORCH_CUDA_CU_API SqueezeOp : public Expr { public: + using Expr::Expr; + //! \param out The output tensor //! \param in The input tensor //! \param is_squeeze_dims True when input dim is a removed broadcast domain @@ -415,42 +387,32 @@ class TORCH_CUDA_CU_API SqueezeOp : public Expr { Val* in, std::vector is_broadcast_dims); - SqueezeOp(const SqueezeOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "SqueezeOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } bool isSqueezeDim(size_t dim) const { - return is_squeeze_dims_.at(dim); - } - - const std::vector& getSqueezeDimFlags() const { - return is_squeeze_dims_; + return getSqueezeDimFlags().at(dim); } - bool sameAs(const Statement* other) const override; - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! The same list passed to the squeeze arithmetic op. Each //! element corresponds to an IterDomain of the input tensor and is //! true when the IterDomain is a broadcast domain that is removed in the //! output. Note that the output tensor may still contain broadcast domains //! because the input tensor may have broadcast domains that we don't want to //! remove (false flag). - const std::vector is_squeeze_dims_; + const std::vector& getSqueezeDimFlags() const { + return attribute(0)->as>>()->value; + } }; //! Reduction operation. Out is first initialized to _init. Then @@ -460,6 +422,8 @@ class TORCH_CUDA_CU_API SqueezeOp : public Expr { //! non-reduction/non-broadcast dimensions. class TORCH_CUDA_CU_API ReductionOp : public Expr { public: + using Expr::Expr; + ReductionOp( IrBuilderPasskey, BinaryOpType reduction_op_type, @@ -468,41 +432,29 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { Val* in, bool is_allreduce = false); - ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ReductionOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } Val* init() const { - return init_; + return attribute(0); } BinaryOpType getReductionOpType() const { - return reduction_op_type_; + return attribute(1)->as>()->value; } bool isAllreduce() const { - return is_allreduce_; + return attribute(2)->as>()->value; } - - bool sameAs(const Statement* other) const override; - - private: - const BinaryOpType reduction_op_type_; - Val* const init_ = nullptr; - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! True if broadcast is fused - bool is_allreduce_ = false; }; //! Grouped reduction operation for horizontal fusions. It works like @@ -513,61 +465,51 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { //! significant performance impact. class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { public: + using Expr::Expr; + GroupedReductionOp( IrBuilderPasskey, - std::vector reduction_op_type, + std::vector reduction_op_types, std::vector init, std::vector out, std::vector in, bool is_allreduce = false); - GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "GroupedReductionOp"; } - Expr* shallowCopy() const override; - //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. size_t numExprs() const { - return reduction_op_types_.size(); + return getReductionOpTypes().size(); } - const std::vector& initVals() const { - return init_vals_; + std::vector initVals() const { + return {attributes().begin() + 2, attributes().end()}; } Val* initVal(size_t index) const { - return init_vals_.at(index); + return attribute(2 + index); } const std::vector& getReductionOpTypes() const { - return reduction_op_types_; + return attribute(0)->as>>()->value; } BinaryOpType getReductionOpType(size_t index) const { - return reduction_op_types_.at(index); + return getReductionOpTypes().at(index); } bool isAllreduce() const { - return is_allreduce_; + return attribute(1)->as>()->value; } //! Return the index of the corresponding reduction expression for //! a given output val. int getExprIndexOfOutput(Val* output_val) const; - - bool sameAs(const Statement* other) const override; - - private: - //! Reduction ops of grouped reductions - const std::vector reduction_op_types_; - //! Initial values of grouped reductions - const std::vector init_vals_; - //! True if using the fused reduction kernel - bool is_allreduce_ = false; }; //! Average, variance and N (count) vals for Welford @@ -693,6 +635,8 @@ class TORCH_CUDA_CU_API WelfordTriplet { //! Welford Scan operation. class TORCH_CUDA_CU_API WelfordOp : public Expr { public: + using Expr::Expr; + WelfordOp( IrBuilderPasskey, const WelfordTriplet& output, @@ -713,70 +657,66 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { Val* init_N, bool is_fused = false); - WelfordOp(const WelfordOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "WelfordOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return output().avg(); + return outputTriplet().avg(); } Val* in() const { - return input().avg(); + return inputTriplet().avg(); } - bool sameAs(const Statement* const other) const override; - - const WelfordTriplet& output() const { - return output_; + WelfordTriplet outputTriplet() const { + return WelfordTriplet(outAvg(), outVar(), outN()); } Val* outAvg() const { - return output().avg(); + return output(0); } Val* outVar() const { - return output().var(); + return output(1); } Val* outN() const { - return output().N(); + return output(2); } - const WelfordTriplet& input() const { - return input_; + WelfordTriplet inputTriplet() const { + return WelfordTriplet(inAvg(), inVar(), inN()); } Val* inAvg() const { - return input().avg(); + return input(0); } Val* inVar() const { - return input().var(); + return input(1); } Val* inN() const { - return input().N(); + return input(2); } - const WelfordTriplet& init() const { - return init_; + WelfordTriplet initTriplet() const { + return WelfordTriplet(initAvg(), initVar(), initN()); } Val* initAvg() const { - return init().avg(); + return attribute(0); } Val* initVar() const { - return init().var(); + return attribute(1); } Val* initN() const { - return init().N(); + return attribute(2); } bool singleValue() const { @@ -787,25 +727,21 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { return !initN()->isZeroInt(); } + //! True if using the fused reduction kernel (not implemented yet) bool isAllreduce() const { - return is_allreduce_; + return attribute(3)->as>()->value; } std::vector getInitVals() const; //! Return the init val for an output val Val* getInitValOfOutput(Val* output_val) const; - - private: - const WelfordTriplet output_; - const WelfordTriplet input_; - const WelfordTriplet init_; - //! True if using the fused reduction kernel (not implemented yet) - bool is_allreduce_ = false; }; class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { public: + using Expr::Expr; + GroupedWelfordOp( IrBuilderPasskey, std::vector output_vals, @@ -813,14 +749,12 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { std::vector init_vals, bool is_allreduce = false); - GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "GroupedWelfordOp"; } - Expr* shallowCopy() const override; - //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. As horizontal grouping is not supported, //! this always returns 1. @@ -836,54 +770,70 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { return inAvg(index); } - bool sameAs(const Statement* const other) const override; - - const std::vector& outputVals() const { - return output_vals_; + std::vector outputVals() const { + std::vector result; + auto size = outputs().size() / 3; + result.reserve(size); + for (auto i : c10::irange(size)) { + result.emplace_back(outAvg(i), outVar(i), outN(i)); + } + return result; } - const std::vector& inputVals() const { - return input_vals_; + std::vector inputVals() const { + std::vector result; + auto size = inputs().size() / 3; + result.reserve(size); + for (auto i : c10::irange(size)) { + result.emplace_back(inAvg(i), inVar(i), inN(i)); + } + return result; } - const std::vector& initVals() const { - return init_vals_; + std::vector initVals() const { + std::vector result; + auto size = attributes().size() / 3; + result.reserve(size); + for (auto i : c10::irange(size)) { + result.emplace_back(initAvg(i), initVar(i), initN(i)); + } + return result; } Val* outAvg(size_t index) const { - return outputVals().at(index).avg(); + return output(index * 3); } Val* outVar(size_t index) const { - return outputVals().at(index).var(); + return output(index * 3 + 1); } Val* outN(size_t index) const { - return outputVals().at(index).N(); + return output(index * 3 + 2); } Val* inAvg(size_t index) const { - return inputVals().at(index).avg(); + return input(index * 3); } Val* inVar(size_t index) const { - return inputVals().at(index).var(); + return input(index * 3 + 1); } Val* inN(size_t index) const { - return inputVals().at(index).N(); + return input(index * 3 + 2); } Val* initAvg(size_t index) const { - return initVals().at(index).avg(); + return attribute(1 + index * 3); } Val* initVar(size_t index) const { - return initVals().at(index).var(); + return attribute(2 + index * 3); } Val* initN(size_t index) const { - return initVals().at(index).N(); + return attribute(3 + index * 3); } //! Return the index of the corresponding welford expression for @@ -902,15 +852,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { } bool isAllreduce() const { - return is_allreduce_; + return attribute(0)->as>()->value; } - - private: - const std::vector output_vals_; - const std::vector input_vals_; - const std::vector init_vals_; - //! True if using the fused reduction kernel - bool is_allreduce_ = false; }; //! Fused Matmul operation @@ -932,6 +875,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { } }; + using Expr::Expr; + MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init); MmaOp( @@ -942,181 +887,104 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { Val* init, OptionsInMma options); - MmaOp(const MmaOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "MmaOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* inA() const { - return in_a_; + return input(0); } Val* inB() const { - return in_b_; + return input(1); } Val* init() const { - return init_; + return attribute(0); } const auto& options() const { - TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:", this); - return options_.value(); + return attribute(1)->as>()->value; } - bool sameAs(const Statement* const other) const override; - auto accStride() const { - TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:", this); - return options_->accumulator_stride; + return options().accumulator_stride; } - void configureOptions(MmaOptions options) { - options_ = OptionsInMma(); - TORCH_INTERNAL_ASSERT( - options.macro != MmaOptions::MacroType::NoMMA, - "Un-configured mma type from options."); - TORCH_INTERNAL_ASSERT( - options.accumulator_stride > 0, "Un-configured accumulator stride."); - options_->accumulator_stride = options.accumulator_stride; - options_->macro = options.macro; - options_->operand_layout = options.operand_layout; - } - - private: - Val* const out_ = nullptr; - Val* const in_a_ = nullptr; - Val* const in_b_ = nullptr; - Val* const init_ = nullptr; - c10::optional options_ = c10::nullopt; + void configureOptions(MmaOptions options); }; class TORCH_CUDA_CU_API TransposeOp : public Expr { public: + using Expr::Expr; + TransposeOp( IrBuilderPasskey, TensorView* out, TensorView* in, std::vector new2old); - TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "TransposeOp"; } - Expr* shallowCopy() const override; - TensorView* out() const { - return out_; + return output(0)->as(); } TensorView* in() const { - return in_; + return input(0)->as(); } const std::vector& new2old() const { - return new2old_; + return attribute(0)->as>>()->value; } std::vector old2new() const; - - private: - TensorView* const out_ = nullptr; - TensorView* const in_ = nullptr; - const std::vector new2old_; }; class TORCH_CUDA_CU_API ExpandOp : public Expr { public: + using Expr::Expr; + ExpandOp( IrBuilderPasskey, TensorView* out, TensorView* in, std::vector _expanded_extents); - ExpandOp(const ExpandOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ExpandOp"; } - Expr* shallowCopy() const override; - TensorView* out() const { - return out_; + return output(0)->as(); } TensorView* in() const { - return in_; - } - - const std::vector& expanded_extents() const { - return expanded_extents_; - } - - private: - TensorView* const out_ = nullptr; - TensorView* const in_ = nullptr; - std::vector expanded_extents_; -}; - -class TORCH_CUDA_CU_API TernaryOp : public Expr { - public: - TernaryOp( - IrBuilderPasskey, - TernaryOpType type, - Val* out, - Val* in1, - Val* in2, - Val* in3); - - TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); - - virtual const char* getOpString() const override { - return "TernaryOp"; - } - - Expr* shallowCopy() const override; - - Val* out() const { - return out_; + return input(0)->as(); } - Val* in1() const { - return in1_; + std::vector expanded_extents() const { + return {inputs().begin() + 1, inputs().end()}; } - Val* in2() const { - return in2_; - } - Val* in3() const { - return in3_; - } - - TernaryOpType getTernaryOpType() const { - return ternary_op_type_; - } - - bool sameAs(const Statement* other) const override; - - private: - const TernaryOpType ternary_op_type_; - Val* const out_ = nullptr; - Val* const in1_ = nullptr; - Val* const in2_ = nullptr; - Val* const in3_ = nullptr; }; //! Shift class TORCH_CUDA_CU_API ShiftOp : public Expr { public: + using Expr::Expr; + //! \param out //! \param in //! \param offsets @@ -1127,54 +995,45 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { std::vector offsets, std::vector pad_width); - ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ShiftOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } int offset(size_t dim) const { - return offsets_.at(dim); + return offsets().at(dim); } + //! Each of the root axes is shifted by the corresponding value of + //! offsets. The sign of each value indicates the direction of shifting. const std::vector& offsets() const { - return offsets_; + return attribute(0)->as>>()->value; } const std::vector& padWidth() const { - return pad_width_; + return attribute(1)->as>>()->value; } bool hasPadding() const { - return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto p) { + return std::any_of(padWidth().begin(), padWidth().end(), [](const auto p) { return p > 0; }); } - - bool sameAs(const Statement* other) const override; - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! Each of the root axes is shifted by the corresponding value of - //! offsets_. The sign of each value indicates the direction of - //! shifting. - const std::vector offsets_; - const std::vector pad_width_; }; //! Gather a window around each element. class TORCH_CUDA_CU_API GatherOp : public Expr { public: + using Expr::Expr; + GatherOp( IrBuilderPasskey, Val* out, @@ -1182,51 +1041,43 @@ class TORCH_CUDA_CU_API GatherOp : public Expr { std::vector window_shape, std::vector> pad_width); - GatherOp(const GatherOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "GatherOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } + //! Shape of a window gathered for each element. const auto& windowShape() const { - return window_shape_; + return attribute(0)->as>>()->value; } //! Returns the gather axis that corresponds to an input axis int gatherAxis(int axis) const; + //! The size of zero-padding of each axis. const auto& padWidth() const { - return pad_width_; + return attribute(1)->as>>>()->value; } bool hasPadding() const { - return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto& p) { + return std::any_of(padWidth().begin(), padWidth().end(), [](const auto& p) { return p[0] > 0 || p[1] > 0; }); } - - bool sameAs(const Statement* other) const override; - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! Shape of a window gathered for each element. - std::vector window_shape_; - //! The size of zero-padding of each axis. - std::vector> pad_width_; }; class TORCH_CUDA_CU_API ViewAsScalar : public Expr { public: + using Expr::Expr; + ViewAsScalar( IrBuilderPasskey, Val* out, @@ -1234,64 +1085,50 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr { IterDomain* vector_id, Val* index = nullptr); - ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ViewAsScalar"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } + // The IterDomain of type VectorComponent newly appended to the output IterDomain* vector_id() const { - return vector_id_; + return attribute(0)->as(); } + // The index that vector_id_ is lowered into Val* index() const { - return index_; + return attribute(1); } - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - - // The IterDomain of type VectorComponent newly appended to the output - IterDomain* vector_id_ = nullptr; - - // The index that vector_id_ is lowered into - Val* index_ = nullptr; }; class TORCH_CUDA_CU_API ViewOp : public Expr { public: - ViewOp(IrBuilderPasskey, TensorView* out, TensorView* in); + using Expr::Expr; - ViewOp(const ViewOp* src, IrCloner* ir_cloner); + ViewOp(IrBuilderPasskey, Val* out, Val* in); + + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ViewOp"; } - Expr* shallowCopy() const override; - - TensorView* out() const { - return out_; + Val* out() const { + return output(0); } - TensorView* in() const { - return in_; + Val* in() const { + return input(0); } - - private: - TensorView* const out_ = nullptr; - TensorView* const in_ = nullptr; }; //! This operator explicitly models data movement between @@ -1302,32 +1139,27 @@ class TORCH_CUDA_CU_API ViewOp : public Expr { //! accelerated memory ops, i.e. ldmatrix, cp.async and more to come. class TORCH_CUDA_CU_API LoadStoreOp : public Expr { public: + using Expr::Expr; + LoadStoreOp(IrBuilderPasskey, LoadStoreOpType op_type, Val* out, Val* in); - LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "LoadStoreOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } LoadStoreOpType opType() const { - return load_store_type_; + return attribute(0)->as>()->value; } - - private: - LoadStoreOpType load_store_type_ = LoadStoreOpType::LdMatrix; - Val* const out_ = nullptr; - Val* const in_ = nullptr; }; // Convenience utility to initialize IterDomain's without having to sort through @@ -1407,6 +1239,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val { IterDomain(const IterDomain* src, IrCloner* ir_cloner); + DECLARE_CLONE + bool sameAs(const Statement* other) const override; //! Returns a new IterDomain matching properties of this @@ -1714,6 +1548,8 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { TensorDomain(const TensorDomain* src, IrCloner* ir_cloner); + DECLARE_CLONE + bool operator==(const TensorDomain& other) const; bool operator!=(const TensorDomain& other) const { return !(*this == other); @@ -1902,6 +1738,8 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { //! remainer or outside. class TORCH_CUDA_CU_API Split : public Expr { public: + using Expr::Expr; + // start_offset and stop_offset are used to express partial // split. Only the partial domain from start_offset to stop_offset // is split and the outer sub-regions are ignored. Note that both @@ -1917,58 +1755,45 @@ class TORCH_CUDA_CU_API Split : public Expr { Val* start_offset = nullptr, Val* stop_offset = nullptr); - Split(const Split* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "Split"; } - Expr* shallowCopy() const override; - IterDomain* outer() const { - return outer_; + return output(0)->as(); } IterDomain* inner() const { - return inner_; + return output(1)->as(); } IterDomain* in() const { - return in_; + return input(0)->as(); } Val* factor() const { - return factor_; + return attribute(0); } bool innerSplit() const { - return inner_split_; + return attribute(1)->as>()->value; } + //! Start position of the input domain. Non-zero means partial + //! split. Elements until this offset are ignored. Val* startOffset() const { - TORCH_INTERNAL_ASSERT(start_offset_ != nullptr); - return start_offset_; + TORCH_INTERNAL_ASSERT(attribute(2) != nullptr); + return attribute(2); } + //! Offset from extent of the input domain. Non-zero means partial + //! split. Elements after this offset are ignored. Val* stopOffset() const { - TORCH_INTERNAL_ASSERT(stop_offset_ != nullptr); - return stop_offset_; + TORCH_INTERNAL_ASSERT(attribute(3) != nullptr); + return attribute(3); } //! Utility function to compute the split extent. static Val* extent(Val* in_extent, Val* start_offset, Val* stop_offset); - - bool sameAs(const Statement* other) const override; - - private: - IterDomain* const outer_ = nullptr; - IterDomain* const inner_ = nullptr; - IterDomain* const in_ = nullptr; - Val* const factor_ = nullptr; - bool inner_split_ = true; - //! Start position of the input domain. Non-zero means partial - //! split. Elements until this offset are ignored. - Val* const start_offset_ = nullptr; - //! Offset from extent of the input domain. Non-zero means partial - //! split. Elements after this offset are ignored. - Val* const stop_offset_ = nullptr; }; //! Merge the IterDomains outer and inner into one domain, outer and inner @@ -1977,41 +1802,36 @@ class TORCH_CUDA_CU_API Split : public Expr { //! strategy if there is one class TORCH_CUDA_CU_API Merge : public Expr { public: + using Expr::Expr; + Merge( IrBuilderPasskey, IterDomain* out, IterDomain* outer, IterDomain* inner); - Merge(const Merge* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "Merge"; } - Expr* shallowCopy() const override; - IterDomain* out() const { - return out_; + return output(0)->as(); } IterDomain* outer() const { - return outer_; + return input(0)->as(); } IterDomain* inner() const { - return inner_; + return input(1)->as(); } - - bool sameAs(const Statement* other) const override; - - private: - IterDomain* const out_ = nullptr; - IterDomain* const outer_ = nullptr; - IterDomain* const inner_ = nullptr; }; //! Applies 2D swizzles on a rectangular tile defined by 2 iterdomains. class TORCH_CUDA_CU_API Swizzle2D : public Expr { public: + using Expr::Expr; + Swizzle2D( IrBuilderPasskey, IterDomain* out_x, @@ -2021,53 +1841,36 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle, SwizzleMode swizzle_mode = SwizzleMode::Data); - Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "Swizzle2D"; } - Expr* shallowCopy() const override; - + // Output iterdomain pair corresponding + // to the original input iterdomain pair. IterDomain* outX() const { - return out_x_; + return output(0)->as(); } IterDomain* outY() const { - return out_y_; + return output(1)->as(); } + // Input iterdomain pair. IterDomain* inX() const { - return in_x_; + return input(0)->as(); } IterDomain* inY() const { - return in_y_; - } - - auto swizzleType() const { - return swizzle_type_; + return input(1)->as(); } - auto swizzleMode() const { - return swizzle_mode_; - } - - bool sameAs(const Statement* other) const override; - - private: - // Output iterdomain pair corresponding - // to the original input iterdomain pair. - IterDomain* const out_x_ = nullptr; - IterDomain* const out_y_ = nullptr; - - // Input iterdomain pair. - IterDomain* const in_x_ = nullptr; - IterDomain* const in_y_ = nullptr; - // The type of predefined 1-to-1 functions // used for swizzling math. - Swizzle2DType swizzle_type_ = Swizzle2DType::NoSwizzle; + auto swizzleType() const { + return attribute(0)->as>()->value; + } // Swizzle mode of this swizzle instance. // [Note on swizzle mode] @@ -2110,7 +1913,9 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { // } // TODO: Loop swizzles eventually will be piped through in all mappings // and replay of the fusion IR infrastructure. - SwizzleMode swizzle_mode_ = SwizzleMode::Data; + auto swizzleMode() const { + return attribute(1)->as>()->value; + } }; //! Integer value which has a special name @@ -2127,6 +1932,8 @@ class TORCH_CUDA_CU_API NamedScalar : public Val { NamedScalar(const NamedScalar* src, IrCloner* ir_cloner); + DECLARE_CLONE + const std::string& name() const { return name_; } diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 12c634d3e0af..f871da903fda 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -85,6 +85,8 @@ Bool::Bool(IrBuilderPasskey passkey, c10::optional value) Bool::Bool(const Bool* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} +DEFINE_CLONE(Bool) + bool Bool::sameAs(const Statement* other) const { if (this == other) { return true; @@ -125,6 +127,8 @@ bool Double::sameAs(const Statement* other) const { return false; } +DEFINE_CLONE(Double) + Int::Int(IrBuilderPasskey passkey) : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{c10::nullopt} {} @@ -138,6 +142,8 @@ Int::Int(IrBuilderPasskey passkey, c10::optional value) Int::Int(const Int* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} +DEFINE_CLONE(Int) + bool Int::sameAs(const Statement* other) const { if (this == other) { return true; @@ -169,6 +175,8 @@ ComplexDouble::ComplexDouble( ComplexDouble::ComplexDouble(const ComplexDouble* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} +DEFINE_CLONE(ComplexDouble) + bool ComplexDouble::sameAs(const Statement* other) const { if (this == other) { return true; @@ -187,7 +195,7 @@ FullOp::FullOp( Val* out, Val* fill_value, DataType dtype) - : Expr(passkey), dtype_(dtype), fill_value_(fill_value) { + : Expr(passkey) { if (out->isA()) { auto tv_root = out->as()->getRootDomain(); for (auto id : tv_root) { @@ -196,32 +204,11 @@ FullOp::FullOp( } addInput(fill_value); addOutput(out); + addAttribute( + IrBuilder::create>(passkey.ir_container_, dtype)); } -FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - dtype_(src->dtype()), - fill_value_(ir_cloner->clone(src->fill_value_)) {} - -Expr* FullOp::shallowCopy() const { - auto result = IrBuilder::create(output(0), fill_value_, dtype_); - result->copyPredicatesFrom(this); - return result; -} - -bool FullOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (dtype_ != other_op->dtype_) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(FullOp) SelectOp::SelectOp( IrBuilderPasskey passkey, @@ -229,35 +216,15 @@ SelectOp::SelectOp( Val* in, IterDomain* select_id, Val* index) - : Expr(passkey), select_id_(select_id) { + : Expr(passkey) { addInput(in); addInput(index); addOutput(out); + addAttribute(select_id); + addAttribute(index); } -SelectOp::SelectOp(const SelectOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), select_id_(ir_cloner->clone(src->select_id_)) {} - -Expr* SelectOp::shallowCopy() const { - auto result = - IrBuilder::create(output(0), input(0), select_id_, input(1)); - result->copyPredicatesFrom(this); - return result; -} - -bool SelectOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (!select_id_->sameAs(other_op->select_id_)) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(SelectOp) ARangeOp::ARangeOp( IrBuilderPasskey passkey, @@ -267,62 +234,17 @@ ARangeOp::ARangeOp( Val* step, DataType dtype, Val* linear_index) - : Expr(passkey), - dtype_(dtype), - start_(start), - end_(end), - step_(step), - linear_index_(linear_index) { + : Expr(passkey) { addInput(start); addInput(end); addInput(step); addOutput(out); + addAttribute( + IrBuilder::create>(passkey.ir_container_, dtype)); + addAttribute(linear_index); } -ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - dtype_(src->dtype()), - start_(ir_cloner->clone(src->start_)), - end_(ir_cloner->clone(src->end_)), - step_(ir_cloner->clone(src->step_)), - linear_index_(ir_cloner->clone(src->linear_index_)) {} - -Expr* ARangeOp::shallowCopy() const { - auto result = IrBuilder::create( - output(0), start_, end_, step_, dtype_, linear_index_); - result->copyPredicatesFrom(this); - return result; -} - -bool ARangeOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (dtype_ != other_op->dtype_) { - return false; - } - if (!start_->sameAs(other_op->start_)) { - return false; - } - if (!end_->sameAs(other_op->end_)) { - return false; - } - if (!step_->sameAs(other_op->step_)) { - return false; - } - if ((linear_index_ == nullptr) != (other_op->linear_index_ == nullptr)) { - return false; - } - if ((linear_index_ != nullptr) && - !linear_index_->sameAs(other_op->linear_index_)) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(ARangeOp) EyeOp::EyeOp( IrBuilderPasskey passkey, @@ -330,7 +252,7 @@ EyeOp::EyeOp( DataType dtype, Val* index1, Val* index2) - : Expr(passkey), dtype_(dtype), index1_(index1), index2_(index2) { + : Expr(passkey) { if (out->isA()) { addInput(out->as()->getRootDomain()[0]->extent()); if (out->as()->getRootDomain()[1] != @@ -339,82 +261,23 @@ EyeOp::EyeOp( } } addOutput(out); + addAttribute( + IrBuilder::create>(passkey.ir_container_, dtype)); + addAttribute(index1); + addAttribute(index2); } -EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - dtype_(src->dtype_), - index1_(ir_cloner->clone(src->index1_)), - index2_(ir_cloner->clone(src->index2_)) {} - -Expr* EyeOp::shallowCopy() const { - auto result = IrBuilder::create(output(0), dtype_, index1_, index2_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(EyeOp) -bool EyeOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (dtype_ != other_op->dtype_) { - return false; - } - if ((index1_ == nullptr) != (other_op->index1_ == nullptr)) { - return false; - } - if ((index2_ == nullptr) != (other_op->index2_ == nullptr)) { - return false; - } - if ((index1_ != nullptr) && !index1_->sameAs(other_op->index1_)) { - return false; - } - if ((index2_ != nullptr) && !index2_->sameAs(other_op->index2_)) { - return false; - } - return Expr::sameAs(other); -} - -UnaryOp::UnaryOp( - IrBuilderPasskey passkey, - UnaryOpType type, - Val* out, - Val* in, - int rng_offset) - : Expr(passkey), unary_op_type_{type}, out_{out}, in_{in} { +UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in) + : Expr(passkey) { addOutput(out); addInput(in); + addAttribute( + IrBuilder::create>(passkey.ir_container_, type)); } -UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - unary_op_type_(src->unary_op_type_), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) {} - -Expr* UnaryOp::shallowCopy() const { - auto result = IrBuilder::create(unary_op_type_, out_, in_); - result->copyPredicatesFrom(this); - return result; -} - -bool UnaryOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getUnaryOpType() != other_op->getUnaryOpType()) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(UnaryOp) BinaryOp::BinaryOp( IrBuilderPasskey passkey, @@ -422,38 +285,15 @@ BinaryOp::BinaryOp( Val* out, Val* lhs, Val* rhs) - : Expr(passkey), binary_op_type_{type}, out_{out}, lhs_{lhs}, rhs_{rhs} { + : Expr(passkey) { addOutput(out); addInput(lhs); addInput(rhs); + addAttribute( + IrBuilder::create>(passkey.ir_container_, type)); } -BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - binary_op_type_(src->binary_op_type_), - out_(ir_cloner->clone(src->out_)), - lhs_(ir_cloner->clone(src->lhs_)), - rhs_(ir_cloner->clone(src->rhs_)) {} - -Expr* BinaryOp::shallowCopy() const { - auto result = IrBuilder::create(binary_op_type_, out_, lhs_, rhs_); - result->copyPredicatesFrom(this); - return result; -} - -bool BinaryOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getBinaryOpType() != other_op->getBinaryOpType()) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(BinaryOp) TernaryOp::TernaryOp( IrBuilderPasskey passkey, @@ -462,46 +302,16 @@ TernaryOp::TernaryOp( Val* in1, Val* in2, Val* in3) - : Expr(passkey), - ternary_op_type_{type}, - out_{out}, - in1_{in1}, - in2_{in2}, - in3_{in3} { + : Expr(passkey) { addOutput(out); addInput(in1); addInput(in2); addInput(in3); + addAttribute( + IrBuilder::create>(passkey.ir_container_, type)); } -TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - ternary_op_type_(src->ternary_op_type_), - out_(ir_cloner->clone(src->out_)), - in1_(ir_cloner->clone(src->in1_)), - in2_(ir_cloner->clone(src->in2_)), - in3_(ir_cloner->clone(src->in3_)) {} - -Expr* TernaryOp::shallowCopy() const { - auto result = - IrBuilder::create(ternary_op_type_, out_, in1_, in2_, in3_); - result->copyPredicatesFrom(this); - return result; -} - -bool TernaryOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getTernaryOpType() != other_op->getTernaryOpType()) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(TernaryOp) RNGOp::RNGOp( IrBuilderPasskey passkey, @@ -511,74 +321,31 @@ RNGOp::RNGOp( std::vector parameters, int rng_offset, Val* philox_index) - : Expr(passkey), - rng_op_type_(type), - dtype_(dtype), - parameters_(std::move(parameters)), - rng_offset_(rng_offset), - philox_index_(philox_index) { - if (out->isA()) { - for (auto id : out->as()->getRootDomain()) { - shape_.emplace_back(id->extent()); + : Expr(passkey) { + if (auto tv_out = dynamic_cast(out)) { + for (auto id : tv_out->getRootDomain()) { + TORCH_CHECK(!id->isReduction(), "Output of RNGOp can not have reduction"); + addInput(id->extent()); } } - for (auto v : shape_) { - addInput(v); - } - for (auto v : parameters_) { + for (auto v : parameters) { addInput(v); } addOutput(out); + RNGOp::Attributes attr{type, dtype, rng_offset}; + addAttribute(IrBuilder::create>( + passkey.ir_container_, attr)); + addAttribute(philox_index); } -RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - rng_op_type_(src->rng_op_type_), - dtype_(src->dtype()), - parameters_(ir_cloner->clone(src->parameters_)), - rng_offset_(src->rng_offset_), - philox_index_(ir_cloner->clone(src->philox_index_)) {} - -Expr* RNGOp::shallowCopy() const { - auto result = IrBuilder::create( - rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(RNGOp) -bool RNGOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getRNGOpType() != other_op->getRNGOpType()) { - return false; - } - if (dtype_ != other_op->dtype_) { - return false; - } - if (parameters_.size() != other_op->parameters_.size()) { - return false; - } - for (auto i : c10::irange(parameters_.size())) { - if (!parameters_[i]->sameAs(other_op->parameters_[i])) { - return false; - } - } - if (getRNGOffset() != other_op->getRNGOffset()) { - return false; - } - if ((philox_index_ == nullptr) != (other_op->philox_index_ == nullptr)) { - return false; +size_t RNGOp::getOutputDims() const { + size_t ndims = 0; + if (auto tv_out = dynamic_cast(output(0))) { + ndims = tv_out->getRootDomain().size(); } - if ((philox_index_ != nullptr) && - !philox_index_->sameAs(other_op->philox_index_)) { - return false; - } - return Expr::sameAs(other); + return ndims; } BroadcastOp::BroadcastOp( @@ -586,10 +353,7 @@ BroadcastOp::BroadcastOp( Val* out, Val* in, std::vector is_broadcast_dims) - : Expr(passkey), - out_(out), - in_(in), - is_broadcast_dims_(std::move(is_broadcast_dims)) { + : Expr(passkey) { auto out_type = out->getValType().value(); auto in_type = in->getValType().value(); @@ -600,6 +364,8 @@ BroadcastOp::BroadcastOp( addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(is_broadcast_dims))); if (!out->isA() || !in->isA()) { return; @@ -610,13 +376,13 @@ BroadcastOp::BroadcastOp( auto in_dom = TensorDomain::noReductions(in_tv->getMaybeRFactorDomain()); auto& out_dom = out_tv->getRootDomain(); TORCH_INTERNAL_ASSERT( - is_broadcast_dims_.size() == out_dom.size(), + is_broadcast_dims.size() == out_dom.size(), "The dimensions of output tensor and does not match with is_broadcast_dims"); - auto out_size = is_broadcast_dims_.size(); + auto out_size = is_broadcast_dims.size(); auto num_new_broadcasts = 0; for (const auto i : c10::irange(out_size)) { - if (is_broadcast_dims_[i]) { + if (is_broadcast_dims[i]) { num_new_broadcasts++; auto id = out_dom[i]; TORCH_INTERNAL_ASSERT( @@ -640,41 +406,14 @@ BroadcastOp::BroadcastOp( "The dimensions of output tensor and does not match with is_broadcast_dims and input tensor"); } -BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_broadcast_dims_(src->is_broadcast_dims_) {} - -Expr* BroadcastOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, is_broadcast_dims_); - result->copyPredicatesFrom(this); - return result; -} - -bool BroadcastOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getBroadcastDimFlags() != other_op->getBroadcastDimFlags()) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(BroadcastOp) SqueezeOp::SqueezeOp( IrBuilderPasskey passkey, Val* out, Val* in, std::vector is_squeeze_dims) - : Expr(passkey), - out_(out), - in_(in), - is_squeeze_dims_(std::move(is_squeeze_dims)) { + : Expr(passkey) { auto out_type = out->getValType().value(); auto in_type = in->getValType().value(); @@ -685,6 +424,8 @@ SqueezeOp::SqueezeOp( addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(is_squeeze_dims))); if (!out->isA() || !in->isA()) { return; @@ -695,13 +436,13 @@ SqueezeOp::SqueezeOp( auto in_dom = TensorDomain::noReductions(in_tv->getMaybeRFactorDomain()); auto& out_dom = out_tv->getRootDomain(); TORCH_INTERNAL_ASSERT( - is_squeeze_dims_.size() == in_dom.size(), + is_squeeze_dims.size() == in_dom.size(), "The dimensions of input tensor and does not match with is_squeeze_dims"); - auto in_size = is_squeeze_dims_.size(); + auto in_size = is_squeeze_dims.size(); auto num_removed_broadcasts = 0; - for (const auto i : c10::irange(is_squeeze_dims_.size())) { - if (is_squeeze_dims_[i]) { + for (const auto i : c10::irange(is_squeeze_dims.size())) { + if (is_squeeze_dims[i]) { num_removed_broadcasts++; auto id = in_dom[i]; TORCH_INTERNAL_ASSERT( @@ -723,31 +464,7 @@ SqueezeOp::SqueezeOp( "The dimensions of output tensor and does not match with is_squeeze_dims and input tensor"); } -SqueezeOp::SqueezeOp(const SqueezeOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_squeeze_dims_(src->is_squeeze_dims_) {} - -Expr* SqueezeOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, is_squeeze_dims_); - result->copyPredicatesFrom(this); - return result; -} - -bool SqueezeOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getSqueezeDimFlags() != other_op->getSqueezeDimFlags()) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(SqueezeOp) ReductionOp::ReductionOp( IrBuilderPasskey passkey, @@ -756,12 +473,7 @@ ReductionOp::ReductionOp( Val* out, Val* in, bool is_allreduce) - : Expr(passkey), - reduction_op_type_(reduction_op_type), - init_(init), - out_(out), - in_(in), - is_allreduce_(is_allreduce) { + : Expr(passkey) { TORCH_CHECK( out->getValType().value() == ValType::TensorView || out->getValType().value() == ValType::TensorIndex); @@ -786,37 +498,14 @@ ReductionOp::ReductionOp( addOutput(out); addInput(in); + addAttribute(init); + addAttribute(IrBuilder::create>( + passkey.ir_container_, reduction_op_type)); + addAttribute( + IrBuilder::create>(passkey.ir_container_, is_allreduce)); } -ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - reduction_op_type_(src->reduction_op_type_), - init_(ir_cloner->clone(src->init_)), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_allreduce_(src->is_allreduce_) {} - -Expr* ReductionOp::shallowCopy() const { - auto result = IrBuilder::create( - reduction_op_type_, init_, out_, in_, is_allreduce_); - result->copyPredicatesFrom(this); - return result; -} - -bool ReductionOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - // Note that init is not part of input vals, so it must be checked separately. - return ( - Expr::sameAs(other) && - getReductionOpType() == other_op->getReductionOpType() && - init()->sameAs(other_op->init())); -} +DEFINE_CLONE_AND_CREATE(ReductionOp) GroupedReductionOp::GroupedReductionOp( IrBuilderPasskey passkey, @@ -825,10 +514,7 @@ GroupedReductionOp::GroupedReductionOp( std::vector outputs, std::vector inputs, bool is_fused) - : Expr(passkey), - reduction_op_types_(std::move(reduction_op_types)), - init_vals_(std::move(init_vals)), - is_allreduce_(is_fused) { + : Expr(passkey) { for (auto out : outputs) { addOutput(out); } @@ -836,23 +522,19 @@ GroupedReductionOp::GroupedReductionOp( for (auto in : inputs) { addInput(in); } -} -GroupedReductionOp::GroupedReductionOp( - const GroupedReductionOp* src, - IrCloner* ir_cloner) - : Expr(src, ir_cloner), - reduction_op_types_(src->reduction_op_types_), - init_vals_(ir_cloner->clone(src->init_vals_)), - is_allreduce_(src->is_allreduce_) {} + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(reduction_op_types))); + addAttribute( + IrBuilder::create>(passkey.ir_container_, is_fused)); -Expr* GroupedReductionOp::shallowCopy() const { - auto result = IrBuilder::create( - reduction_op_types_, init_vals_, outputs(), inputs(), is_allreduce_); - result->copyPredicatesFrom(this); - return result; + for (auto init : init_vals) { + addAttribute(init); + } } +DEFINE_CLONE_AND_CREATE(GroupedReductionOp) + int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { auto it = std::find(outputs().begin(), outputs().end(), output_val); if (it != outputs().end()) { @@ -863,28 +545,34 @@ int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { false, "Not an output, ", output_val->toString(), ", of ", toString()); } -bool GroupedReductionOp::sameAs(const Statement* other) const { - if (this == other) { - return true; +c10::optional WelfordTriplet::getNameOf( + Val* val) const { + auto it = std::find(begin(), end(), val); + if (it != end()) { + return indexToValName(std::distance(begin(), it)); } - auto grouped_rop = dynamic_cast(other); - if (grouped_rop == nullptr) { - return false; - } + return c10::optional(); +} - if (!Expr::sameAs(other) || - getReductionOpTypes() != grouped_rop->getReductionOpTypes()) { - return false; - } +bool WelfordTriplet::sameAs(const WelfordTriplet& other) const { + return this == &other || + (avg()->sameAs(other.avg()) && var()->sameAs(other.var()) && + N()->sameAs(other.N())); +} - for (const auto i : c10::irange(numExprs())) { - if (!initVal(i)->sameAs(grouped_rop->initVal(i))) { - return false; - } - } +WelfordTriplet WelfordTriplet::clone(IrCloner* ir_cloner) const { + return transform([&](const Val* val) { return ir_cloner->clone(val); }); +} - return true; +std::vector WelfordTriplet::clone( + const std::vector& src, + IrCloner* ir_cloner) { + std::vector cloned; + for (const auto& triplet : src) { + cloned.emplace_back(triplet.clone(ir_cloner)); + } + return cloned; } WelfordOp::WelfordOp( @@ -893,11 +581,7 @@ WelfordOp::WelfordOp( const WelfordTriplet& input, const WelfordTriplet& init, bool is_fused) - : Expr(passkey), - output_(output), - input_(input), - init_(init), - is_allreduce_(is_fused) { + : Expr(passkey) { // Previously, nullptr was accepted and implicitly replaced by // default values. Looks like we always pass some non-null values, // so removed the implicit default behavior for code simplicity. @@ -931,74 +615,50 @@ WelfordOp::WelfordOp( // initial value with a count of 1 is un-common enough that I'll push // the responsibility of creating all-zero var tensors to the user TORCH_INTERNAL_ASSERT( - init_.avg()->getValType().value() == ValType::TensorView || - init_.avg()->getValType().value() == ValType::TensorIndex); + init.avg()->getValType().value() == ValType::TensorView || + init.avg()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - init_.var()->getValType().value() == ValType::TensorView || - init_.var()->getValType().value() == ValType::TensorIndex, + init.var()->getValType().value() == ValType::TensorView || + init.var()->getValType().value() == ValType::TensorIndex, "Invalid initial var: ", - init_.var()->toString()); + init.var()->toString()); } // check input TORCH_INTERNAL_ASSERT( - input_.avg()->getValType().value() == ValType::TensorView || - input_.avg()->getValType().value() == ValType::TensorIndex, - input_.avg()->getValType().value()); + input.avg()->getValType().value() == ValType::TensorView || + input.avg()->getValType().value() == ValType::TensorIndex, + input.avg()->getValType().value()); TORCH_INTERNAL_ASSERT( - input_.N()->getValType().value() == ValType::Scalar || - input_.N()->getValType().value() == ValType::TensorView || - input_.N()->getValType().value() == ValType::TensorIndex); - TORCH_INTERNAL_ASSERT(isIntegralType(input_.N()->dtype())); - if (!input_.N()->isOneInt()) { + input.N()->getValType().value() == ValType::Scalar || + input.N()->getValType().value() == ValType::TensorView || + input.N()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT(isIntegralType(input.N()->dtype())); + if (!input.N()->isOneInt()) { // when input is only one value, only the value is required through avg // input the var part is implicitly 0 and codegen will handle that. TORCH_INTERNAL_ASSERT( - input_.var()->getValType().value() == ValType::TensorView || - input_.var()->getValType().value() == ValType::TensorIndex); + input.var()->getValType().value() == ValType::TensorView || + input.var()->getValType().value() == ValType::TensorIndex); } else { TORCH_INTERNAL_ASSERT( - input_.var() == nullptr || input_.var()->isZeroInt(), + input.var() == nullptr || input.var()->isZeroInt(), "Invalid var input, which must be either nullptr or scalar zero when the N input is one."); } - addOutput(output_.avg()); - addOutput(output_.var()); - addOutput(output_.N()); + addOutput(output.avg()); + addOutput(output.var()); + addOutput(output.N()); - addInput(input_.avg()); - addInput(input_.var()); - addInput(input_.N()); -} - -c10::optional WelfordTriplet::getNameOf( - Val* val) const { - auto it = std::find(begin(), end(), val); - if (it != end()) { - return indexToValName(std::distance(begin(), it)); - } + addInput(input.avg()); + addInput(input.var()); + addInput(input.N()); - return c10::optional(); -} - -bool WelfordTriplet::sameAs(const WelfordTriplet& other) const { - return this == &other || - (avg()->sameAs(other.avg()) && var()->sameAs(other.var()) && - N()->sameAs(other.N())); -} - -WelfordTriplet WelfordTriplet::clone(IrCloner* ir_cloner) const { - return transform([&](const Val* val) { return ir_cloner->clone(val); }); -} - -std::vector WelfordTriplet::clone( - const std::vector& src, - IrCloner* ir_cloner) { - std::vector cloned; - for (const auto& triplet : src) { - cloned.emplace_back(triplet.clone(ir_cloner)); - } - return cloned; + addAttribute(init.avg()); + addAttribute(init.var()); + addAttribute(init.N()); + addAttribute( + IrBuilder::create>(passkey.ir_container_, is_fused)); } WelfordOp::WelfordOp( @@ -1020,22 +680,10 @@ WelfordOp::WelfordOp( WelfordTriplet(init_avg, init_var, init_N), is_fused) {} -WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - output_(src->output_.clone(ir_cloner)), - input_(src->input_.clone(ir_cloner)), - init_(src->init_.clone(ir_cloner)), - is_allreduce_(src->is_allreduce_) {} - -Expr* WelfordOp::shallowCopy() const { - auto result = - IrBuilder::create(output_, input_, init_, is_allreduce_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(WelfordOp) Val* WelfordOp::getInitValOfOutput(Val* output_val) const { - auto val_name = output().getNameOf(output_val); + auto val_name = outputTriplet().getNameOf(output_val); TORCH_INTERNAL_ASSERT( val_name.has_value(), @@ -1044,21 +692,11 @@ Val* WelfordOp::getInitValOfOutput(Val* output_val) const { " of ", toString()); - return init().get(*val_name); -} - -bool WelfordOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (auto other_wop = dynamic_cast(other)) { - return input_.sameAs(other_wop->input_) && init_.sameAs(other_wop->init_); - } - return false; + return initTriplet().get(*val_name); } std::vector WelfordOp::getInitVals() const { - std::vector init_vals({init_.avg(), init_.var(), init_.N()}); + std::vector init_vals({initAvg(), initVar(), initN()}); return init_vals; } @@ -1068,43 +706,39 @@ GroupedWelfordOp::GroupedWelfordOp( std::vector input_vals, std::vector init_vals, bool is_allreduce) - : Expr(passkey), - output_vals_(std::move(output_vals)), - input_vals_(std::move(input_vals)), - init_vals_(std::move(init_vals)), - is_allreduce_(is_allreduce) { - const auto num_grouped_ops = output_vals_.size(); + : Expr(passkey) { + const auto num_grouped_ops = output_vals.size(); TORCH_INTERNAL_ASSERT( - input_vals_.size() == num_grouped_ops, + input_vals.size() == num_grouped_ops, "Invalid number of input arguments. Expected: ", num_grouped_ops, ", Given: ", - input_vals_.size()); + input_vals.size()); TORCH_INTERNAL_ASSERT( - init_vals_.size() == num_grouped_ops, + init_vals.size() == num_grouped_ops, "Invalid number of N arguments. Expected: ", num_grouped_ops, ", Given: ", - init_vals_.size()); + init_vals.size()); for (const auto i : c10::irange(num_grouped_ops)) { // Check output type TORCH_INTERNAL_ASSERT( - output_vals_[i].avg()->getValType().value() == ValType::TensorView || - output_vals_[i].avg()->getValType().value() == ValType::TensorIndex); + output_vals[i].avg()->getValType().value() == ValType::TensorView || + output_vals[i].avg()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - output_vals_[i].var()->getValType().value() == ValType::TensorView || - output_vals_[i].var()->getValType().value() == ValType::TensorIndex); + output_vals[i].var()->getValType().value() == ValType::TensorView || + output_vals[i].var()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - output_vals_[i].N()->getValType().value() == ValType::TensorView || - output_vals_[i].N()->getValType().value() == ValType::TensorIndex); - TORCH_INTERNAL_ASSERT(isIntegralType(output_vals_[i].N()->dtype())); + output_vals[i].N()->getValType().value() == ValType::TensorView || + output_vals[i].N()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT(isIntegralType(output_vals[i].N()->dtype())); // check initial value - auto init_avg = init_vals_[i].avg(); - auto init_var = init_vals_[i].var(); - auto init_N = init_vals_[i].N(); + auto init_avg = init_vals[i].avg(); + auto init_var = init_vals[i].var(); + auto init_N = init_vals[i].N(); TORCH_INTERNAL_ASSERT( init_avg != nullptr && init_var != nullptr && init_N != nullptr, "nullptr init vals are not allowed"); @@ -1129,9 +763,9 @@ GroupedWelfordOp::GroupedWelfordOp( init_var->toString()); // check input - auto in_avg = input_vals_[i].avg(); - auto in_var = input_vals_[i].var(); - auto in_N = input_vals_[i].N(); + auto in_avg = input_vals[i].avg(); + auto in_var = input_vals[i].var(); + auto in_N = input_vals[i].N(); TORCH_INTERNAL_ASSERT( in_avg != nullptr && in_var != nullptr && in_N != nullptr, "nullptr input vals are not allowed"); @@ -1163,56 +797,22 @@ GroupedWelfordOp::GroupedWelfordOp( } } + addAttribute( + IrBuilder::create>(passkey.ir_container_, is_allreduce)); for (const auto i : c10::irange(num_grouped_ops)) { - addOutput(output_vals_[i].avg()); - addOutput(output_vals_[i].var()); - addOutput(output_vals_[i].N()); - addInput(input_vals_[i].avg()); - addInput(input_vals_[i].var()); - addInput(input_vals_[i].N()); + addOutput(output_vals[i].avg()); + addOutput(output_vals[i].var()); + addOutput(output_vals[i].N()); + addInput(input_vals[i].avg()); + addInput(input_vals[i].var()); + addInput(input_vals[i].N()); + addAttribute(init_vals[i].avg()); + addAttribute(init_vals[i].var()); + addAttribute(init_vals[i].N()); } } -GroupedWelfordOp::GroupedWelfordOp( - const GroupedWelfordOp* src, - IrCloner* ir_cloner) - : Expr(src, ir_cloner), - output_vals_(WelfordTriplet::clone(src->output_vals_, ir_cloner)), - input_vals_(WelfordTriplet::clone(src->input_vals_, ir_cloner)), - init_vals_(WelfordTriplet::clone(src->init_vals_, ir_cloner)), - is_allreduce_(src->is_allreduce_) {} - -Expr* GroupedWelfordOp::shallowCopy() const { - auto result = IrBuilder::create( - output_vals_, input_vals_, init_vals_, is_allreduce_); - result->copyPredicatesFrom(this); - return result; -} - -bool GroupedWelfordOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - - auto grouped_op = dynamic_cast(other); - if (grouped_op == nullptr) { - return false; - } - - if (!Expr::sameAs(other)) { - return false; - } - - for (const auto i : c10::irange(numExprs())) { - if (!initAvg(i)->sameAs(grouped_op->initAvg(i)) || - !initVar(i)->sameAs(grouped_op->initVar(i)) || - !initN(i)->sameAs(grouped_op->initN(i))) { - return false; - } - } - - return true; -} +DEFINE_CLONE_AND_CREATE(GroupedWelfordOp) int GroupedWelfordOp::getExprIndexOfOutput(Val* output_val) const { for (const auto expr_idx : c10::irange(numExprs())) { @@ -1239,7 +839,7 @@ MmaOp::MmaOp( Val* in_a, Val* in_b, Val* init) - : Expr(passkey), out_(out), in_a_(in_a), in_b_(in_b), init_(init) { + : Expr(passkey) { // Check output type TORCH_INTERNAL_ASSERT( out->getValType().value() == ValType::TensorView || @@ -1258,6 +858,9 @@ MmaOp::MmaOp( addOutput(out); addInput(in_a); addInput(in_b); + addAttribute(init); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); } MmaOp::MmaOp( @@ -1268,34 +871,21 @@ MmaOp::MmaOp( Val* init, OptionsInMma options) : MmaOp(passkey, out, in_a, in_b, init) { - options_ = options; + attribute(1)->as>()->value = options; } -MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_a_(ir_cloner->clone(src->in_a_)), - in_b_(ir_cloner->clone(src->in_b_)), - init_(ir_cloner->clone(src->init_)), - options_(src->options_) {} - -Expr* MmaOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_a_, in_b_, init_); - result->options_ = options_; - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(MmaOp) -bool MmaOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (auto other_mma = dynamic_cast(other)) { - return out_->sameAs(other_mma->out_) && in_a_->sameAs(other_mma->in_a_) && - in_b_->sameAs(other_mma->in_b_) && init_->sameAs(other_mma->init_) && - options_ == other_mma->options_; - } - return false; +void MmaOp::configureOptions(MmaOptions options) { + OptionsInMma& opt = attribute(1)->as>()->value; + TORCH_INTERNAL_ASSERT( + options.macro != MmaOptions::MacroType::NoMMA, + "Un-configured mma type from options."); + TORCH_INTERNAL_ASSERT( + options.accumulator_stride > 0, "Un-configured accumulator stride."); + opt.accumulator_stride = options.accumulator_stride; + opt.macro = options.macro; + opt.operand_layout = options.operand_layout; } TransposeOp::TransposeOp( @@ -1303,7 +893,7 @@ TransposeOp::TransposeOp( TensorView* out, TensorView* in, std::vector new2old) - : Expr(passkey), out_(out), in_(in), new2old_(std::move(new2old)) { + : Expr(passkey) { // Sanity check of the input parameters. Maybe not necessary as they // should be checked at function transpose. @@ -1311,44 +901,36 @@ TransposeOp::TransposeOp( TensorDomain::noReductions(in->getMaybeRFactorDomain()).size() == out->getMaybeRFactorDomain().size()); - TORCH_INTERNAL_ASSERT(new2old_.size() == out->getMaybeRFactorDomain().size()); + TORCH_INTERNAL_ASSERT(new2old.size() == out->getMaybeRFactorDomain().size()); // Make sure the entries of new2old are unique and range from 0 to // N-1, where N == new2old.size(). - std::set old_positions(new2old_.begin(), new2old_.end()); - TORCH_INTERNAL_ASSERT(old_positions.size() == new2old_.size()); + std::set old_positions(new2old.begin(), new2old.end()); + TORCH_INTERNAL_ASSERT(old_positions.size() == new2old.size()); // old_positions is sorted, so the first entry must be 0. TORCH_INTERNAL_ASSERT( *(old_positions.begin()) == 0, "Invalid new2old vector detected: ", - new2old_); + new2old); // The last entry must be N-1, since old_positions is sorted, starts // with 0, and its length is N. TORCH_INTERNAL_ASSERT( - *(old_positions.rbegin()) == (int)(new2old_.size() - 1), + *(old_positions.rbegin()) == (int)(new2old.size() - 1), "Invalid new2old vector detected: ", - new2old_); + new2old); addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(new2old))); } -TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - new2old_(src->new2old_) {} - -Expr* TransposeOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, new2old_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(TransposeOp) std::vector TransposeOp::old2new() const { - std::vector old2new(new2old_.size()); - for (auto new_axis : c10::irange(new2old_.size())) { - auto old_axis = new2old_.at(new_axis); + std::vector old2new(new2old().size()); + for (auto new_axis : c10::irange(new2old().size())) { + auto old_axis = new2old().at(new_axis); old2new[old_axis] = new_axis; } return old2new; @@ -1359,13 +941,10 @@ ExpandOp::ExpandOp( TensorView* out, TensorView* in, std::vector _expanded_extents) - : Expr(passkey), - out_(out), - in_(in), - expanded_extents_(std::move(_expanded_extents)) { + : Expr(passkey) { addOutput(out); addInput(in); - for (auto expanded_extent : expanded_extents_) { + for (auto expanded_extent : _expanded_extents) { TORCH_INTERNAL_ASSERT(expanded_extent != nullptr); TORCH_INTERNAL_ASSERT( expanded_extent->dtype() == DataType::Int, @@ -1374,21 +953,7 @@ ExpandOp::ExpandOp( } } -ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) { - expanded_extents_.reserve(src->expanded_extents_.size()); - for (const auto expanded_extent : src->expanded_extents_) { - expanded_extents_.push_back(ir_cloner->clone(expanded_extent)); - } -} - -Expr* ExpandOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, expanded_extents_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(ExpandOp) ShiftOp::ShiftOp( IrBuilderPasskey passkey, @@ -1396,14 +961,10 @@ ShiftOp::ShiftOp( Val* in, std::vector offsets, std::vector pad_width) - : Expr(passkey), - out_(out), - in_(in), - offsets_(std::move(offsets)), - pad_width_(std::move(pad_width)) { - // clang-tidy complains about out_ that it may be null. - TORCH_INTERNAL_ASSERT(out_ != nullptr); - TORCH_INTERNAL_ASSERT(in_ != nullptr); + : Expr(passkey) { + // clang-tidy complains about out that it may be null. + TORCH_INTERNAL_ASSERT(out != nullptr); + TORCH_INTERNAL_ASSERT(in != nullptr); auto out_type = out->getValType().value(); auto in_type = in->getValType().value(); @@ -1413,49 +974,28 @@ ShiftOp::ShiftOp( "Cannot shift a non-tensor object."); TORCH_INTERNAL_ASSERT( - offsets_.size() == - TensorDomain::noReductions(in_->as()->getRootDomain()) + offsets.size() == + TensorDomain::noReductions(in->as()->getRootDomain()) .size(), "Invalid offset vector: ", - offsets_); + offsets); TORCH_INTERNAL_ASSERT( - pad_width_.size() == - TensorDomain::noReductions(in_->as()->getRootDomain()) + pad_width.size() == + TensorDomain::noReductions(in->as()->getRootDomain()) .size(), "Invalid padding width vector: ", - pad_width_); + pad_width); addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(offsets))); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(pad_width))); } -ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - offsets_(src->offsets_), - pad_width_(src->pad_width_) {} - -Expr* ShiftOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, offsets_, pad_width_); - result->copyPredicatesFrom(this); - return result; -} - -bool ShiftOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (offsets() != other_op->offsets()) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(ShiftOp) GatherOp::GatherOp( IrBuilderPasskey passkey, @@ -1463,14 +1003,10 @@ GatherOp::GatherOp( Val* in, std::vector window_shape, std::vector> pad_width) - : Expr(passkey), - out_(out), - in_(in), - window_shape_(std::move(window_shape)), - pad_width_(std::move(pad_width)) { + : Expr(passkey) { // clang-tidy complains about out_ that it may be null. - TORCH_INTERNAL_ASSERT(out_ != nullptr); - TORCH_INTERNAL_ASSERT(in_ != nullptr); + TORCH_INTERNAL_ASSERT(out != nullptr); + TORCH_INTERNAL_ASSERT(in != nullptr); auto out_type = out->getValType().value(); auto in_type = in->getValType().value(); @@ -1480,52 +1016,29 @@ GatherOp::GatherOp( "Cannot shift a non-tensor object."); const auto ndims = - TensorDomain::noReductions(in_->as()->getRootDomain()).size(); + TensorDomain::noReductions(in->as()->getRootDomain()).size(); TORCH_INTERNAL_ASSERT( - window_shape_.size() == ndims, + window_shape.size() == ndims, "Invalid window_shape vector: ", - window_shape_); + window_shape); TORCH_INTERNAL_ASSERT( - pad_width_.size() == ndims, "Invalid pad_width vector: ", pad_width_); + pad_width.size() == ndims, "Invalid pad_width vector: ", pad_width); - for (const auto& pad : pad_width_) { + for (const auto& pad : pad_width) { TORCH_INTERNAL_ASSERT( pad.size() == 2, "Padding size for each axis must have two Int vals."); } addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(window_shape))); + addAttribute(IrBuilder::create>>>( + passkey.ir_container_, std::move(pad_width))); } -GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - window_shape_(src->window_shape_), - pad_width_(src->pad_width_) {} - -Expr* GatherOp::shallowCopy() const { - auto result = - IrBuilder::create(out_, in_, window_shape_, pad_width_); - result->copyPredicatesFrom(this); - return result; -} - -bool GatherOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (windowShape() != other_op->windowShape() || - padWidth() != other_op->padWidth()) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(GatherOp) int GatherOp::gatherAxis(int axis) const { if (axis < 0) { @@ -1542,62 +1055,35 @@ ViewAsScalar::ViewAsScalar( Val* in, IterDomain* vector_id, Val* index) - : Expr(passkey), out_(out), in_(in), vector_id_(vector_id), index_(index) { + : Expr(passkey) { addOutput(out); addInput(in); + addAttribute(vector_id); + addAttribute(index); } -ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - vector_id_(ir_cloner->clone(src->vector_id_)), - index_(ir_cloner->clone(src->index_)) {} - -Expr* ViewAsScalar::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, vector_id_, index_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(ViewAsScalar) -ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) - : Expr(passkey), out_(out), in_(in) { +ViewOp::ViewOp(IrBuilderPasskey passkey, Val* out, Val* in) : Expr(passkey) { addOutput(out); addInput(in); } -ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) {} - -Expr* ViewOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(ViewOp) LoadStoreOp::LoadStoreOp( IrBuilderPasskey passkey, LoadStoreOpType op_type, Val* out, Val* in) - : Expr(passkey), load_store_type_(op_type), out_(out), in_(in) { + : Expr(passkey) { addOutput(out); addInput(in); + addAttribute(IrBuilder::create>( + passkey.ir_container_, op_type)); } -LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - load_store_type_(src->load_store_type_), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) {} - -Expr* LoadStoreOp::shallowCopy() const { - auto result = IrBuilder::create(load_store_type_, out_, in_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(LoadStoreOp) IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) : start_(_start), extent_(_extent) { @@ -1765,6 +1251,8 @@ IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) padded_to_size_(src->padded_to_size_), is_mma_swizzled_(src->is_mma_swizzled_) {} +DEFINE_CLONE(IterDomain) + bool IterDomain::sameAs(const Statement* other) const { if (other == this) { return true; @@ -2200,6 +1688,8 @@ TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner) contiguity_(src->contiguity()), has_reduction_(src->has_reduction_) {} +DEFINE_CLONE(TensorDomain) + bool TensorDomain::hasBlockBroadcast() const { return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { return id->isBroadcast() && id->isThreadDim(); @@ -2672,44 +2162,29 @@ Split::Split( bool inner_split, Val* start_offset, Val* stop_offset) - : Expr(passkey), - outer_{outer}, - inner_{inner}, - in_{in}, - factor_{factor}, - inner_split_{inner_split}, - start_offset_{ - start_offset != nullptr ? start_offset - : passkey.ir_container_->zeroVal()}, - stop_offset_{ - stop_offset != nullptr ? stop_offset - : passkey.ir_container_->zeroVal()} { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( - factor_->isAnInt(), + factor->isAnInt(), "Attempted to create a Split node with a non-integer factor."); + if (start_offset == nullptr) { + start_offset = passkey.ir_container_->zeroVal(); + } + if (stop_offset == nullptr) { + stop_offset = passkey.ir_container_->zeroVal(); + } addOutput(outer); addOutput(inner); addInput(in); // TODO add factor as an input, need to check Split::Split during validation // and need to check BestEffortReplay::findFirstMismatchedID addInput(factor); + addAttribute(factor); + addAttribute( + IrBuilder::create>(passkey.ir_container_, inner_split)); + addAttribute(start_offset); + addAttribute(stop_offset); } -Split::Split(const Split* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - outer_(ir_cloner->clone(src->outer_)), - inner_(ir_cloner->clone(src->inner_)), - in_(ir_cloner->clone(src->in_)), - factor_(ir_cloner->clone(src->factor_)), - inner_split_(src->inner_split_), - start_offset_(ir_cloner->clone(src->start_offset_)), - stop_offset_(ir_cloner->clone(src->stop_offset_)) {} - -Expr* Split::shallowCopy() const { - auto result = IrBuilder::create( - outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(Split) Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { TORCH_INTERNAL_ASSERT(in_extent != nullptr); @@ -2725,52 +2200,18 @@ Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { return in_extent; } -bool Split::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - return Expr::sameAs(other) && - factor()->sameAs(other->as()->factor()) && - innerSplit() == other->as()->innerSplit() && - startOffset()->sameAs(other->as()->startOffset()) && - stopOffset()->sameAs(other->as()->stopOffset()); -} - Merge::Merge( IrBuilderPasskey passkey, IterDomain* out, IterDomain* outer, IterDomain* inner) - : Expr(passkey), out_{out}, outer_{outer}, inner_{inner} { + : Expr(passkey) { addOutput(out); addInput(outer); addInput(inner); } -Merge::Merge(const Merge* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - outer_(ir_cloner->clone(src->outer_)), - inner_(ir_cloner->clone(src->inner_)) {} - -Expr* Merge::shallowCopy() const { - auto result = IrBuilder::create(out_, outer_, inner_); - result->copyPredicatesFrom(this); - return result; -} - -bool Merge::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - return Expr::sameAs(other); -} +DEFINE_CLONE_AND_CREATE(Merge) Swizzle2D::Swizzle2D( IrBuilderPasskey passkey, @@ -2780,47 +2221,18 @@ Swizzle2D::Swizzle2D( IterDomain* in_y, Swizzle2DType swizzle_type, SwizzleMode swizzle_mode) - : Expr(passkey), - out_x_{out_x}, - out_y_{out_y}, - in_x_{in_x}, - in_y_{in_y}, - swizzle_type_(swizzle_type), - swizzle_mode_(swizzle_mode) { + : Expr(passkey) { addOutput(out_x); addOutput(out_y); addInput(in_x); addInput(in_y); + addAttribute(IrBuilder::create>( + passkey.ir_container_, swizzle_type)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, swizzle_mode)); } -Expr* Swizzle2D::shallowCopy() const { - auto result = IrBuilder::create( - out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_); - result->copyPredicatesFrom(this); - return result; -} - -bool Swizzle2D::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - if (!(swizzle_type_ == other->as()->swizzle_type_)) { - return false; - } - return Expr::sameAs(other); -} - -Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_x_(ir_cloner->clone(src->out_x_)), - out_y_(ir_cloner->clone(src->out_y_)), - in_x_(ir_cloner->clone(src->in_x_)), - in_y_(ir_cloner->clone(src->in_y_)), - swizzle_type_(src->swizzle_type_), - swizzle_mode_(src->swizzle_mode_) {} +DEFINE_CLONE_AND_CREATE(Swizzle2D) NamedScalar::NamedScalar( IrBuilderPasskey passkey, @@ -2831,6 +2243,8 @@ NamedScalar::NamedScalar( NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner) : Val(src, ir_cloner), name_(src->name_) {} +DEFINE_CLONE(NamedScalar) + bool NamedScalar::sameAs(const Statement* other) const { if (this == other) { return true; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index c1281e1a27a2..8f80f741fadc 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -87,43 +88,117 @@ Val* TensorIndex::index(int i) const { return indices_[i]; } -BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) - : Expr(passkey), war_sync_(war_sync) { +Allocate::Allocate( + IrBuilderPasskey passkey, + Val* buffer, + MemoryType memory_type, + std::vector shape, + bool zero_init, + const Allocate* alias) + : Expr(passkey) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + if (!shape.empty()) { + TORCH_INTERNAL_ASSERT( + (shape.size() == 1 && shape[0]->isOneInt()) || + buffer->isA()); + } else { + TORCH_INTERNAL_ASSERT(buffer->isA()); + TORCH_INTERNAL_ASSERT( + buffer->as()->getMemoryType() == memory_type); + const auto domain = buffer->as()->domain(); + for (auto axis : domain->noReductions()) { + shape.push_back(axis->extent()); + } + } + + Val* size = nullptr; + for (auto s : shape) { + if (size == nullptr) { + size = s; + } else { + size = IrBuilder::mulExpr(size, s); + } + } + + if (size == nullptr) { + size = FusionGuard::getCurFusion()->oneVal(); + } + + if (alias != nullptr) { + TORCH_INTERNAL_ASSERT(alias != this, "Invalid alias"); + TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type, "Invalid alias"); + } + + addInput(size); + addAttribute(buffer); + addAttribute(IrBuilder::create>( + passkey.ir_container_, memory_type)); + addAttribute( + IrBuilder::create>(passkey.ir_container_, zero_init)); + + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute(IrBuilder::create>( + passkey.ir_container_, alias)); + + for (auto s : shape) { + addAttribute(s); + } +} + +Allocate::Allocate( + IrBuilderPasskey passkey, + Val* buffer, + MemoryType memory_type, + Val* size, + bool zero_init) + : Allocate( + passkey, + buffer, + memory_type, + size == nullptr ? std::vector{} : std::vector{size}, + zero_init) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); } -Expr* BlockSync::shallowCopy() const { - auto result = IrBuilder::create(war_sync_); - result->copyPredicatesFrom(this); - return result; +DEFINE_CLONE_AND_CREATE(Allocate) + +BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) : Expr(passkey) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + addAttribute( + IrBuilder::create>(passkey.ir_container_, war_sync)); } +DEFINE_CLONE_AND_CREATE(BlockSync) + GridSync::GridSync( IrBuilderPasskey passkey, ParallelTypeBitmap sync_dims, Val* sync_buffer) - : Expr(passkey), sync_dims_(sync_dims), sync_buffer_(sync_buffer) {} - -Expr* GridSync::shallowCopy() const { - auto result = IrBuilder::create(sync_dims_, sync_buffer_); - result->copyPredicatesFrom(this); - return result; + : Expr(passkey) { + addAttribute(IrBuilder::create>( + passkey.ir_container_, sync_dims)); + addAttribute(sync_buffer); } +DEFINE_CLONE_AND_CREATE(GridSync) + CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) - : Expr(passkey), keep_stages_(keep_stages) { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + addAttribute(IrBuilder::create>( + passkey.ir_container_, keep_stages)); } -Expr* CpAsyncWait::shallowCopy() const { - auto result = IrBuilder::create(keep_stages_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(CpAsyncWait) CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) : Expr(passkey) { TORCH_INTERNAL_ASSERT( @@ -131,11 +206,7 @@ CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) : Expr(passkey) { "IR type only valid for Kernel container."); } -Expr* CpAsyncCommit::shallowCopy() const { - auto result = IrBuilder::create(); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(CpAsyncCommit) InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { TORCH_INTERNAL_ASSERT( @@ -143,11 +214,7 @@ InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { "IR type only valid for Kernel container."); } -Expr* InitMagicZero::shallowCopy() const { - auto result = IrBuilder::create(); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(InitMagicZero) UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { TORCH_INTERNAL_ASSERT( @@ -155,11 +222,7 @@ UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { "IR type only valid for Kernel container."); } -Expr* UpdateMagicZero::shallowCopy() const { - auto result = IrBuilder::create(); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(UpdateMagicZero) void Scope::insert(std::vector::const_iterator pos, Expr* expr) { exprs_.insert(pos, expr); @@ -232,33 +295,36 @@ ForLoop::ForLoop( Val* vectorize_shift, bool unroll_required, DoubleBufferLoopStage double_buffer_loop_stage) - : Expr(passkey), - iter_domain_{iter_domain}, - index_(index), - start_(start), - stop_(stop), - step_(step), - vectorize_(vectorize), - vectorize_shift_(vectorize_shift), - unroll_required_(unroll_required), - body_(this), - double_buffer_loop_stage_(double_buffer_loop_stage) { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); TORCH_INTERNAL_ASSERT(index->dtype() == DataType::Int); addInput(index); addInput(iter_domain); - if (start_ == nullptr && iter_domain->isThread()) { - start_ = NamedScalar::getParallelIndex(iter_domain->getParallelType()); + if (start == nullptr && iter_domain->isThread()) { + start = NamedScalar::getParallelIndex(iter_domain->getParallelType()); } - if (step_ == nullptr) { + if (step == nullptr) { if (iter_domain->isThread()) { - step_ = NamedScalar::getParallelDim(iter_domain->getParallelType()); + step = NamedScalar::getParallelDim(iter_domain->getParallelType()); } else { - step_ = FusionGuard::getCurFusion()->oneVal(); + step = FusionGuard::getCurFusion()->oneVal(); } } + addAttribute(start); + addAttribute(stop); + addAttribute(step); + addAttribute( + IrBuilder::create>(passkey.ir_container_, vectorize)); + addAttribute(vectorize_shift); + addAttribute(IrBuilder::create>( + passkey.ir_container_, unroll_required)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, double_buffer_loop_stage)); + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute(IrBuilder::create>(passkey.ir_container_, this)); } ForLoop::ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain) @@ -296,21 +362,7 @@ ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) "IR type only valid for Kernel container."); } -Expr* ForLoop::shallowCopy() const { - auto result = IrBuilder::create( - iter_domain_, - index_, - start_, - stop_, - step_, - vectorize_, - vectorize_shift_, - unroll_required_, - double_buffer_loop_stage_); - result->body_ = body_; - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(ForLoop) bool ForLoop::isUnrollable() const { // Start and stop must be constant, must not be a broadcast @@ -325,7 +377,7 @@ bool ForLoop::isUnrolled() const { if (isUnrollRequired() && !isUnrollable()) { TORCH_WARN( "Unroll required but not possible. Register allocation disabled. Loop index: ", - index_->toString()); + index()->toString()); return false; } @@ -356,28 +408,28 @@ bool ForLoop::isUnrolled() const { } Val* ForLoop::start() const { - if (start_ != nullptr) { - return start_; + if (attribute(0) != nullptr) { + return attribute(0); } else { // clang-tidy complains without this - TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr); - return iter_domain_->start(); + TORCH_INTERNAL_ASSERT(iter_domain() != nullptr); + return iter_domain()->start(); } } Val* ForLoop::stop() const { - if (stop_ != nullptr) { - return stop_; + if (attribute(1) != nullptr) { + return attribute(1); } else { // clang-tidy complains without this - TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr); - return iter_domain_->extent(); + TORCH_INTERNAL_ASSERT(iter_domain() != nullptr); + return iter_domain()->extent(); } } Val* ForLoop::step() const { - TORCH_INTERNAL_ASSERT(step_ != nullptr); - return step_; + TORCH_INTERNAL_ASSERT(attribute(2) != nullptr); + return attribute(2); } bool ForLoop::isTrivial() const { @@ -426,93 +478,16 @@ bool ForLoop::isTrivial() const { } IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond) - : Expr(passkey), then_body_(this), else_body_(this) { + : Expr(passkey) { setPredicate(cond); addInput(cond); + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute(IrBuilder::create>(passkey.ir_container_, this)); + addAttribute(IrBuilder::create>(passkey.ir_container_, this)); } -Expr* IfThenElse::shallowCopy() const { - auto result = IrBuilder::create(predicate()); - result->then_body_ = then_body_; - result->else_body_ = else_body_; - result->setWritePredicate(writePredicate()); - return result; -} - -Allocate::Allocate( - IrBuilderPasskey passkey, - Val* buffer, - MemoryType memory_type, - std::vector shape, - bool zero_init, - const Allocate* alias) - : Expr(passkey), - buffer_(buffer), - memory_type_(memory_type), - shape_(std::move(shape)), - zero_init_(zero_init), - alias_(alias) { - TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); - if (!shape_.empty()) { - TORCH_INTERNAL_ASSERT( - (shape_.size() == 1 && shape_[0]->isOneInt()) || - buffer_->isA()); - } else { - TORCH_INTERNAL_ASSERT(buffer_->isA()); - TORCH_INTERNAL_ASSERT( - buffer_->as()->getMemoryType() == memory_type_); - const auto domain = buffer_->as()->domain(); - for (auto axis : domain->noReductions()) { - shape_.push_back(axis->extent()); - } - } - - for (auto s : shape_) { - if (size_ == nullptr) { - size_ = s; - } else { - size_ = IrBuilder::mulExpr(size_, s); - } - } - - if (size_ == nullptr) { - size_ = FusionGuard::getCurFusion()->oneVal(); - } - - if (alias_ != nullptr) { - TORCH_INTERNAL_ASSERT(alias_ != this, "Invalid alias"); - TORCH_INTERNAL_ASSERT( - alias_->memoryType() == memory_type_, "Invalid alias"); - } - - addInput(size_); -} - -Allocate::Allocate( - IrBuilderPasskey passkey, - Val* buffer, - MemoryType memory_type, - Val* size, - bool zero_init) - : Allocate( - passkey, - buffer, - memory_type, - size == nullptr ? std::vector{} : std::vector{size}, - zero_init) { - TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); -} - -Expr* Allocate::shallowCopy() const { - auto result = - IrBuilder::create(buffer_, memory_type_, shape_, zero_init_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(IfThenElse) GridReduction::GridReduction( IrBuilderPasskey passkey, @@ -525,31 +500,27 @@ GridReduction::GridReduction( Val* entrance_index, Val* entrances, bool is_allreduce) - : ReductionOp(passkey, reduction_op_type, init, out, in, is_allreduce), - reduction_buffer_(reduction_buffer), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances) { + : ReductionOp(passkey, reduction_op_type, init, out, in, is_allreduce) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); -} - -Expr* GridReduction::shallowCopy() const { - auto result = IrBuilder::create( - getReductionOpType(), - init(), - out(), - in(), - reduction_buffer_, - sync_buffer_, - entrance_index_, - entrances_, - isAllreduce()); - result->copyPredicatesFrom(this); - result->thread_predicate_ = thread_predicate_; - return result; -} + TORCH_INTERNAL_ASSERT( + attributes().size() == num_reduction_op_attr, + "The num_reduction_op_attr does not match the number of attributes ReductionOp has." + "If you changed ReductionOp, please change num_reduction_op_attr accordingly."); + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute(IrBuilder::create>( + passkey.ir_container_, reduction_buffer)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, sync_buffer)); + addAttribute(entrance_index); + addAttribute(entrances); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); +} + +DEFINE_CLONE_AND_CREATE(GridReduction) GroupedGridReduction::GroupedGridReduction( IrBuilderPasskey passkey, @@ -569,54 +540,49 @@ GroupedGridReduction::GroupedGridReduction( std::move(init_vals), std::move(outputs), std::move(inputs), - is_allreduce), - reduction_buffers_(std::move(reduction_buffers)), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances), - buffer_stride_(buffer_stride) { + is_allreduce) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); -} - -Expr* GroupedGridReduction::shallowCopy() const { - auto result = IrBuilder::create( - getReductionOpTypes(), - initVals(), - outputs(), - inputs(), - reduction_buffers_, - sync_buffer_, - entrance_index_, - entrances_, - buffer_stride_, - isAllreduce()); - result->copyPredicatesFrom(this); - result->thread_predicate_ = thread_predicate_; - return result; -} + TORCH_INTERNAL_ASSERT( + attributes().size() == numGroupedReductionOpAttr(), + "The numGroupedReductionOpAttr() does not match the number of attributes GroupedReductionOp has." + "If you changed GroupedReductionOp, please change numGroupedReductionOpAttr() accordingly."); + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(reduction_buffers))); + addAttribute(IrBuilder::create>( + passkey.ir_container_, sync_buffer)); + addAttribute(entrance_index); + addAttribute(entrances); + addAttribute(buffer_stride); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); +} + +DEFINE_CLONE_AND_CREATE(GroupedGridReduction) GridBroadcast::GridBroadcast( IrBuilderPasskey passkey, BroadcastOp* broadcast_op, Allocate* broadcast_buffer, Allocate* sync_buffer) - : Expr(passkey), - broadcast_op_(broadcast_op), - broadcast_buffer_(broadcast_buffer), - sync_buffer_(sync_buffer) { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute(IrBuilder::create>( + passkey.ir_container_, broadcast_op)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, broadcast_buffer)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, sync_buffer)); } -Expr* GridBroadcast::shallowCopy() const { - auto result = IrBuilder::create( - broadcast_op_, broadcast_buffer_, sync_buffer_); - result->copyPredicatesFrom(this); - return result; -} +DEFINE_CLONE_AND_CREATE(GridBroadcast) GridWelford::GridWelford( IrBuilderPasskey passkey, @@ -627,32 +593,29 @@ GridWelford::GridWelford( Allocate* sync_buffer, Val* entrance_index, Val* entrances) - : Expr(passkey), - welford_op_(welford_op), - var_buffer_(var_buffer), - avg_buffer_(avg_buffer), - n_buffer_(n_buffer), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances) { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); -} - -Expr* GridWelford::shallowCopy() const { - auto result = IrBuilder::create( - welford_op_, - var_buffer_, - avg_buffer_, - n_buffer_, - sync_buffer_, - entrance_index_, - entrances_); - result->copyPredicatesFrom(this); - result->thread_predicate_ = thread_predicate_; - return result; -} + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute(IrBuilder::create>( + passkey.ir_container_, welford_op)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, var_buffer)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, avg_buffer)); + addAttribute( + IrBuilder::create>(passkey.ir_container_, n_buffer)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, sync_buffer)); + addAttribute(entrance_index); + addAttribute(entrances); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); +} + +DEFINE_CLONE_AND_CREATE(GridWelford) GroupedGridWelford::GroupedGridWelford( IrBuilderPasskey passkey, @@ -670,133 +633,79 @@ GroupedGridWelford::GroupedGridWelford( std::move(output_vals), std::move(input_vals), std::move(init_vals), - is_allreduce), - reduction_buffers_(std::move(reduction_buffers)), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances), - buffer_stride_(buffer_stride) { + is_allreduce) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); -} - -Expr* GroupedGridWelford::shallowCopy() const { - auto result = IrBuilder::create( - outputVals(), - inputVals(), - initVals(), - reduction_buffers_, - sync_buffer_, - entrance_index_, - entrances_, - buffer_stride_, - isAllreduce()); - result->copyPredicatesFrom(this); - result->thread_predicate_ = thread_predicate_; - return result; -} - -AllocateFusedReduction::AllocateFusedReduction( - IrBuilderPasskey passkey, - GridReduction* grid_reduction) - : Expr(passkey), grid_expr_(grid_reduction) { TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); -} + attributes().size() == numGroupedWelfordOpAttr(), + "The numGroupedWelfordOpAttr() does not match the number of attributes GroupedWelfordOp has." + "If you changed GroupedReductionOp, please change numGroupedWelfordOpAttr() accordingly."); + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute( + IrBuilder::create, 3>>>( + passkey.ir_container_, std::move(reduction_buffers))); + addAttribute(IrBuilder::create>( + passkey.ir_container_, sync_buffer)); + addAttribute(entrance_index); + addAttribute(entrances); + addAttribute(buffer_stride); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); +} + +DEFINE_CLONE_AND_CREATE(GroupedGridWelford) AllocateFusedReduction::AllocateFusedReduction( IrBuilderPasskey passkey, - GridWelford* grid_welford) - : Expr(passkey), grid_expr_(grid_welford) { + Expr* grid_expr) + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + // Storing IR nodes as PlainVal is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute( + IrBuilder::create>(passkey.ir_container_, grid_expr)); } -AllocateFusedReduction::AllocateFusedReduction( - IrBuilderPasskey passkey, - GroupedGridReduction* grouped_grid_reduction) - : Expr(passkey), grid_expr_(grouped_grid_reduction) { - TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); -} - -AllocateFusedReduction::AllocateFusedReduction( - IrBuilderPasskey passkey, - GroupedGridWelford* grouped_grid_welford) - : Expr(passkey), grid_expr_(grouped_grid_welford) { - TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); -} - -Expr* AllocateFusedReduction::shallowCopy() const { - if (grid_expr_->isA()) { - auto result = IrBuilder::create( - grid_expr_->as()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - return result; - } else if (grid_expr_->isA()) { - auto result = IrBuilder::create( - grid_expr_->as()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - return result; - } else if (grid_expr_->isA()) { - auto result = IrBuilder::create( - grid_expr_->as()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - return result; - } else if (grid_expr_->isA()) { - auto result = IrBuilder::create( - grid_expr_->as()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - return result; - } - TORCH_INTERNAL_ASSERT( - false, "Unknown reduction type in AllocateFusedReduction::shallowCopy"); -} +DEFINE_CLONE_AND_CREATE(AllocateFusedReduction) TensorIndex* AllocateFusedReduction::out() const { - TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); - if (grid_expr_->isA() || - grid_expr_->isA()) { - return grid_expr_->outputs().at(0)->as(); - } else if (auto grid_welford = dynamic_cast(grid_expr_)) { + TORCH_INTERNAL_ASSERT(gridExpr() != nullptr); + if (gridExpr()->isA() || + gridExpr()->isA()) { + return gridExpr()->outputs().at(0)->as(); + } else if (auto grid_welford = dynamic_cast(gridExpr())) { return grid_welford->welford_op()->out()->as(); } else if ( auto grouped_grid_welford = - dynamic_cast(grid_expr_)) { + dynamic_cast(gridExpr())) { return grouped_grid_welford->out(0)->as(); } else { TORCH_INTERNAL_ASSERT( - false, "Invalid grid expression: ", grid_expr_->toString()); + false, "Invalid grid expression: ", gridExpr()->toString()); } } const ParallelTypeBitmap& AllocateFusedReduction::threadPredicate() const { - TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); - if (auto grid_reduction = dynamic_cast(grid_expr_)) { + TORCH_INTERNAL_ASSERT(gridExpr() != nullptr); + if (auto grid_reduction = dynamic_cast(gridExpr())) { return grid_reduction->threadPredicate(); - } else if (auto grid_welford = dynamic_cast(grid_expr_)) { + } else if (auto grid_welford = dynamic_cast(gridExpr())) { return grid_welford->threadPredicate(); } else if ( auto grouped_grid_reduction = - dynamic_cast(grid_expr_)) { + dynamic_cast(gridExpr())) { return grouped_grid_reduction->threadPredicate(); } else if ( auto grouped_grid_welford = - dynamic_cast(grid_expr_)) { + dynamic_cast(gridExpr())) { return grouped_grid_welford->threadPredicate(); } else { TORCH_INTERNAL_ASSERT( - false, "Invalid grid expression: ", grid_expr_->toString()); + false, "Invalid grid expression: ", gridExpr()->toString()); } } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 572900445ed3..37fabbf97c2b 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -179,6 +179,8 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { //! describes the output of an operation. class TORCH_CUDA_CU_API Allocate final : public Expr { public: + using Expr::Expr; + //! Allocation of a multi-dimensional buffer //! //! param shape Size of each dimension @@ -204,44 +206,35 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { return "Allocate"; } - Expr* shallowCopy() const override; + DECLARE_CLONE_AND_CREATE Val* buffer() const { - return buffer_; + return attribute(0); } MemoryType memoryType() const { - return memory_type_; + return attribute(1)->as>()->value; } + //! Total size Val* size() const { - return size_; + return input(0); } - const std::vector& shape() const { - return shape_; + //! Size of each dimension + std::vector shape() const { + return {attributes().begin() + 4, attributes().end()}; } bool zeroInit() const { - return zero_init_; - } - - const Allocate* alias() const { - return alias_; + return attribute(2)->as>()->value; } - private: - Val* buffer_ = nullptr; - MemoryType memory_type_ = MemoryType::Local; - //! Size of each dimension - std::vector shape_; - bool zero_init_ = false; - //! Total size - Val* size_ = nullptr; - // This alias tracks the next Allocate node in a linked chain of aliases // If the alias is nullptr, then the Allocate node uses memory in the kernel - const Allocate* alias_ = nullptr; + const Allocate* alias() const { + return attribute(3)->as>()->value; + } }; // Sync represents __syncthreads barrier for block level coordination. @@ -250,43 +243,66 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { // class TORCH_CUDA_CU_API BlockSync final : public Expr { public: + using Expr::Expr; + explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); virtual const char* getOpString() const override { return "BlockSync"; } - Expr* shallowCopy() const override; + DECLARE_CLONE_AND_CREATE + // TODO: war_sync_ is only used for testing/validation purposes. bool isWarHazardSync() const { - return war_sync_; + return attribute(0)->as>()->value; } +}; - private: - // TODO: war_sync_ is only used for testing/validation purposes. - bool war_sync_ = false; +// Synchronize all blocks in device, implies cooperative group launch is +// required. +class TORCH_CUDA_CU_API GridSync final : public Expr { + public: + using Expr::Expr; + + explicit GridSync( + IrBuilderPasskey passkey, + ParallelTypeBitmap sync_dims, + Val* sync_buffer); + + DECLARE_CLONE_AND_CREATE + + virtual const char* getOpString() const override { + return "GridSync"; + } + + ParallelTypeBitmap syncDims() const { + return attribute(0)->as>()->value; + } + + Val* syncBuffer() const { + return attribute(1); + } }; // CpAsyncWait represents wait intrinsics for cp.async class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { public: + using Expr::Expr; + explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0); + DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "CpAsyncWait"; } - Expr* shallowCopy() const override; - //! Returns the remaining number of stages that are not synchronized //! after this op. unsigned int keepStages() const { - return keep_stages_; + return attribute(0)->as>()->value; } - - private: - //! Number of stage to leave un-sync'ed by this op. - unsigned int keep_stages_ = 0; }; // CpAsyncCommit represents commit intrinsics for cp.async @@ -294,67 +310,45 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { // to the async load hardware. Example usage see [Cicular buffer]. class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { public: - explicit CpAsyncCommit(IrBuilderPasskey passkey); + using Expr::Expr; - virtual const char* getOpString() const override { - return "CpAsyncCommit"; - } - - Expr* shallowCopy() const override; -}; + explicit CpAsyncCommit(IrBuilderPasskey passkey); -// Synchronize all blocks in device, implies cooperative group launch is -// required. -class TORCH_CUDA_CU_API GridSync final : public Expr { - public: - explicit GridSync( - IrBuilderPasskey passkey, - ParallelTypeBitmap sync_dims, - Val* sync_buffer); + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { - return "GridSync"; - } - - Expr* shallowCopy() const override; - - ParallelTypeBitmap syncDims() const { - return sync_dims_; - } - - Val* syncBuffer() const { - return sync_buffer_; + return "CpAsyncCommit"; } - - private: - ParallelTypeBitmap sync_dims_; - Val* sync_buffer_ = nullptr; }; // Simply prints "DEFINE_MAGIC_ZERO" in the code in accordance with magic_zero // in helpers.cu class TORCH_CUDA_CU_API InitMagicZero final : public Expr { public: + using Expr::Expr; + explicit InitMagicZero(IrBuilderPasskey passkey); + DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "InitMagicZero"; } - - Expr* shallowCopy() const override; }; // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero // in helpers.cu class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { public: + using Expr::Expr; + explicit UpdateMagicZero(IrBuilderPasskey passkey); + DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "UpdateMagicZero"; } - - Expr* shallowCopy() const override; }; // TODO(kir): promote to IR node @@ -409,6 +403,10 @@ class TORCH_CUDA_CU_API Scope { return owner_; } + bool operator==(const Scope&) const { + TORCH_INTERNAL_ASSERT(false, "Should not reach here"); + } + private: // Insert expr before pos void insert(std::vector::const_iterator pos, Expr* expr); @@ -435,6 +433,8 @@ class TORCH_CUDA_CU_API Scope { //! be smaller than the extent of iter_domain_. class TORCH_CUDA_CU_API ForLoop final : public Expr { public: + using Expr::Expr; + //! By default, start and stop are the same as those of iter_domain. //! Step is one by default. //! @@ -455,14 +455,14 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { ForLoop(IrBuilderPasskey passkey, const ForLoop* other); + DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "ForLoop"; } - Expr* shallowCopy() const override; - Val* index() const { - return index_; + return input(0); } Val* start() const; @@ -471,38 +471,42 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { Val* step() const; + // [pre | vectorize | post] <= inner-most, merged root domain + // shift_ is applied to vectorize and post sections. Val* vectorize_shift() const { - return vectorize_shift_; + return attribute(4); } IterDomain* iter_domain() const { - return iter_domain_; + return input(1)->as(); } // TODO: Return pointer instead of reference to be more consistent Scope& body() { - return body_; + return attribute(7)->as>()->value; } const Scope& body() const { - return body_; + return attribute(7)->as>()->value; } + // vectorize is true when the for-loop contains a vectorize set + // the flag is used to omit the for-loop from the kernel bool vectorize() const { - return vectorize_; + return attribute(3)->as>()->value; } //! True if unrolled (i.e., "#pragma unroll" is attached) bool isUnrolled() const; - //! True if unrolling is required + //! True if unroll is required for avoiding stack allocation bool isUnrollRequired() const { - return unroll_required_; + return attribute(5)->as>()->value; } //! Set unrolling required void requireUnroll() { - unroll_required_ = true; + attribute(5)->as>()->value = true; } //! True if no actual for-loop is materialized @@ -511,37 +515,12 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! Returns the stage of a double buffered iterdomain //! that this for loop materializes. auto doubleBufferLoopStage() const { - return double_buffer_loop_stage_; + return attribute(6)->as>()->value; } private: //! Returns if a loop could be unrolled. bool isUnrollable() const; - - private: - IterDomain* const iter_domain_ = nullptr; - - Val* index_ = nullptr; - Val* start_ = nullptr; - Val* stop_ = nullptr; - Val* step_ = nullptr; - - // vectorize is true when the for-loop contains a vectorize set - // the flag is used to omit the for-loop from the kernel - bool vectorize_ = false; - // [pre | vectorize | post] <= inner-most, merged root domain - // shift_ is applied to vectorize and post sections. - Val* vectorize_shift_ = nullptr; - - //! True if unroll is required for avoiding stack allocation - bool unroll_required_ = false; - - Scope body_; - - //! Tracks if this for loop is implementing a stage of - //! a double buffered iterdomain. - DoubleBufferLoopStage double_buffer_loop_stage_ = - DoubleBufferLoopStage::NotApplicable; }; //! IfThenElse provides scoping for an boolean operator. Exprs placed in its @@ -553,36 +532,34 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! class TORCH_CUDA_CU_API IfThenElse final : public Expr { public: + using Expr::Expr; + explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); + DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "IfThenElse"; } - Expr* shallowCopy() const override; - Scope& thenBody() { - return then_body_; + return attribute(0)->as>()->value; } const Scope& thenBody() const { - return then_body_; + return attribute(0)->as>()->value; } Scope& elseBody() { - return else_body_; + return attribute(1)->as>()->value; } const Scope& elseBody() const { - return else_body_; + return attribute(1)->as>()->value; } bool hasElse() const { - return !else_body_.empty(); + return !elseBody().empty(); } - - private: - Scope then_body_; - Scope else_body_; }; //! Grid reduction operation @@ -593,7 +570,11 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr { //! This node provides FusionExecutor the information it needs to allocate the //! reduction and sync buffers. class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { + static constexpr int num_reduction_op_attr = 3; + public: + using ReductionOp::ReductionOp; + GridReduction( IrBuilderPasskey passkey, BinaryOpType reduction_op_type, @@ -606,54 +587,59 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { Val* entrances, bool is_allreduce = false); + DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "GridReduction"; } - Expr* shallowCopy() const override; - Allocate* reduction_buffer() const { - return reduction_buffer_; + return attribute(num_reduction_op_attr)->as>()->value; } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(num_reduction_op_attr + 1) + ->as>() + ->value; } // Which instance of entering this grid reduction is this iteration? Val* entrance_index() const { - return entrance_index_; + return attribute(num_reduction_op_attr + 2); } // How many times will this grid reduction be entered Val* entrances() const { - return entrances_; + return attribute(num_reduction_op_attr + 3); } + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. const ParallelTypeBitmap& threadPredicate() const { - return thread_predicate_; + return attribute(num_reduction_op_attr + 4) + ->as>() + ->value; + } + + ParallelTypeBitmap& threadPredicate() { + return attribute(num_reduction_op_attr + 4) + ->as>() + ->value; } GridReduction* withThreadPredicate( const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); - result->thread_predicate_ = thread_predicate; + result->threadPredicate() = thread_predicate; return result; } - - private: - Allocate* reduction_buffer_ = nullptr; - Allocate* sync_buffer_ = nullptr; - // gridReduce has template flags for thread predicates. In order to - // use them, the thread predicate is held here separately from - // Expr::predicate_. - ParallelTypeBitmap thread_predicate_; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; }; class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { public: + using GroupedReductionOp::GroupedReductionOp; + GroupedGridReduction( IrBuilderPasskey passkey, std::vector reduction_op_type, @@ -667,60 +653,69 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { Val* buffer_stride, bool is_allreduce = false); + DECLARE_CLONE_AND_CREATE + + // number of attributes in the parent class + int numGroupedReductionOpAttr() const { + return 2 + outputs().size(); + } + virtual const char* getOpString() const override { return "GroupedGridReduction"; } - Expr* shallowCopy() const override; - const std::vector& reduction_buffers() const { - return reduction_buffers_; + return attribute(numGroupedReductionOpAttr()) + ->as>>() + ->value; } Allocate* reduction_buffer(size_t i) const { - return reduction_buffers_.at(i); + return reduction_buffers().at(i); } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(numGroupedReductionOpAttr() + 1) + ->as>() + ->value; } // Which instance of entering this grid reduction is this iteration? Val* entrance_index() const { - return entrance_index_; + return attribute(numGroupedReductionOpAttr() + 2); } // How many times will this grid reduction be entered Val* entrances() const { - return entrances_; + return attribute(numGroupedReductionOpAttr() + 3); } + // Stride of reduction buffers Val* buffer_stride() const { - return buffer_stride_; + return attribute(numGroupedReductionOpAttr() + 4); } + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. const ParallelTypeBitmap& threadPredicate() const { - return thread_predicate_; + return attribute(numGroupedReductionOpAttr() + 5) + ->as>() + ->value; + } + + ParallelTypeBitmap& threadPredicate() { + return attribute(numGroupedReductionOpAttr() + 5) + ->as>() + ->value; } GroupedGridReduction* withThreadPredicate( const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); - result->thread_predicate_ = thread_predicate; + result->threadPredicate() = thread_predicate; return result; } - - private: - std::vector reduction_buffers_; - Allocate* sync_buffer_ = nullptr; - // gridReduce has template flags for thread predicates. In order to - // use them, the thread predicate is held here separately from - // Expr::predicate_. - ParallelTypeBitmap thread_predicate_; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; - // Stride of reduction buffers - Val* buffer_stride_ = nullptr; }; //! Grid broadcast operation @@ -732,34 +727,31 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { //! broadcast and sync buffers. class TORCH_CUDA_CU_API GridBroadcast final : public Expr { public: + using Expr::Expr; + GridBroadcast( IrBuilderPasskey passkey, BroadcastOp* broadcast_op, Allocate* broadcast_buffer, Allocate* sync_buffer); + DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "GridBroadcast"; } - Expr* shallowCopy() const override; - BroadcastOp* broadcast_op() const { - return broadcast_op_; + return attribute(0)->as>()->value; } Allocate* broadcast_buffer() const { - return broadcast_buffer_; + return attribute(1)->as>()->value; } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(2)->as>()->value; } - - private: - BroadcastOp* broadcast_op_ = nullptr; - Allocate* broadcast_buffer_ = nullptr; - Allocate* sync_buffer_ = nullptr; }; //! Grid welford operation @@ -773,6 +765,8 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { //! TODO: Make this a subclass of WelfordOp class TORCH_CUDA_CU_API GridWelford final : public Expr { public: + using Expr::Expr; + GridWelford( IrBuilderPasskey passkey, WelfordOp* welford_op, @@ -783,68 +777,63 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { Val* entrance_index, Val* entrances); + DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "GridWelford"; } - Expr* shallowCopy() const override; - WelfordOp* welford_op() const { - return welford_op_; + return attribute(0)->as>()->value; } Allocate* var_buffer() const { - return var_buffer_; + return attribute(1)->as>()->value; } Allocate* avg_buffer() const { - return avg_buffer_; + return attribute(2)->as>()->value; } Allocate* N_buffer() const { - return n_buffer_; + return attribute(3)->as>()->value; } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(4)->as>()->value; } // Which instance of entering this grid reduction is this iteration? Val* entrance_index() const { - return entrance_index_; + return attribute(5); } // How many times will this grid reduction be entered Val* entrances() const { - return entrances_; + return attribute(6); } + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. const ParallelTypeBitmap& threadPredicate() const { - return thread_predicate_; + return attribute(7)->as>()->value; + } + ParallelTypeBitmap& threadPredicate() { + return attribute(7)->as>()->value; } GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); - result->thread_predicate_ = thread_predicate; + result->threadPredicate() = thread_predicate; return result; } - - private: - WelfordOp* welford_op_ = nullptr; - Allocate* var_buffer_ = nullptr; - Allocate* avg_buffer_ = nullptr; - Allocate* n_buffer_ = nullptr; - Allocate* sync_buffer_ = nullptr; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; - // gridReduce has template flags for thread predicates. In order to - // use them, the thread predicate is held here separately from - // Expr::predicate_. - ParallelTypeBitmap thread_predicate_; }; class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { public: + using GroupedWelfordOp::GroupedWelfordOp; + // input, output and init vals are vectors of triplets GroupedGridWelford( IrBuilderPasskey passkey, @@ -858,94 +847,110 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { Val* buffer_stride, bool is_allreduce = false); + DECLARE_CLONE_AND_CREATE + + int numGroupedWelfordOpAttr() const { + return 1 + outputs().size(); + } + virtual const char* getOpString() const override { return "GroupedGridWelford"; } - Expr* shallowCopy() const override; - const std::array, 3>& reduction_buffers() const { - return reduction_buffers_; + return attribute(numGroupedWelfordOpAttr()) + ->as, 3>>>() + ->value; } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(numGroupedWelfordOpAttr() + 1) + ->as>() + ->value; } // Which instance of entering this grid reduction is this iteration? Val* entrance_index() const { - return entrance_index_; + return attribute(numGroupedWelfordOpAttr() + 2); } // How many times will this grid reduction be entered Val* entrances() const { - return entrances_; + return attribute(numGroupedWelfordOpAttr() + 3); } + // Stride of reduction buffers Val* buffer_stride() const { - return buffer_stride_; + return attribute(numGroupedWelfordOpAttr() + 4); } + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. const ParallelTypeBitmap& threadPredicate() const { - return thread_predicate_; + return attribute(numGroupedWelfordOpAttr() + 5) + ->as>() + ->value; + } + ParallelTypeBitmap& threadPredicate() { + return attribute(numGroupedWelfordOpAttr() + 5) + ->as>() + ->value; } GroupedGridWelford* withThreadPredicate( const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); - result->thread_predicate_ = thread_predicate; + result->threadPredicate() = thread_predicate; return result; } - - private: - std::array, 3> reduction_buffers_; - Allocate* sync_buffer_ = nullptr; - // gridReduce has template flags for thread predicates. In order to - // use them, the thread predicate is held here separately from - // Expr::predicate_. - ParallelTypeBitmap thread_predicate_; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; - // Stride of reduction buffers - Val* buffer_stride_ = nullptr; }; // Allocate an instance of the fused reduction class. class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { + explicit AllocateFusedReduction(IrBuilderPasskey passkey, Expr* grid_expr); + public: + using Expr::Expr; + explicit AllocateFusedReduction( IrBuilderPasskey passkey, - GridReduction* grid_reduction); + GridReduction* grid_reduction) + : AllocateFusedReduction(passkey, dynamic_cast(grid_reduction)) {} explicit AllocateFusedReduction( IrBuilderPasskey passkey, - GridWelford* grid_welford); + GridWelford* grid_welford) + : AllocateFusedReduction(passkey, dynamic_cast(grid_welford)) {} explicit AllocateFusedReduction( IrBuilderPasskey passkey, - GroupedGridReduction* grouped_grid_reduction); + GroupedGridReduction* grouped_grid_reduction) + : AllocateFusedReduction( + passkey, + dynamic_cast(grouped_grid_reduction)) {} explicit AllocateFusedReduction( IrBuilderPasskey passkey, - GroupedGridWelford* grouped_grid_welford); + GroupedGridWelford* grouped_grid_welford) + : AllocateFusedReduction( + passkey, + dynamic_cast(grouped_grid_welford)) {} + + DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "AllocateFusedReduction"; } - Expr* shallowCopy() const override; - + //! GridReduction, GridWelford, GroupedGridReduction or GroupedGridWelford Expr* gridExpr() const { - return grid_expr_; + return attribute(0)->as>()->value; } TensorIndex* out() const; const ParallelTypeBitmap& threadPredicate() const; - - private: - //! GridReduction, GridWelford, GroupedGridReduction or GroupedGridWelford - Expr* grid_expr_ = nullptr; }; } // namespace kir diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 430cb31f7774..fdea5857fb4a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -856,9 +856,12 @@ void IndexLowering::handle(const GroupedWelfordOp* grouped_wop) { std::vector indexed_outputs(grouped_wop->numExprs()); std::vector indexed_inputs(grouped_wop->numExprs()); + auto output_vals = grouped_wop->outputVals(); + auto input_vals = grouped_wop->inputVals(); + for (const auto i : c10::irange(grouped_wop->numExprs())) { - const auto& output = grouped_wop->outputVals().at(i); - const auto& input = grouped_wop->inputVals().at(i); + const auto& output = output_vals.at(i); + const auto& input = input_vals.at(i); WelfordTriplet indexed_output; WelfordTriplet indexed_input; for (const auto j : c10::irange(3)) { diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 3735e74080ee..5d4c11b40737 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -234,7 +234,7 @@ void OptOutMutator::mutate(RNGOp* rop) { Val* out = maybeMutated(rop->output(0)); Val* philox_idx = maybeMutated(rop->getPhiloxIndex()); - auto& parameters = rop->getParameters(); + auto parameters = rop->getParameters(); std::vector mutated_parameters; bool all_mutated_same = true; for (auto v : parameters) { diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 93c24a2f5068..fe46e48f0240 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -140,6 +140,8 @@ TensorView::TensorView( "Function invalid for kernel container."); } +DEFINE_CLONE(TensorView) + void TensorView::convertRfactorToRootDomain() { // For a given TensorView, does its domain (root / rfactor) contain any // concrete sized extents? diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp index ab7aa93f5130..25c3669e787f 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp @@ -758,6 +758,7 @@ TEST_F(NVFuserTest, FusionRegister_CUDA) { // dummy expr with 2 outputs only for toposort test. struct DummyExpr : public Expr { + using Expr::Expr; ~DummyExpr() = default; DummyExpr( IrBuilderPasskey passkey, @@ -765,13 +766,13 @@ struct DummyExpr : public Expr { Val* _outrhs, Val* _lhs, Val* _rhs) - : Expr(passkey) // terribly safe :-D - { + : Expr(passkey) { addOutput(_outlhs); addOutput(_outrhs); addInput(_lhs); addInput(_rhs); } + DECLARE_CLONE_AND_CREATE DummyExpr(const DummyExpr& other) = delete; DummyExpr& operator=(const DummyExpr& other) = delete; DummyExpr(DummyExpr&& other) = delete; @@ -779,11 +780,10 @@ struct DummyExpr : public Expr { virtual const char* getOpString() const override { return "DummyExpr"; } - Expr* shallowCopy() const override { - return nullptr; - } }; +DEFINE_CLONE_AND_CREATE(DummyExpr) + TEST_F(NVFuserTest, FusionTopoSort_CUDA) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 7ec0b8ef9fd9..bd548e60b4a8 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -33,7 +33,8 @@ enum class ValType { Scalar, NamedScalar, Predicate, - TensorIndex + TensorIndex, + Plain }; // Manual - The user provides the Bool value. Predicate generation is bypassed. From 24b1a52d2de23a9111884d668742a57af01e1907 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 18 Nov 2022 12:15:52 -0800 Subject: [PATCH 02/10] save --- torch/csrc/jit/codegen/cuda/dispatch.cpp | 167 ------ torch/csrc/jit/codegen/cuda/dispatch.h | 44 -- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 2 +- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 25 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 558 ++---------------- 5 files changed, 47 insertions(+), 749 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 40fdf11a48cf..eaef01f7b636 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -549,171 +549,6 @@ void Val::mutatorDispatch(T mutator, Val* val) { TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!"); } -template -void Expr::mutatorDispatch(T mutator, Expr* expr) { - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - if (expr->isStrictlyA()) { - ptr(mutator)->mutate(expr->as()); - return; - } - TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); -} - template void Statement::mutatorDispatch(T mutator, Statement* stmt) { if (stmt->isVal()) { @@ -764,8 +599,6 @@ template void Statement::mutatorDispatch(OptOutMutator&, Statement*); template void Statement::mutatorDispatch(OptOutMutator*, Statement*); template void Val::mutatorDispatch(OptOutMutator&, Val*); template void Val::mutatorDispatch(OptOutMutator*, Val*); -template void Expr::mutatorDispatch(OptOutMutator&, Expr*); -template void Expr::mutatorDispatch(OptOutMutator*, Expr*); void OptOutDispatch::handle(Statement* s) { Statement::dispatch(this, s); diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 49295c16532c..eb47c609ca5d 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -320,50 +320,6 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(kir::Predicate*); virtual void mutate(kir::TensorIndex*); - // Exprs - virtual void mutate(FullOp*); - virtual void mutate(ARangeOp*); - virtual void mutate(EyeOp*); - virtual void mutate(UnaryOp*); - virtual void mutate(BinaryOp*); - virtual void mutate(TernaryOp*); - virtual void mutate(SelectOp*); - virtual void mutate(RNGOp*); - virtual void mutate(ReductionOp*); - virtual void mutate(GroupedReductionOp*); - virtual void mutate(WelfordOp*); - virtual void mutate(GroupedWelfordOp*); - virtual void mutate(LoadStoreOp*); - virtual void mutate(MmaOp*); - virtual void mutate(BroadcastOp*); - virtual void mutate(SqueezeOp*); - - virtual void mutate(Split*); - virtual void mutate(Merge*); - virtual void mutate(Swizzle2D*); - virtual void mutate(TransposeOp*); - virtual void mutate(ExpandOp*); - virtual void mutate(ShiftOp*); - virtual void mutate(GatherOp*); - virtual void mutate(ViewAsScalar*); - virtual void mutate(ViewOp*); - - virtual void mutate(kir::Allocate*); - virtual void mutate(kir::BlockSync*); - virtual void mutate(kir::GridSync*); - virtual void mutate(kir::CpAsyncWait*); - virtual void mutate(kir::CpAsyncCommit*); - virtual void mutate(kir::InitMagicZero*); - virtual void mutate(kir::UpdateMagicZero*); - virtual void mutate(kir::ForLoop*); - virtual void mutate(kir::IfThenElse*); - virtual void mutate(kir::GridReduction*); - virtual void mutate(kir::GroupedGridReduction*); - virtual void mutate(kir::GridBroadcast*); - virtual void mutate(kir::GridWelford*); - virtual void mutate(kir::GroupedGridWelford*); - virtual void mutate(kir::AllocateFusedReduction*); - protected: void removeExpr(IrContainer*, Expr*); }; diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index d78fa5b9623c..efd087289cda 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -333,7 +333,7 @@ Expr::Expr( attributes_(std::move(attributes)) {} Expr* Expr::shallowCopy() const { - auto result = newObject(inputs(), outputs(), attributes()); + auto result = newObject(ir_container_, inputs(), outputs(), attributes()); if (container()->isA()) { result->predicate_ = predicate_; result->write_predicate_ = write_predicate_; diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 3f39e2204703..442cedd480b4 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -489,6 +489,7 @@ class TORCH_CUDA_CU_API Expr : public Statement { // Creates a new instance of the same expression type with the given inputs, // outputs, and attributes. virtual Expr* newObject( + IrContainer* container, std::vector inputs, std::vector outputs, std::vector attributes) const = 0; @@ -525,9 +526,6 @@ class TORCH_CUDA_CU_API Expr : public Statement { template static void constDispatch(T handler, const Expr* const); - template - static void mutatorDispatch(T mutator, Expr*); - // TODO: Protect based on being in kernel container kir::Predicate* predicate() const; @@ -596,19 +594,22 @@ bool Val::isDefinitionType() const { #define DECLARE_CLONE_AND_CREATE \ virtual Statement* clone(IrCloner* ir_cloner) const override; \ virtual Expr* newObject( \ + IrContainer* container, \ std::vector inputs, \ std::vector outputs, \ std::vector attributes) const override; -#define DEFINE_CLONE_AND_CREATE(ClassName) \ - Statement* ClassName::clone(IrCloner* ir_cloner) const { \ - return IrBuilder::clone(this, ir_cloner); \ - } \ - Expr* ClassName::newObject( \ - std::vector inputs, \ - std::vector outputs, \ - std::vector attributes) const { \ - return IrBuilder::create(inputs, outputs, attributes); \ +#define DEFINE_CLONE_AND_CREATE(ClassName) \ + Statement* ClassName::clone(IrCloner* ir_cloner) const { \ + return IrBuilder::clone(this, ir_cloner); \ + } \ + Expr* ClassName::newObject( \ + IrContainer* container, \ + std::vector inputs, \ + std::vector outputs, \ + std::vector attributes) const { \ + return IrBuilder::create( \ + container, inputs, outputs, attributes); \ } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 5d4c11b40737..c07363290e18 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -15,10 +15,6 @@ void OptOutMutator::mutate(Statement* s) { Statement::mutatorDispatch(this, s); } -void OptOutMutator::mutate(Expr* e) { - Expr::mutatorDispatch(this, e); -} - void OptOutMutator::mutate(Val* v) { Val::mutatorDispatch(this, v); } @@ -125,543 +121,55 @@ void OptOutMutator::mutate(kir::TensorIndex*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } -void OptOutMutator::mutate(FullOp* fop) { - Val* out = maybeMutated(fop->output(0)); - Val* fill_value = maybeMutated(fop->getFillValue()); - - if (out->sameAs(fop->output(0))) { - return; - } - auto container = fop->container(); - container->removeExpr(fop); - IrBuilder::create(container, out, fill_value, fop->dtype()); -} - -void OptOutMutator::mutate(SelectOp* sop) { - Val* out = maybeMutated(sop->output(0)); - Val* in = maybeMutated(sop->input(0)); - Val* index = maybeMutated(sop->input(1)); - IterDomain* select_axis = - maybeMutated(sop->getSelectAxis())->as(); - - if (out->sameAs(sop->output(0)) && in->sameAs(sop->output(0)) && - index->sameAs(sop->output(1)) && - select_axis->sameAs(sop->getSelectAxis())) { - return; - } - auto container = sop->container(); - container->removeExpr(sop); - IrBuilder::create(container, out, in, select_axis, index); -} - -void OptOutMutator::mutate(ARangeOp* aop) { - Val* out = maybeMutated(aop->output(0)); - - if (out->sameAs(aop->output(0))) { - return; - } - auto container = aop->container(); - container->removeExpr(aop); - IrBuilder::create( - container, - out, - aop->start(), - aop->end(), - aop->step(), - aop->dtype(), - aop->getLinearLogicalIndex()); -} - -void OptOutMutator::mutate(EyeOp* eop) { - Val* out = maybeMutated(eop->output(0)); - - if (out->sameAs(eop->output(0))) { - return; - } - auto container = eop->container(); - container->removeExpr(eop); - IrBuilder::create( - container, out, eop->dtype(), eop->getIndex1(), eop->getIndex2()); -} - -void OptOutMutator::mutate(UnaryOp* uop) { - Val* out = maybeMutated(uop->out()); - Val* in = maybeMutated(uop->in()); - - if (out->sameAs(uop->out()) && in->sameAs(uop->in())) { - return; - } - auto container = uop->container(); - auto uop_type = uop->getUnaryOpType(); - container->removeExpr(uop); - IrBuilder::create(container, uop_type, out, in); -} - -void OptOutMutator::mutate(BinaryOp* bop) { - Val* out = maybeMutated(bop->out()); - Val* lhs = maybeMutated(bop->lhs()); - Val* rhs = maybeMutated(bop->rhs()); - - if (out->sameAs(bop->out()) && lhs->sameAs(bop->lhs()) && - rhs->sameAs(bop->rhs())) { - return; - } - - auto container = bop->container(); - auto bop_type = bop->getBinaryOpType(); - container->removeExpr(bop); - IrBuilder::create(container, bop_type, out, lhs, rhs); -} - -void OptOutMutator::mutate(TernaryOp* top) { - Val* out = maybeMutated(top->out()); - Val* in1 = maybeMutated(top->in1()); - Val* in2 = maybeMutated(top->in2()); - Val* in3 = maybeMutated(top->in3()); - - if (out->sameAs(top->out()) && in1->sameAs(top->in1()) && - in2->sameAs(top->in2()) && in3->sameAs(top->in3())) { - return; - } - - auto container = top->container(); - auto top_type = top->getTernaryOpType(); - container->removeExpr(top); - IrBuilder::create(container, top_type, out, in1, in2, in3); -} - -void OptOutMutator::mutate(RNGOp* rop) { - Val* out = maybeMutated(rop->output(0)); - Val* philox_idx = maybeMutated(rop->getPhiloxIndex()); - - auto parameters = rop->getParameters(); - std::vector mutated_parameters; - bool all_mutated_same = true; - for (auto v : parameters) { - mutated_parameters.emplace_back(maybeMutated(v)); - all_mutated_same = all_mutated_same && mutated_parameters.back()->sameAs(v); +void OptOutMutator::mutate(Expr* op) { + std::vector mutated_inputs; + mutated_inputs.reserve(op->inputs().size()); + for (auto input : op->inputs()) { + mutated_inputs.emplace_back(maybeMutated(input)); } - if (out->sameAs(rop->output(0)) && - ((philox_idx == nullptr && rop->getPhiloxIndex() == nullptr) || - philox_idx->sameAs(rop->getPhiloxIndex())) && - all_mutated_same) { - return; + std::vector mutated_outputs; + mutated_outputs.reserve(op->outputs().size()); + for (auto output : op->outputs()) { + mutated_outputs.emplace_back(maybeMutated(output)); } - auto container = rop->container(); - auto rop_type = rop->getRNGOpType(); - container->removeExpr(rop); - IrBuilder::create( - container, - rop_type, - out, - rop->dtype(), - mutated_parameters, - rop->getRNGOffset(), - philox_idx); -} - -void OptOutMutator::mutate(ReductionOp* rop) { - Val* out = maybeMutated(rop->out()); - Val* in = maybeMutated(rop->in()); - Val* init = rop->init(); - if (out->sameAs(rop->out()) && in->sameAs(rop->in()) && - init->sameAs(rop->init())) { - return; + std::vector mutated_attrs; + mutated_attrs.reserve(op->attributes().size()); + for (auto attr : op->attributes()) { + mutated_attrs.emplace_back(maybeMutated(attr)); } - auto container = rop->container(); - auto rop_type = rop->getReductionOpType(); - container->removeExpr(rop); - IrBuilder::create( - container, rop_type, init, out, in, rop->isAllreduce()); -} - -void OptOutMutator::mutate(GroupedReductionOp* rop) { - bool is_same = true; - - std::vector outputs; - for (auto out : rop->outputs()) { - auto maybe_mutated = maybeMutated(out); - is_same = is_same && maybe_mutated->sameAs(out); - outputs.push_back(maybe_mutated); - } - - std::vector inputs; - for (auto in : rop->inputs()) { - auto maybe_mutated = maybeMutated(in); - is_same = is_same && maybe_mutated->sameAs(in); - inputs.push_back(maybe_mutated); - } - - std::vector init_vals; - for (auto init : rop->initVals()) { - auto maybe_mutated = maybeMutated(init); - is_same = is_same && maybe_mutated->sameAs(init); - init_vals.push_back(maybe_mutated); - } - - if (is_same) { - return; - } - - auto container = rop->container(); - const auto& rop_types = rop->getReductionOpTypes(); - container->removeExpr(rop); - IrBuilder::create( - container, rop_types, init_vals, outputs, inputs, rop->isAllreduce()); -} - -namespace { -inline bool compareOptional(Val* a, Val* b) { - if (!a || !b) { - return (!a && !b); - } - return a->sameAs(b); -} - -} // namespace - -void OptOutMutator::mutate(WelfordOp* wop) { - Val* out_avg = maybeMutated(wop->outAvg()); - Val* out_var = maybeMutated(wop->outVar()); - Val* out_N = maybeMutated(wop->outN()); - - Val* in_avg = maybeMutated(wop->inAvg()); - Val* in_var = wop->inVar() ? maybeMutated(wop->inVar()) : nullptr; - Val* in_N = maybeMutated(wop->inN()); - - Val* init_avg = wop->initAvg() ? maybeMutated(wop->initAvg()) : nullptr; - Val* init_var = wop->initVar() ? maybeMutated(wop->initVar()) : nullptr; - Val* init_N = maybeMutated(wop->initN()); - - const bool out_compare = out_avg->sameAs(wop->outAvg()) && - out_var->sameAs(wop->outVar()) && out_N->sameAs(wop->outN()); - const bool in_compare = in_avg->sameAs(wop->inAvg()) && - compareOptional(in_var, wop->inVar()) && in_N->sameAs(wop->inN()); - const bool init_compare = compareOptional(init_avg, wop->initAvg()) && - compareOptional(init_var, wop->initVar()) && init_N->sameAs(wop->initN()); - - if (out_compare && init_compare && in_compare) { - return; - } - - auto container = wop->container(); - container->removeExpr(wop); - IrBuilder::create( - container, - out_avg, - out_var, - out_N, - in_avg, - in_var, - in_N, - init_avg, - init_var, - init_N, - wop->isAllreduce()); -} - -void OptOutMutator::mutate(GroupedWelfordOp* wop) { - bool is_same = true; - - std::vector output_vals; - for (const auto& out : wop->outputVals()) { - auto maybe_mutated = - out.transform([&](Val* val) { return maybeMutated(val); }); - is_same = is_same && maybe_mutated.sameAs(out); - output_vals.push_back(maybe_mutated); - } - - std::vector input_vals; - for (const auto& inp : wop->inputVals()) { - auto maybe_mutated = - inp.transform([&](Val* val) { return maybeMutated(val); }); - is_same = is_same && maybe_mutated.sameAs(inp); - input_vals.push_back(maybe_mutated); - } - - std::vector init_vals; - for (const auto& init : wop->initVals()) { - auto maybe_mutated = - init.transform([&](Val* val) { return maybeMutated(val); }); - is_same = is_same && maybe_mutated.sameAs(init); - init_vals.push_back(maybe_mutated); - } - - if (is_same) { - return; - } - - auto container = wop->container(); - container->removeExpr(wop); - IrBuilder::create( - container, output_vals, input_vals, init_vals, wop->isAllreduce()); -} - -void OptOutMutator::mutate(MmaOp* mma) { - Val* out = maybeMutated(mma->out()); - Val* in_a = maybeMutated(mma->inA()); - Val* in_b = maybeMutated(mma->inB()); - Val* init = mma->init(); - - if (out->sameAs(mma->out()) && in_a->sameAs(mma->inA()) && - in_b->sameAs(mma->inB())) { - return; - } - - auto container = mma->container(); - auto options = mma->options(); - container->removeExpr(mma); - C10_UNUSED auto new_mma = - IrBuilder::create(container, out, in_a, in_b, init, options); -} - -void OptOutMutator::mutate(LoadStoreOp* ldst) { - Val* out = maybeMutated(ldst->out()); - Val* in = maybeMutated(ldst->in()); - auto op_type = ldst->opType(); - - if (out->sameAs(ldst->out()) && in->sameAs(ldst->in())) { - return; - } - - auto container = ldst->container(); - container->removeExpr(ldst); - IrBuilder::create(container, op_type, out, in); -} - -void OptOutMutator::mutate(BroadcastOp* bop) { - Val* out = maybeMutated(bop->out()); - Val* in = maybeMutated(bop->in()); - - if (out->sameAs(bop->out()) && in->sameAs(bop->in())) { - return; - } - - auto container = bop->container(); - auto flags = bop->getBroadcastDimFlags(); - container->removeExpr(bop); - IrBuilder::create(container, out, in, flags); -} - -void OptOutMutator::mutate(SqueezeOp* sop) { - Val* out = maybeMutated(sop->out()); - Val* in = maybeMutated(sop->in()); - - if (out->sameAs(sop->out()) && in->sameAs(sop->in())) { - return; - } - - auto container = sop->container(); - auto flags = sop->getSqueezeDimFlags(); - container->removeExpr(sop); - IrBuilder::create(container, out, in, flags); -} - -void OptOutMutator::mutate(TransposeOp* top) { - TensorView* out = maybeMutated(top->out())->as(); - TensorView* in = maybeMutated(top->in())->as(); - - if (out->sameAs(top->out()) && in->sameAs(top->in())) { - return; - } - - auto container = top->container(); - auto new2old = top->new2old(); - container->removeExpr(top); - IrBuilder::create(container, out, in, new2old); -} - -void OptOutMutator::mutate(ExpandOp* eop) { - bool is_same = true; - - TensorView* out = maybeMutated(eop->out())->as(); - is_same = is_same && out->sameAs(eop->out()); - TensorView* in = maybeMutated(eop->in())->as(); - is_same = is_same && in->sameAs(eop->in()); - - std::vector expanded_extents; - expanded_extents.reserve(eop->expanded_extents().size()); - for (auto expanded_extent : eop->expanded_extents()) { - expanded_extents.push_back(maybeMutated(expanded_extent)); - if (!expanded_extents.back()->sameAs(expanded_extent)) { - is_same = false; + bool all_same = true; + for (auto i : c10::irange(op->outputs().size())) { + if (!all_same) { + break; } + all_same = all_same && mutated_outputs[i]->sameAs(op->output(i)); } - - if (is_same) { - return; + for (auto i : c10::irange(op->inputs().size())) { + if (!all_same) { + break; + } + all_same = all_same && mutated_inputs[i]->sameAs(op->input(i)); } - - auto container = eop->container(); - container->removeExpr(eop); - IrBuilder::create(container, out, in, expanded_extents); -} - -void OptOutMutator::mutate(ShiftOp* sop) { - Val* out = maybeMutated(sop->out())->asVal(); - Val* in = maybeMutated(sop->in())->asVal(); - - if (out->sameAs(sop->out()) && in->sameAs(sop->in())) { - return; + for (auto i : c10::irange(op->attributes().size())) { + if (!all_same) { + break; + } + all_same = all_same && mutated_attrs[i]->sameAs(op->attribute(i)); } - auto offsets = sop->offsets(); - auto pad_width = sop->padWidth(); - auto container = sop->container(); - container->removeExpr(sop); - IrBuilder::create(container, out, in, offsets, pad_width); -} - -void OptOutMutator::mutate(GatherOp* op) { - Val* out = maybeMutated(op->out())->asVal(); - Val* in = maybeMutated(op->in())->asVal(); - - if (out->sameAs(op->out()) && in->sameAs(op->in())) { + if (all_same) { return; } - auto window_shape = op->windowShape(); - auto pad_width = op->padWidth(); auto container = op->container(); + std::cout << op->toString() << std::endl; + std::cout << container << std::endl; + auto newObject = op->newObject; container->removeExpr(op); - IrBuilder::create(container, out, in, window_shape, pad_width); -} - -void OptOutMutator::mutate(ViewAsScalar* vop) { - TensorView* out = maybeMutated(vop->out())->as(); - TensorView* in = maybeMutated(vop->in())->as(); - IterDomain* vid = maybeMutated(vop->vector_id())->as(); - Val* idx = maybeMutated(vop->index()); - - if (out->sameAs(vop->out()) && in->sameAs(vop->in()) && - vid->sameAs(vop->vector_id()) && - ((idx == nullptr && vop->index() == nullptr) || - idx->sameAs(vop->index()))) { - return; - } - - auto container = vop->container(); - container->removeExpr(vop); - IrBuilder::create(container, out, in, vid, idx); -} - -void OptOutMutator::mutate(ViewOp* vop) { - TensorView* out = maybeMutated(vop->out())->as(); - TensorView* in = maybeMutated(vop->in())->as(); - - if (out->sameAs(vop->out()) && in->sameAs(vop->in())) { - return; - } - - auto container = vop->container(); - container->removeExpr(vop); - IrBuilder::create(container, out, in); -} - -void OptOutMutator::mutate(Split* s) { - IterDomain* ot = maybeMutated(s->outer())->as(); - IterDomain* inr = maybeMutated(s->inner())->as(); - IterDomain* in = maybeMutated(s->in())->as(); - Val* fact = maybeMutated(s->factor())->as(); - Val* start_offset = maybeMutated(s->startOffset()); - Val* stop_offset = maybeMutated(s->stopOffset()); - - if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && - in->sameAs(s->in()) && areEqualScalars(fact, s->factor()) && - start_offset->sameAs(s->startOffset()) && - stop_offset->sameAs(s->stopOffset())) { - return; - } - - auto container = s->container(); - auto inner_split = s->innerSplit(); - container->removeExpr(s); - C10_UNUSED auto new_node = IrBuilder::create( - container, ot, inr, in, fact, inner_split, start_offset, stop_offset); -} - -void OptOutMutator::mutate(Merge* m) { - IterDomain* ot = maybeMutated(m->out())->as(); - IterDomain* otr = maybeMutated(m->outer())->as(); - IterDomain* in = maybeMutated(m->inner())->as(); - - if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) && - in->sameAs(m->inner())) { - return; - } - - auto container = m->container(); - container->removeExpr(m); - C10_UNUSED auto new_node = IrBuilder::create(container, ot, otr, in); -} - -void OptOutMutator::mutate(Swizzle2D* m) { - IterDomain* outx = maybeMutated(m->outX())->as(); - IterDomain* outy = maybeMutated(m->outY())->as(); - - IterDomain* inx = maybeMutated(m->inX())->as(); - IterDomain* iny = maybeMutated(m->inY())->as(); - - auto swizzle_type = m->swizzleType(); - - if (outx->sameAs(m->outX()) && outy->sameAs(m->outY()) && - inx->sameAs(m->inX()) && iny->sameAs(m->inY())) { - return; - } - auto container = m->container(); - container->removeExpr(m); - FusionGuard::getCurFusion()->removeExpr(m); - C10_UNUSED auto new_node = IrBuilder::create( - container, outx, outy, inx, iny, swizzle_type); -} - -void OptOutMutator::mutate(kir::Allocate*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::BlockSync*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GridSync*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::CpAsyncWait*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::CpAsyncCommit*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::InitMagicZero*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::UpdateMagicZero*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::ForLoop*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::IfThenElse*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GridReduction*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GroupedGridReduction*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GridBroadcast*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GridWelford*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::GroupedGridWelford*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); -} -void OptOutMutator::mutate(kir::AllocateFusedReduction*) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); + newObject(container, mutated_inputs, mutated_outputs, mutated_attrs); } void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) { From c8244e6e9139611879d41082dd2db78274467a56 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 18 Nov 2022 12:46:07 -0800 Subject: [PATCH 03/10] fix --- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 5 +++-- torch/csrc/jit/codegen/cuda/ir_interface_nodes.h | 2 -- torch/csrc/jit/codegen/cuda/mutator.cpp | 8 ++++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index ef60a79fb4bf..a72329ca659a 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -489,6 +489,7 @@ class TORCH_CUDA_CU_API Expr : public Statement { // Creates a new instance of the same expression type with the given inputs, // outputs, and attributes. virtual Expr* newObject( + IrContainer* container, std::vector inputs, std::vector outputs, std::vector attributes) const = 0; @@ -594,7 +595,7 @@ bool Val::isDefinitionType() const { return false; } -#define DECLARE_CLONE_AND_CREATE \ +#define NVFUSER_DECLARE_CLONE_AND_CREATE \ virtual Statement* clone(IrCloner* ir_cloner) const override; \ virtual Expr* newObject( \ IrContainer* container, \ @@ -602,7 +603,7 @@ bool Val::isDefinitionType() const { std::vector outputs, \ std::vector attributes) const override; -#define DEFINE_CLONE_AND_CREATE(ClassName) \ +#define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \ Statement* ClassName::clone(IrCloner* ir_cloner) const { \ return IrBuilder::clone(this, ir_cloner); \ } \ diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 5c30f0a9479f..d323b0c590ec 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -80,8 +80,6 @@ class TORCH_CUDA_CU_API FloatingPoint : public Val { NVFUSER_DECLARE_CLONE - DECLARE_CLONE - bool isSymbolic() const { return !(maybe_value_.has_value()); } diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 051e526d8b76..c1f7f90b6ee1 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -136,10 +136,14 @@ void OptOutMutator::mutate(Expr* op) { mutated_outputs.emplace_back(maybeMutated(output)); } - std::vector mutated_attrs; + std::vector mutated_attrs; mutated_attrs.reserve(op->attributes().size()); for (auto attr : op->attributes()) { - mutated_attrs.emplace_back(maybeMutated(attr)); + if (auto attr_val = dynamic_cast(attr)) { + mutated_attrs.emplace_back(maybeMutated(attr_val)); + } else { + mutated_attrs.emplace_back(attr); + } } bool all_same = true; From ee62179bcd41158941e1e7869c8f4368c9494281 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 18 Nov 2022 13:30:18 -0800 Subject: [PATCH 04/10] fix --- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 2 +- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 25 +++++++++++-------- torch/csrc/jit/codegen/cuda/mutator.cpp | 6 ++--- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index b4fc6f7d7ca3..99dd5e433d08 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -337,7 +337,7 @@ Expr::Expr( attributes_(std::move(attributes)) {} Expr* Expr::shallowCopy() const { - auto result = newObject(ir_container_, inputs(), outputs(), attributes()); + auto result = newObjectFunc()(ir_container_, inputs(), outputs(), attributes()); if (container()->isA()) { result->predicate_ = predicate_; result->write_predicate_ = write_predicate_; diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index a72329ca659a..013034e1ecff 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -429,6 +429,12 @@ class TORCH_CUDA_CU_API Attribute : public Val { } }; +using newObjectFuncType = Expr*( + IrContainer*, + std::vector, + std::vector, + std::vector); + //! A Expr represents a "computation." These are functions that takes inputs //! and produce outputs, inputs and outputs all being Vals. There are //! specializations of BinaryOp which takes 2 inputs and produces 1 output, and @@ -480,20 +486,14 @@ class TORCH_CUDA_CU_API Expr : public Statement { std::vector outputs, std::vector attributes); + virtual newObjectFuncType* newObjectFunc() const = 0; + // Creates a new instance of the expression with all its field copied. // Note that unlike IrCloner, this function only do a shallow copy Expr* shallowCopy() const; bool sameAs(const Statement* other) const override; - // Creates a new instance of the same expression type with the given inputs, - // outputs, and attributes. - virtual Expr* newObject( - IrContainer* container, - std::vector inputs, - std::vector outputs, - std::vector attributes) const = 0; - // Input/output accessors const auto& inputs() const { return inputs_; @@ -597,11 +597,14 @@ bool Val::isDefinitionType() const { #define NVFUSER_DECLARE_CLONE_AND_CREATE \ virtual Statement* clone(IrCloner* ir_cloner) const override; \ - virtual Expr* newObject( \ + static Expr* newObject( \ IrContainer* container, \ std::vector inputs, \ std::vector outputs, \ - std::vector attributes) const override; + std::vector attributes); \ + virtual newObjectFuncType* newObjectFunc() const override { \ + return newObject; \ + } #define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \ Statement* ClassName::clone(IrCloner* ir_cloner) const { \ @@ -611,7 +614,7 @@ bool Val::isDefinitionType() const { IrContainer* container, \ std::vector inputs, \ std::vector outputs, \ - std::vector attributes) const { \ + std::vector attributes) { \ return IrBuilder::create( \ container, inputs, outputs, attributes); \ } diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index c1f7f90b6ee1..1bf9ce1f1336 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -171,11 +171,9 @@ void OptOutMutator::mutate(Expr* op) { } auto container = op->container(); - std::cout << op->toString() << std::endl; - std::cout << container << std::endl; - auto newObject = op->newObject; + auto newObjectFunc = op->newObjectFunc(); container->removeExpr(op); - newObject(container, mutated_inputs, mutated_outputs, mutated_attrs); + newObjectFunc(container, mutated_inputs, mutated_outputs, mutated_attrs); } void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) { From 376dc8858e37dea5a8ab3aae97668ed6456ede8e Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 18 Nov 2022 13:41:00 -0800 Subject: [PATCH 05/10] fix --- torch/csrc/jit/codegen/cuda/ir_builder_key.h | 29 -------------------- torch/csrc/jit/codegen/cuda/mutator.cpp | 5 +++- 2 files changed, 4 insertions(+), 30 deletions(-) delete mode 100644 torch/csrc/jit/codegen/cuda/ir_builder_key.h diff --git a/torch/csrc/jit/codegen/cuda/ir_builder_key.h b/torch/csrc/jit/codegen/cuda/ir_builder_key.h deleted file mode 100644 index 95c8f21ea4f1..000000000000 --- a/torch/csrc/jit/codegen/cuda/ir_builder_key.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -class IrContainer; - -// Passkey for builder to register properties with statements, and to call -// functions in IrContainer -class TORCH_CUDA_CU_API IrBuilderPasskey { - friend class IrBuilder; - - public: - // TODO: Collapse ir_container and Kernel once Kernel inherits from - // IrContainer - IrContainer* const ir_container_ = nullptr; - - private: - explicit IrBuilderPasskey(IrContainer* ir_container); -}; - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 1bf9ce1f1336..c7f28912e872 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -163,7 +163,10 @@ void OptOutMutator::mutate(Expr* op) { if (!all_same) { break; } - all_same = all_same && mutated_attrs[i]->sameAs(op->attribute(i)); + bool same = + ((mutated_attrs[i] == nullptr) && (op->attribute(i) == nullptr)) || + mutated_attrs[i]->sameAs(op->attribute(i)); + all_same = all_same && same; } if (all_same) { From db9686d134afa6d102fd6b2627e4e22edb66f79b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 18 Nov 2022 13:48:27 -0800 Subject: [PATCH 06/10] no print --- torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 42e4925d1a03..2855eb95abac 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -2500,8 +2500,6 @@ TEST_F(NVFuserTest, FusionGeluBwdReduction_CUDA) { fusion.addOutput(t26); fusion.addOutput(t27); - fusion.printMath(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(1); From 7d9c94e956dfc3e107d45c288710d39c85cb371b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 18 Nov 2022 14:09:09 -0800 Subject: [PATCH 07/10] save --- torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 2855eb95abac..413cdb902da9 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -2522,7 +2522,7 @@ TEST_F(NVFuserTest, FusionGeluBwdReduction_CUDA) { auto reduction_params = getReductionHeuristics(&fusion, {at_grad, at_xvar}); TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); scheduleReduction(&fusion, *reduction_params); - fusion.printKernel(); + FusionExecutor fe; fe.compileFusion(&fusion, {at_grad, at_xvar}, reduction_params->lparams); auto cg_outputs = fe.runFusion({at_grad, at_xvar}, reduction_params->lparams); From 305135fc7bb12fee0c9f7392b0ed91c08101899c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 18 Nov 2022 14:39:44 -0800 Subject: [PATCH 08/10] save --- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 99dd5e433d08..623b49c48389 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -337,7 +337,8 @@ Expr::Expr( attributes_(std::move(attributes)) {} Expr* Expr::shallowCopy() const { - auto result = newObjectFunc()(ir_container_, inputs(), outputs(), attributes()); + auto result = + newObjectFunc()(ir_container_, inputs(), outputs(), attributes()); if (container()->isA()) { result->predicate_ = predicate_; result->write_predicate_ = write_predicate_; From e585dcbcd87b3a509cc6884567494f5b64cb0298 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 18 Nov 2022 16:40:28 -0800 Subject: [PATCH 09/10] rewrite SubstituteInExpr with OptOutMutator --- torch/csrc/jit/codegen/cuda/dispatch.h | 3 +- torch/csrc/jit/codegen/cuda/ir_utils.cpp | 388 +---------------------- torch/csrc/jit/codegen/cuda/mutator.cpp | 9 +- 3 files changed, 19 insertions(+), 381 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 84704f23d5c2..3e10172be08a 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -327,7 +327,8 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(kir::TensorIndex*); protected: - void removeExpr(IrContainer*, Expr*); + virtual void removeExpr(IrContainer*, Expr*) const; + virtual void registerNewExpr(Expr*) {} }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index dc434ef48a1a..5bcfda4b6abe 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -155,397 +155,31 @@ std::vector normalizeOld2New( namespace ValReplacement { // Create New Expr given producer - [an input for the expression] // Creates a new Expr substituting current with producer -struct SubstituteInExpr : public OptInDispatch { +struct SubstituteInExpr : public OptOutMutator { public: static Expr* subsitute(Expr* expr, Val* reference, Val* substitute) { TORCH_INTERNAL_ASSERT( expr != nullptr && reference != nullptr && substitute != nullptr, "Nullptr arg found."); SubstituteInExpr sie(reference, substitute); - sie.handle(expr); - TORCH_INTERNAL_ASSERT( - sie.expr_ != nullptr, - "Substitution failed of ", - reference, - " with ", - substitute); - return sie.expr_; - } - - private: - explicit SubstituteInExpr(Val* reference, Val* substitute) - : reference_(reference), substitute_(substitute) {} - - void handle(Expr* expr) final { - OptInDispatch::handle(expr); - } - - void handle(FullOp* full_expr) final { - auto out = reference_->sameAs(full_expr->output(0)) ? substitute_ - : full_expr->output(0); - expr_ = IrBuilder::create( - full_expr->container(), - out, - full_expr->getFillValue(), - full_expr->dtype()); - } - - void handle(ARangeOp* arange_expr) final { - auto start = reference_->sameAs(arange_expr->start()) - ? substitute_ - : arange_expr->start(); - auto end = reference_->sameAs(arange_expr->end()) ? substitute_ - : arange_expr->end(); - auto step = reference_->sameAs(arange_expr->step()) ? substitute_ - : arange_expr->step(); - auto out = reference_->sameAs(arange_expr->output(0)) - ? substitute_ - : arange_expr->output(0); - expr_ = IrBuilder::create( - arange_expr->container(), - out, - start, - end, - step, - arange_expr->dtype(), - arange_expr->getLinearLogicalIndex()); - } - - void handle(EyeOp* eye_expr) final { - auto out = reference_->sameAs(eye_expr->output(0)) ? substitute_ - : eye_expr->output(0); - expr_ = IrBuilder::create( - eye_expr->container(), - out, - eye_expr->dtype(), - eye_expr->getIndex1(), - eye_expr->getIndex2()); - } - - void handle(UnaryOp* unary_expr) final { - auto in = - reference_->sameAs(unary_expr->in()) ? substitute_ : unary_expr->in(); - auto out = - reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out(); - expr_ = IrBuilder::create( - unary_expr->container(), unary_expr->getUnaryOpType(), out, in); - } - - void handle(BinaryOp* binary_expr) final { - auto lhs = reference_->sameAs(binary_expr->lhs()) ? substitute_ - : binary_expr->lhs(); - auto rhs = reference_->sameAs(binary_expr->rhs()) ? substitute_ - : binary_expr->rhs(); - auto out = reference_->sameAs(binary_expr->out()) ? substitute_ - : binary_expr->out(); - - expr_ = IrBuilder::create( - binary_expr->container(), - binary_expr->getBinaryOpType(), - out, - lhs, - rhs); - } - - void handle(TernaryOp* ternary_expr) final { - auto in1 = reference_->sameAs(ternary_expr->in1()) ? substitute_ - : ternary_expr->in1(); - auto in2 = reference_->sameAs(ternary_expr->in2()) ? substitute_ - : ternary_expr->in2(); - auto in3 = reference_->sameAs(ternary_expr->in3()) ? substitute_ - : ternary_expr->in3(); - auto out = reference_->sameAs(ternary_expr->out()) ? substitute_ - : ternary_expr->out(); - expr_ = IrBuilder::create( - ternary_expr->container(), - ternary_expr->getTernaryOpType(), - out, - in1, - in2, - in3); - } - - void handle(SelectOp* select_expr) final { - auto input = reference_->sameAs(select_expr->input(0)) - ? substitute_ - : select_expr->input(0); - auto index = reference_->sameAs(select_expr->input(1)) - ? substitute_ - : select_expr->input(1); - auto out = reference_->sameAs(select_expr->output(0)) - ? substitute_ - : select_expr->output(0); - expr_ = IrBuilder::create( - select_expr->container(), - out, - input, - select_expr->getSelectAxis(), - index); - } - - void handle(RNGOp* rng_expr) final { - std::vector subsituted_params; - for (auto v : rng_expr->getParameters()) { - subsituted_params.emplace_back(reference_->sameAs(v) ? substitute_ : v); - } - auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_ - : rng_expr->output(0); - expr_ = IrBuilder::create( - rng_expr->container(), - rng_expr->getRNGOpType(), - out, - rng_expr->dtype(), - subsituted_params, - rng_expr->getRNGOffset(), - rng_expr->getPhiloxIndex()); - } - - void handle(ReductionOp* reduction_expr) final { - auto init = reference_->sameAs(reduction_expr->init()) - ? substitute_ - : reduction_expr->init(); - auto out = reference_->sameAs(reduction_expr->out()) - ? substitute_ - : reduction_expr->out(); - auto in = reference_->sameAs(reduction_expr->in()) ? substitute_ - : reduction_expr->in(); - - expr_ = IrBuilder::create( - reduction_expr->container(), - reduction_expr->getReductionOpType(), - init, - out, - in); - } - - void handle(GroupedReductionOp* grouped_reduction_expr) final { - std::vector outputs; - std::transform( - grouped_reduction_expr->outputs().begin(), - grouped_reduction_expr->outputs().end(), - std::back_inserter(outputs), - [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; }); - - std::vector inputs; - std::transform( - grouped_reduction_expr->inputs().begin(), - grouped_reduction_expr->inputs().end(), - std::back_inserter(inputs), - [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; }); - - std::vector init_vals; - std::transform( - grouped_reduction_expr->initVals().begin(), - grouped_reduction_expr->initVals().end(), - std::back_inserter(init_vals), - [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; }); - - expr_ = IrBuilder::create( - grouped_reduction_expr->container(), - grouped_reduction_expr->getReductionOpTypes(), - init_vals, - outputs, - inputs); + sie.mutate(expr); + // if nothing substituted, then return the original expr + return sie.expr_ == nullptr ? expr : sie.expr_; } - void handle(BroadcastOp* broadcast_expr) final { - auto out = reference_->sameAs(broadcast_expr->out()) - ? substitute_ - : broadcast_expr->out(); - auto in = reference_->sameAs(broadcast_expr->in()) ? substitute_ - : broadcast_expr->in(); - - expr_ = IrBuilder::create( - broadcast_expr->container(), - out, - in, - broadcast_expr->getBroadcastDimFlags()); - } - - void handle(SqueezeOp* squeeze_expr) final { - auto out = reference_->sameAs(squeeze_expr->out()) ? substitute_ - : squeeze_expr->out(); - auto in = reference_->sameAs(squeeze_expr->in()) ? substitute_ - : squeeze_expr->in(); - - expr_ = IrBuilder::create( - squeeze_expr->container(), out, in, squeeze_expr->getSqueezeDimFlags()); - } - - void handle(TransposeOp* transpose_expr) final { - TORCH_INTERNAL_ASSERT( - substitute_->isA(), - "All args to transpose must be tensor view, but received a non-TensorView for replacement: ", - substitute_); - auto out = reference_->sameAs(transpose_expr->out()) - ? substitute_->as() - : transpose_expr->out(); - auto in = reference_->sameAs(transpose_expr->in()) - ? substitute_->as() - : transpose_expr->in(); - expr_ = IrBuilder::create( - transpose_expr->container(), out, in, transpose_expr->new2old()); - } - - void handle(ExpandOp* expand_expr) final { - auto out = reference_->sameAs(expand_expr->out()) - ? substitute_->as() - : expand_expr->out(); - auto in = reference_->sameAs(expand_expr->in()) - ? substitute_->as() - : expand_expr->in(); - - auto expanded_extents = expand_expr->expanded_extents(); - if (substitute_->isA()) { - for (auto i : c10::irange(expanded_extents.size())) { - if (!expanded_extents[i]->sameAs(substitute_)) { - expanded_extents[i] = substitute_; - } - } - } - expr_ = IrBuilder::create( - expand_expr->container(), out, in, expanded_extents); - } - - void handle(ShiftOp* shift_expr) final { - auto out = - reference_->sameAs(shift_expr->out()) ? substitute_ : shift_expr->out(); - auto in = - reference_->sameAs(shift_expr->in()) ? substitute_ : shift_expr->in(); - - expr_ = IrBuilder::create( - shift_expr->container(), - out, - in, - shift_expr->offsets(), - shift_expr->padWidth()); - } + protected: + virtual void removeExpr(IrContainer*, Expr*) const override {} - void handle(GatherOp* gather_expr) final { - auto out = reference_->sameAs(gather_expr->out()) ? substitute_ - : gather_expr->out(); - auto in = - reference_->sameAs(gather_expr->in()) ? substitute_ : gather_expr->in(); - - expr_ = IrBuilder::create( - gather_expr->container(), - out, - in, - gather_expr->windowShape(), - gather_expr->padWidth()); + virtual void registerNewExpr(Expr* expr) override { + expr_ = expr; } - void handle(ViewAsScalar* expr) final { - TORCH_INTERNAL_ASSERT( - substitute_->isA(), - "All args to view must be TensorView, but received a non-TensorView for replacement: ", - substitute_); - auto in = reference_->sameAs(expr->in()) ? substitute_->as() - : expr->in(); - auto out = reference_->sameAs(expr->out()) ? substitute_->as() - : expr->out(); - expr_ = IrBuilder::create( - expr->container(), out, in, expr->vector_id(), expr->index()); - } - - void handle(ViewOp* view_expr) final { - TORCH_INTERNAL_ASSERT( - substitute_->isA(), - "All args to view must be TensorView, but received a non-TensorView for replacement: ", - substitute_); - auto in = reference_->sameAs(view_expr->in()) - ? substitute_->as() - : view_expr->in(); - auto out = reference_->sameAs(view_expr->out()) - ? substitute_->as() - : view_expr->out(); - expr_ = IrBuilder::create(view_expr->container(), out, in); - } - - void handle(WelfordOp* welford_expr) final { - auto out_avg = reference_->sameAs(welford_expr->outAvg()) - ? substitute_->as() - : welford_expr->outAvg(); - auto out_var = reference_->sameAs(welford_expr->outVar()) - ? substitute_->as() - : welford_expr->outVar(); - auto out_N = reference_->sameAs(welford_expr->outN()) - ? substitute_->as() - : welford_expr->outN(); - auto in_avg = reference_->sameAs(welford_expr->inAvg()) - ? substitute_->as() - : welford_expr->inAvg(); - auto in_var = - welford_expr->inVar() && reference_->sameAs(welford_expr->inVar()) - ? substitute_->as() - : welford_expr->inVar(); - auto in_N = reference_->sameAs(welford_expr->inN()) ? substitute_ - : welford_expr->inN(); - auto init_avg = - welford_expr->initAvg() && reference_->sameAs(welford_expr->initAvg()) - ? substitute_->as() - : welford_expr->initAvg(); - auto init_var = - welford_expr->initVar() && reference_->sameAs(welford_expr->initVar()) - ? substitute_->as() - : welford_expr->initVar(); - auto init_N = - welford_expr->initN() && reference_->sameAs(welford_expr->initN()) - ? substitute_ - : welford_expr->initN(); - expr_ = IrBuilder::create( - welford_expr->container(), - out_avg, - out_var, - out_N, - in_avg, - in_var, - in_N, - init_avg, - init_var, - init_N, - welford_expr->isAllreduce()); - } - - void handle(LoadStoreOp* ldst_expr) final { - TORCH_INTERNAL_ASSERT( - substitute_->isA(), - "All args to view must be TensorView, but received a non-TensorView for replacement: ", - substitute_); - auto in = reference_->sameAs(ldst_expr->in()) - ? substitute_->as() - : ldst_expr->in(); - auto out = reference_->sameAs(ldst_expr->out()) - ? substitute_->as() - : ldst_expr->out(); - expr_ = IrBuilder::create( - ldst_expr->container(), ldst_expr->opType(), out, in); - } - - void handle(MmaOp* mma_expr) final { - TORCH_INTERNAL_ASSERT( - substitute_->isA(), - "All args to MmaOp must be TensorView, but received a non-TensorView for replacement: ", - substitute_); - auto in_a = reference_->sameAs(mma_expr->inA()) - ? substitute_->as() - : mma_expr->inA(); - auto in_b = reference_->sameAs(mma_expr->inB()) - ? substitute_->as() - : mma_expr->inB(); - auto out = reference_->sameAs(mma_expr->out()) - ? substitute_->as() - : mma_expr->out(); - auto init = reference_->sameAs(mma_expr->init()) - ? substitute_->as() - : mma_expr->init(); - expr_ = IrBuilder::create( - mma_expr->container(), out, in_a, in_b, init, mma_expr->options()); + private: + explicit SubstituteInExpr(Val* reference, Val* substitute) { + mutations[reference] = substitute; } private: - Val* reference_ = nullptr; - Val* substitute_ = nullptr; Expr* expr_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index c7f28912e872..afc5222c4348 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -175,13 +175,16 @@ void OptOutMutator::mutate(Expr* op) { auto container = op->container(); auto newObjectFunc = op->newObjectFunc(); - container->removeExpr(op); - newObjectFunc(container, mutated_inputs, mutated_outputs, mutated_attrs); + removeExpr(container, op); + auto new_expr = + newObjectFunc(container, mutated_inputs, mutated_outputs, mutated_attrs); + registerNewExpr(new_expr); } -void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) { +void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) const { container->removeExpr(expr); } + } // namespace cuda } // namespace fuser } // namespace jit From a97b850d91219edff1122906f11e73dd7cf8d499 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 21 Nov 2022 11:27:10 -0800 Subject: [PATCH 10/10] rename --- torch/csrc/jit/codegen/cuda/dispatch.h | 6 +++--- torch/csrc/jit/codegen/cuda/ir_utils.cpp | 2 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 3e10172be08a..26a7bf509761 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -302,13 +302,13 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { void registerMutation(Val* val, Val* mutation); Val* maybeMutated(Val* val) { - if (mutations.find(val) == mutations.end()) { + if (mutations_.find(val) == mutations_.end()) { return val; } - return mutations.at(val); + return mutations_.at(val); } - std::unordered_map mutations; + std::unordered_map mutations_; //****Functions below defined in mutator.cpp***** diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 5bcfda4b6abe..65ff22df69d1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -176,7 +176,7 @@ struct SubstituteInExpr : public OptOutMutator { private: explicit SubstituteInExpr(Val* reference, Val* substitute) { - mutations[reference] = substitute; + mutations_[reference] = substitute; } private: diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index afc5222c4348..3d49b558a386 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -38,7 +38,7 @@ void OptOutMutator::registerMutation(Val* val, Val* mutation) { ", ", mutation->dtype(), ")"); - mutations[val] = mutation; + mutations_[val] = mutation; } void OptOutMutator::mutate(Bool* b) {}