Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use withPredicate to replace setPredicate to maintain Exprs immutable #2025

Merged
merged 16 commits into from
Oct 5, 2022
18 changes: 18 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,15 @@ void Expr::setPredicate(kir::Predicate* predicate) {
predicate_ = predicate;
}

Expr* Expr::withPredicate(kir::Predicate* predicate) {
auto result = shallowCopy();
if (predicate != nullptr) {
predicate = predicate->maybeWithReplacedExpr(this, result);
}
result->setPredicate(predicate);
return result;
}

kir::Predicate* Expr::writePredicate() const {
TORCH_INTERNAL_ASSERT(
container()->isA<kir::Kernel>(), "Function invalid for fusion.");
Expand All @@ -353,6 +362,15 @@ void Expr::setWritePredicate(kir::Predicate* write_predicate) {
write_predicate_ = write_predicate;
}

Expr* Expr::withWritePredicate(kir::Predicate* predicate) {
auto result = shallowCopy();
if (predicate != nullptr) {
predicate = predicate->maybeWithReplacedExpr(this, result);
}
result->setWritePredicate(predicate);
return result;
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
12 changes: 10 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ class TORCH_CUDA_CU_API Expr : public Statement {

Expr(const Expr* src, IrCloner* ir_cloner);

virtual Expr* shallowCopy() const = 0;
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved

c10::optional<ExprType> getExprType() const override {
return etype_;
}
Expand Down Expand Up @@ -467,15 +469,21 @@ class TORCH_CUDA_CU_API Expr : public Statement {
kir::Predicate* predicate() const;

// TODO: Protect based on being in kernel container
void setPredicate(kir::Predicate* predicate);
Expr* withPredicate(kir::Predicate* predicate);
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved

// TODO: Protect based on being in kernel container
kir::Predicate* writePredicate() const;

// TODO: Protect based on being in kernel container
void setWritePredicate(kir::Predicate* write_predicate);
Expr* withWritePredicate(kir::Predicate* write_predicate);

protected:
// TODO: Protect based on being in kernel container
void setPredicate(kir::Predicate* predicate);

// TODO: Protect based on being in kernel container
void setWritePredicate(kir::Predicate* write_predicate);

// TODO: Add Fusion passkey
void addInput(Val* input) {
TORCH_INTERNAL_ASSERT(input != nullptr);
Expand Down
46 changes: 46 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class TORCH_CUDA_CU_API FullOp : public Expr {

FullOp(const FullOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

bool sameAs(const Statement* other) const override;

DataType dtype() const {
Expand Down Expand Up @@ -64,6 +66,8 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr {

ARangeOp(const ARangeOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

bool sameAs(const Statement* other) const override;

DataType dtype() const {
Expand Down Expand Up @@ -127,6 +131,8 @@ class TORCH_CUDA_CU_API EyeOp : public Expr {

EyeOp(const EyeOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

bool sameAs(const Statement* other) const override;

DataType dtype() const {
Expand Down Expand Up @@ -172,6 +178,8 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr {

UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -201,6 +209,8 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr {

BinaryOp(const BinaryOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -239,6 +249,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {

RNGOp(const RNGOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

RNGOpType getRNGOpType() const {
return rng_op_type_;
}
Expand Down Expand Up @@ -298,6 +310,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr {

BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -346,6 +360,8 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr {

ReductionOp(const ReductionOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -394,6 +410,8 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {

GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

//! Number of expressions grouped horizontally. It does not reflect
//! iteration grouping.
size_t numExprs() const {
Expand Down Expand Up @@ -580,6 +598,8 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr {

WelfordOp(const WelfordOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return output().avg();
}
Expand Down Expand Up @@ -675,6 +695,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr {

GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

//! Number of expressions grouped horizontally. It does not reflect
//! iteration grouping. As horizontal grouping is not supported,
//! this always returns 1.
Expand Down Expand Up @@ -798,6 +820,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {

MmaOp(const MmaOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -856,6 +880,8 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr {

TransposeOp(const TransposeOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

TensorView* out() const {
return out_;
}
Expand Down Expand Up @@ -886,6 +912,8 @@ class TORCH_CUDA_CU_API ExpandOp : public Expr {

ExpandOp(const ExpandOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

TensorView* out() const {
return out_;
}
Expand Down Expand Up @@ -916,6 +944,8 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr {

TernaryOp(const TernaryOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -959,6 +989,8 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr {

ShiftOp(const ShiftOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -1008,6 +1040,8 @@ class TORCH_CUDA_CU_API GatherOp : public Expr {

GatherOp(const GatherOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -1054,6 +1088,8 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr {

ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -1087,6 +1123,8 @@ class TORCH_CUDA_CU_API ViewOp : public Expr {

ViewOp(const ViewOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

TensorView* out() const {
return out_;
}
Expand All @@ -1112,6 +1150,8 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr {

LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

Val* out() const {
return out_;
}
Expand Down Expand Up @@ -1691,6 +1731,8 @@ class TORCH_CUDA_CU_API Split : public Expr {

Split(const Split* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

IterDomain* outer() const {
return outer_;
}
Expand Down Expand Up @@ -1751,6 +1793,8 @@ class TORCH_CUDA_CU_API Merge : public Expr {

Merge(const Merge* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

IterDomain* out() const {
return out_;
}
Expand Down Expand Up @@ -1783,6 +1827,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {

Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner);

Expr* shallowCopy() const override;

IterDomain* outX() const {
return out_x_;
}
Expand Down
Loading