Skip to content

Commit

Permalink
Use withPredicate to replace setPredicate to maintain Exprs immutable (
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Oct 5, 2022
1 parent 197221b commit e040676
Show file tree
Hide file tree
Showing 15 changed files with 616 additions and 115 deletions.
19 changes: 19 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ void Expr::setPredicate(kir::Predicate* predicate) {
predicate_ = predicate;
}

Expr* Expr::withPredicate(kir::Predicate* predicate) {
auto result = shallowCopy();
result->setPredicate(predicate);
return result;
}

kir::Predicate* Expr::writePredicate() const {
TORCH_INTERNAL_ASSERT(
container()->isA<kir::Kernel>(), "Function invalid for fusion.");
Expand All @@ -353,6 +359,19 @@ void Expr::setWritePredicate(kir::Predicate* write_predicate) {
write_predicate_ = write_predicate;
}

Expr* Expr::withWritePredicate(kir::Predicate* predicate) {
auto result = shallowCopy();
result->setWritePredicate(predicate);
return result;
}

void Expr::copyPredicatesFrom(const Expr* expr) {
if (container()->isA<kir::Kernel>()) {
predicate_ = expr->predicate_;
write_predicate_ = expr->write_predicate_;
}
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
19 changes: 17 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,10 @@ class TORCH_CUDA_CU_API Expr : public Statement {

Expr(const Expr* src, IrCloner* ir_cloner);

// Creates a new instance of the expression with all its field copied.
// Note that unlike IrCloner, this function only do a shallow copy
virtual Expr* shallowCopy() const = 0;

c10::optional<ExprType> getExprType() const override {
return etype_;
}
Expand Down Expand Up @@ -466,16 +470,27 @@ class TORCH_CUDA_CU_API Expr : public Statement {
// TODO: Protect based on being in kernel container
kir::Predicate* predicate() const;

// Creates a shallow copy the expression with the given predicate attached.
// TODO: Protect based on being in kernel container
void setPredicate(kir::Predicate* predicate);
Expr* withPredicate(kir::Predicate* predicate);

// TODO: Protect based on being in kernel container
kir::Predicate* writePredicate() const;

// Creates a shallow copy the expression with the given write-predicate
// attached.
// TODO: Protect based on being in kernel container
void setWritePredicate(kir::Predicate* write_predicate);
Expr* withWritePredicate(kir::Predicate* write_predicate);

protected:
// TODO: Protect based on being in kernel container
void setPredicate(kir::Predicate* predicate);

// TODO: Protect based on being in kernel container
void setWritePredicate(kir::Predicate* write_predicate);

void copyPredicatesFrom(const Expr* expr);

// TODO: Add Fusion passkey
void addInput(Val* input) {
TORCH_INTERNAL_ASSERT(input != nullptr);
Expand Down
46 changes: 46 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class TORCH_CUDA_CU_API FullOp : public Expr {

FullOp(const FullOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

bool sameAs(const Statement* other) const override;

DataType dtype() const {
Expand Down Expand Up @@ -64,6 +66,8 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr {

ARangeOp(const ARangeOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

bool sameAs(const Statement* other) const override;

DataType dtype() const {
Expand Down Expand Up @@ -127,6 +131,8 @@ class TORCH_CUDA_CU_API EyeOp : public Expr {

EyeOp(const EyeOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

bool sameAs(const Statement* other) const override;

DataType dtype() const {
Expand Down Expand Up @@ -172,6 +178,8 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr {

UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -201,6 +209,8 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr {

BinaryOp(const BinaryOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -239,6 +249,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {

RNGOp(const RNGOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

RNGOpType getRNGOpType() const {
return rng_op_type_;
}
Expand Down Expand Up @@ -298,6 +310,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr {

BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -346,6 +360,8 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr {

ReductionOp(const ReductionOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -394,6 +410,8 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {

GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

//! Number of expressions grouped horizontally. It does not reflect
//! iteration grouping.
size_t numExprs() const {
Expand Down Expand Up @@ -580,6 +598,8 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr {

WelfordOp(const WelfordOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return output().avg();
}
Expand Down Expand Up @@ -675,6 +695,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr {

GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

//! Number of expressions grouped horizontally. It does not reflect
//! iteration grouping. As horizontal grouping is not supported,
//! this always returns 1.
Expand Down Expand Up @@ -798,6 +820,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {

MmaOp(const MmaOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -856,6 +880,8 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr {

TransposeOp(const TransposeOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

TensorView* out() const {
return out_;
}
Expand Down Expand Up @@ -886,6 +912,8 @@ class TORCH_CUDA_CU_API ExpandOp : public Expr {

ExpandOp(const ExpandOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

TensorView* out() const {
return out_;
}
Expand Down Expand Up @@ -916,6 +944,8 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr {

TernaryOp(const TernaryOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -959,6 +989,8 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr {

ShiftOp(const ShiftOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -1008,6 +1040,8 @@ class TORCH_CUDA_CU_API GatherOp : public Expr {

GatherOp(const GatherOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -1054,6 +1088,8 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr {

ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -1087,6 +1123,8 @@ class TORCH_CUDA_CU_API ViewOp : public Expr {

ViewOp(const ViewOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

TensorView* out() const {
return out_;
}
Expand All @@ -1112,6 +1150,8 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr {

LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -1691,6 +1731,8 @@ class TORCH_CUDA_CU_API Split : public Expr {

Split(const Split* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

IterDomain* outer() const {
return outer_;
}
Expand Down Expand Up @@ -1751,6 +1793,8 @@ class TORCH_CUDA_CU_API Merge : public Expr {

Merge(const Merge* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

IterDomain* out() const {
return out_;
}
Expand Down Expand Up @@ -1783,6 +1827,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {

Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

IterDomain* outX() const {
return out_x_;
}
Expand Down
Loading

0 comments on commit e040676

Please sign in to comment.