forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
153 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
namespace fuser { | ||
namespace cuda { | ||
|
||
//! A Expr represents a "computation." These are functions that takes inputs | ||
//! and produce outputs, inputs and outputs all being Vals. There are | ||
//! specializations of BinaryOp which takes 2 inputs and produces 1 output, and | ||
//! UnaryOp which takes 1 input and produces 1 output. Exprs are unique and | ||
//! immutable. Conceptually, Exprs could always be manipulated using unique | ||
//! pointers, and we could add this later. However, for now Exprs can be | ||
//! replaced in a fusion, but they cannot be modified in place. | ||
//! | ||
//! The IR is static single assignment (SSA). Values can only be defined as an | ||
//! output of an Expr once. If they are re-defined the original definition is | ||
//! deleted from the program, as opposed to an ordered redefinition of the | ||
//! value in the program. | ||
//! | ||
//! Note: Registering an Expr with a Fusion is actually 2 parts, one part is | ||
//! done in the Expr constructor, so that should be called on anything that | ||
//! inherits Expr. The issue with having registration in Expr's constructor, is | ||
//! that the constructor of an Expr will set ouputs and inputs. This | ||
//! information is important for registration with Fuser, so it can track the | ||
//! dependency chain. | ||
//! | ||
//! Adding an Expr: | ||
//! Right now adding an Expr is quite involved. Expr's can be defined in ir.h | ||
//! or in their own header file. The following is what is currently needed for | ||
//! Expr definitions: | ||
//! | ||
//! 1) Definition inheriting from Expr. | ||
//! - Members must be private or protected | ||
//! - Accessor functions for members | ||
//! - Constructors need to register with the Fusion after inputs/outputs | ||
//! are defined | ||
//! - Implementation of bool sameAs(...) | ||
//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val | ||
//! 3) Default mutator function should be added to mutator.h/.cpp | ||
//! 4) Printing functions should be added to ir_iostream.h/.cpp | ||
//! 5) Lower case convenience functions should be added to arith.h/.cpp (If | ||
//! user facing) | ||
//! 6) An enum value must be added to ExprType in type.h | ||
//! 7) A string entry must be added in expr_type_string_map | ||
//! 8) Entry added to ir_graphviz .cpp/.h | ||
//! | ||
class TORCH_CUDA_CU_API Expr : public Statement { | ||
public: | ||
explicit Expr(IrBuilderPasskey, ExprType type); | ||
|
||
Expr(const Expr* src, IrCloner* ir_cloner); | ||
|
||
// Creates a new instance of the expression with all its field copied. | ||
// Note that unlike IrCloner, this function only do a shallow copy | ||
virtual Expr* shallowCopy() const = 0; | ||
|
||
c10::optional<ExprType> getExprType() const override { | ||
return etype_; | ||
} | ||
|
||
ExprType etype() const { | ||
return etype_; | ||
} | ||
|
||
bool sameAs(const Statement* other) const override; | ||
|
||
// Input/output accessors | ||
const auto& inputs() const { | ||
return inputs_; | ||
} | ||
|
||
const auto& outputs() const { | ||
return outputs_; | ||
} | ||
|
||
auto input(size_t index) const { | ||
return inputs_[index]; | ||
} | ||
|
||
auto output(size_t index) const { | ||
return outputs_[index]; | ||
} | ||
|
||
// Dispatch functions, definitions in dispatch.cpp | ||
template <typename T> | ||
static void dispatch(T handler, Expr*); | ||
|
||
template <typename T> | ||
static void constDispatch(T handler, const Expr* const); | ||
|
||
template <typename T> | ||
static void mutatorDispatch(T mutator, Expr*); | ||
|
||
// TODO: Protect based on being in kernel container | ||
kir::Predicate* predicate() const; | ||
|
||
// Creates a shallow copy the expression with the given predicate attached. | ||
// TODO: Protect based on being in kernel container | ||
Expr* withPredicate(kir::Predicate* predicate); | ||
|
||
// TODO: Protect based on being in kernel container | ||
kir::Predicate* writePredicate() const; | ||
|
||
// Creates a shallow copy the expression with the given write-predicate | ||
// attached. | ||
// TODO: Protect based on being in kernel container | ||
Expr* withWritePredicate(kir::Predicate* write_predicate); | ||
|
||
protected: | ||
// TODO: Protect based on being in kernel container | ||
void setPredicate(kir::Predicate* predicate); | ||
|
||
// TODO: Protect based on being in kernel container | ||
void setWritePredicate(kir::Predicate* write_predicate); | ||
|
||
void copyPredicatesFrom(const Expr* expr); | ||
|
||
// TODO: Add Fusion passkey | ||
void addInput(Val* input) { | ||
TORCH_INTERNAL_ASSERT(input != nullptr); | ||
inputs_.push_back(input); | ||
} | ||
|
||
// TODO: Add Fusion passkey | ||
void addOutput(Val* output) { | ||
TORCH_INTERNAL_ASSERT(output != nullptr); | ||
outputs_.push_back(output); | ||
} | ||
|
||
ExprPasskey exprPasskey() { | ||
return ExprPasskey(); | ||
} | ||
|
||
private: | ||
ExprType etype_ = ExprType::Invalid; | ||
std::vector<Val*> inputs_; | ||
std::vector<Val*> outputs_; | ||
|
||
kir::Predicate* predicate_ = nullptr; | ||
|
||
// Only used for reduction-related expressions | ||
kir::Predicate* write_predicate_ = nullptr; | ||
}; | ||
|
||
class TORCH_CUDA_CU_API StructuredExpr : public Expr { | ||
}; | ||
|
||
} // namespace cuda | ||
} // namespace fuser | ||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters