Skip to content

Commit

Permalink
save structured expr
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Nov 15, 2022
1 parent c9f8c1d commit c0a5864
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 139 deletions.
152 changes: 152 additions & 0 deletions torch/csrc/jit/codegen/cuda/expr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

//! A Expr represents a "computation." These are functions that takes inputs
//! and produce outputs, inputs and outputs all being Vals. There are
//! specializations of BinaryOp which takes 2 inputs and produces 1 output, and
//! UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
//! immutable. Conceptually, Exprs could always be manipulated using unique
//! pointers, and we could add this later. However, for now Exprs can be
//! replaced in a fusion, but they cannot be modified in place.
//!
//! The IR is static single assignment (SSA). Values can only be defined as an
//! output of an Expr once. If they are re-defined the original definition is
//! deleted from the program, as opposed to an ordered redefinition of the
//! value in the program.
//!
//! Note: Registering an Expr with a Fusion is actually 2 parts, one part is
//! done in the Expr constructor, so that should be called on anything that
//! inherits Expr. The issue with having registration in Expr's constructor, is
//! that the constructor of an Expr will set ouputs and inputs. This
//! information is important for registration with Fuser, so it can track the
//! dependency chain.
//!
//! Adding an Expr:
//! Right now adding an Expr is quite involved. Expr's can be defined in ir.h
//! or in their own header file. The following is what is currently needed for
//! Expr definitions:
//!
//! 1) Definition inheriting from Expr.
//! - Members must be private or protected
//! - Accessor functions for members
//! - Constructors need to register with the Fusion after inputs/outputs
//! are defined
//! - Implementation of bool sameAs(...)
//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
//! 3) Default mutator function should be added to mutator.h/.cpp
//! 4) Printing functions should be added to ir_iostream.h/.cpp
//! 5) Lower case convenience functions should be added to arith.h/.cpp (If
//! user facing)
//! 6) An enum value must be added to ExprType in type.h
//! 7) A string entry must be added in expr_type_string_map
//! 8) Entry added to ir_graphviz .cpp/.h
//!
class TORCH_CUDA_CU_API Expr : public Statement {
public:
explicit Expr(IrBuilderPasskey, ExprType type);

Expr(const Expr* src, IrCloner* ir_cloner);

// Creates a new instance of the expression with all its field copied.
// Note that unlike IrCloner, this function only do a shallow copy
virtual Expr* shallowCopy() const = 0;

c10::optional<ExprType> getExprType() const override {
return etype_;
}

ExprType etype() const {
return etype_;
}

bool sameAs(const Statement* other) const override;

// Input/output accessors
const auto& inputs() const {
return inputs_;
}

const auto& outputs() const {
return outputs_;
}

auto input(size_t index) const {
return inputs_[index];
}

auto output(size_t index) const {
return outputs_[index];
}

// Dispatch functions, definitions in dispatch.cpp
template <typename T>
static void dispatch(T handler, Expr*);

template <typename T>
static void constDispatch(T handler, const Expr* const);

template <typename T>
static void mutatorDispatch(T mutator, Expr*);

// TODO: Protect based on being in kernel container
kir::Predicate* predicate() const;

// Creates a shallow copy the expression with the given predicate attached.
// TODO: Protect based on being in kernel container
Expr* withPredicate(kir::Predicate* predicate);

// TODO: Protect based on being in kernel container
kir::Predicate* writePredicate() const;

// Creates a shallow copy the expression with the given write-predicate
// attached.
// TODO: Protect based on being in kernel container
Expr* withWritePredicate(kir::Predicate* write_predicate);

protected:
// TODO: Protect based on being in kernel container
void setPredicate(kir::Predicate* predicate);

// TODO: Protect based on being in kernel container
void setWritePredicate(kir::Predicate* write_predicate);

void copyPredicatesFrom(const Expr* expr);

// TODO: Add Fusion passkey
void addInput(Val* input) {
TORCH_INTERNAL_ASSERT(input != nullptr);
inputs_.push_back(input);
}

// TODO: Add Fusion passkey
void addOutput(Val* output) {
TORCH_INTERNAL_ASSERT(output != nullptr);
outputs_.push_back(output);
}

ExprPasskey exprPasskey() {
return ExprPasskey();
}

private:
ExprType etype_ = ExprType::Invalid;
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;

kir::Predicate* predicate_ = nullptr;

// Only used for reduction-related expressions
kir::Predicate* write_predicate_ = nullptr;
};

class TORCH_CUDA_CU_API StructuredExpr : public Expr {
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_all_nodes.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/expr.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>

Expand Down
139 changes: 0 additions & 139 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <unordered_map>
#include <vector>

// TODO: Add more types (int32, int64)
// TODO: sameAs should have better logic to check against any type and return
// gracefully

Expand Down Expand Up @@ -397,144 +396,6 @@ class TORCH_CUDA_CU_API Val : public Statement {
int evaluator_index_ = -1;
};

//! A Expr represents a "computation." These are functions that takes inputs
//! and produce outputs, inputs and outputs all being Vals. There are
//! specializations of BinaryOp which takes 2 inputs and produces 1 output, and
//! UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
//! immutable. Conceptually, Exprs could always be manipulated using unique
//! pointers, and we could add this later. However, for now Exprs can be
//! replaced in a fusion, but they cannot be modified in place.
//!
//! The IR is static single assignment (SSA). Values can only be defined as an
//! output of an Expr once. If they are re-defined the original definition is
//! deleted from the program, as opposed to an ordered redefinition of the
//! value in the program.
//!
//! Note: Registering an Expr with a Fusion is actually 2 parts, one part is
//! done in the Expr constructor, so that should be called on anything that
//! inherits Expr. The issue with having registration in Expr's constructor, is
//! that the constructor of an Expr will set ouputs and inputs. This
//! information is important for registration with Fuser, so it can track the
//! dependency chain.
//!
//! Adding an Expr:
//! Right now adding an Expr is quite involved. Expr's can be defined in ir.h
//! or in their own header file. The following is what is currently needed for
//! Expr definitions:
//!
//! 1) Definition inheriting from Expr.
//! - Members must be private or protected
//! - Accessor functions for members
//! - Constructors need to register with the Fusion after inputs/outputs
//! are defined
//! - Implementation of bool sameAs(...)
//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
//! 3) Default mutator function should be added to mutator.h/.cpp
//! 4) Printing functions should be added to ir_iostream.h/.cpp
//! 5) Lower case convenience functions should be added to arith.h/.cpp (If
//! user facing)
//! 6) An enum value must be added to ExprType in type.h
//! 7) A string entry must be added in expr_type_string_map
//! 8) Entry added to ir_graphviz .cpp/.h
//!
class TORCH_CUDA_CU_API Expr : public Statement {
public:
explicit Expr(IrBuilderPasskey, ExprType type);

Expr(const Expr* src, IrCloner* ir_cloner);

// Creates a new instance of the expression with all its field copied.
// Note that unlike IrCloner, this function only do a shallow copy
virtual Expr* shallowCopy() const = 0;

c10::optional<ExprType> getExprType() const override {
return etype_;
}

ExprType etype() const {
return etype_;
}

bool sameAs(const Statement* other) const override;

// Input/output accessors
const auto& inputs() const {
return inputs_;
}

const auto& outputs() const {
return outputs_;
}

auto input(size_t index) const {
return inputs_[index];
}

auto output(size_t index) const {
return outputs_[index];
}

// Dispatch functions, definitions in dispatch.cpp
template <typename T>
static void dispatch(T handler, Expr*);

template <typename T>
static void constDispatch(T handler, const Expr* const);

template <typename T>
static void mutatorDispatch(T mutator, Expr*);

// TODO: Protect based on being in kernel container
kir::Predicate* predicate() const;

// Creates a shallow copy the expression with the given predicate attached.
// TODO: Protect based on being in kernel container
Expr* withPredicate(kir::Predicate* predicate);

// TODO: Protect based on being in kernel container
kir::Predicate* writePredicate() const;

// Creates a shallow copy the expression with the given write-predicate
// attached.
// TODO: Protect based on being in kernel container
Expr* withWritePredicate(kir::Predicate* write_predicate);

protected:
// TODO: Protect based on being in kernel container
void setPredicate(kir::Predicate* predicate);

// TODO: Protect based on being in kernel container
void setWritePredicate(kir::Predicate* write_predicate);

void copyPredicatesFrom(const Expr* expr);

// TODO: Add Fusion passkey
void addInput(Val* input) {
TORCH_INTERNAL_ASSERT(input != nullptr);
inputs_.push_back(input);
}

// TODO: Add Fusion passkey
void addOutput(Val* output) {
TORCH_INTERNAL_ASSERT(output != nullptr);
outputs_.push_back(output);
}

ExprPasskey exprPasskey() {
return ExprPasskey();
}

private:
ExprType etype_ = ExprType::Invalid;
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;

kir::Predicate* predicate_ = nullptr;

// Only used for reduction-related expressions
kir::Predicate* write_predicate_ = nullptr;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down

0 comments on commit c0a5864

Please sign in to comment.