Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 0 additions & 167 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,171 +559,6 @@ void Val::mutatorDispatch(T mutator, Val* val) {
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
}

template <typename T>
void Expr::mutatorDispatch(T mutator, Expr* expr) {
if (expr->isStrictlyA<FullOp>()) {
ptr(mutator)->mutate(expr->as<FullOp>());
return;
}
if (expr->isStrictlyA<ARangeOp>()) {
ptr(mutator)->mutate(expr->as<ARangeOp>());
return;
}
if (expr->isStrictlyA<EyeOp>()) {
ptr(mutator)->mutate(expr->as<EyeOp>());
return;
}
if (expr->isStrictlyA<UnaryOp>()) {
ptr(mutator)->mutate(expr->as<UnaryOp>());
return;
}
if (expr->isStrictlyA<BinaryOp>()) {
ptr(mutator)->mutate(expr->as<BinaryOp>());
return;
}
if (expr->isStrictlyA<TernaryOp>()) {
ptr(mutator)->mutate(expr->as<TernaryOp>());
return;
}
if (expr->isStrictlyA<SelectOp>()) {
ptr(mutator)->mutate(expr->as<SelectOp>());
return;
}
if (expr->isStrictlyA<RNGOp>()) {
ptr(mutator)->mutate(expr->as<RNGOp>());
return;
}
if (expr->isStrictlyA<ReductionOp>()) {
ptr(mutator)->mutate(expr->as<ReductionOp>());
return;
}
if (expr->isStrictlyA<GroupedReductionOp>()) {
ptr(mutator)->mutate(expr->as<GroupedReductionOp>());
return;
}
if (expr->isStrictlyA<WelfordOp>()) {
ptr(mutator)->mutate(expr->as<WelfordOp>());
return;
}
if (expr->isStrictlyA<GroupedWelfordOp>()) {
ptr(mutator)->mutate(expr->as<GroupedWelfordOp>());
return;
}
if (expr->isStrictlyA<LoadStoreOp>()) {
ptr(mutator)->mutate(expr->as<LoadStoreOp>());
return;
}
if (expr->isStrictlyA<MmaOp>()) {
ptr(mutator)->mutate(expr->as<MmaOp>());
return;
}
if (expr->isStrictlyA<BroadcastOp>()) {
ptr(mutator)->mutate(expr->as<BroadcastOp>());
return;
}
if (expr->isStrictlyA<SqueezeOp>()) {
ptr(mutator)->mutate(expr->as<SqueezeOp>());
return;
}
if (expr->isStrictlyA<Split>()) {
ptr(mutator)->mutate(expr->as<Split>());
return;
}
if (expr->isStrictlyA<Merge>()) {
ptr(mutator)->mutate(expr->as<Merge>());
return;
}
if (expr->isStrictlyA<Swizzle2D>()) {
ptr(mutator)->mutate(expr->as<Swizzle2D>());
return;
}
if (expr->isStrictlyA<TransposeOp>()) {
ptr(mutator)->mutate(expr->as<TransposeOp>());
return;
}
if (expr->isStrictlyA<ExpandOp>()) {
ptr(mutator)->mutate(expr->as<ExpandOp>());
return;
}
if (expr->isStrictlyA<ShiftOp>()) {
ptr(mutator)->mutate(expr->as<ShiftOp>());
return;
}
if (expr->isStrictlyA<GatherOp>()) {
ptr(mutator)->mutate(expr->as<GatherOp>());
return;
}
if (expr->isStrictlyA<ViewAsScalar>()) {
ptr(mutator)->mutate(expr->as<ViewAsScalar>());
return;
}
if (expr->isStrictlyA<ViewOp>()) {
ptr(mutator)->mutate(expr->as<ViewOp>());
return;
}
if (expr->isStrictlyA<kir::Allocate>()) {
ptr(mutator)->mutate(expr->as<kir::Allocate>());
return;
}
if (expr->isStrictlyA<kir::BlockSync>()) {
ptr(mutator)->mutate(expr->as<kir::BlockSync>());
return;
}
if (expr->isStrictlyA<kir::GridSync>()) {
ptr(mutator)->mutate(expr->as<kir::GridSync>());
return;
}
if (expr->isStrictlyA<kir::CpAsyncWait>()) {
ptr(mutator)->mutate(expr->as<kir::CpAsyncWait>());
return;
}
if (expr->isStrictlyA<kir::CpAsyncCommit>()) {
ptr(mutator)->mutate(expr->as<kir::CpAsyncCommit>());
return;
}
if (expr->isStrictlyA<kir::InitMagicZero>()) {
ptr(mutator)->mutate(expr->as<kir::InitMagicZero>());
return;
}
if (expr->isStrictlyA<kir::UpdateMagicZero>()) {
ptr(mutator)->mutate(expr->as<kir::UpdateMagicZero>());
return;
}
if (expr->isStrictlyA<kir::ForLoop>()) {
ptr(mutator)->mutate(expr->as<kir::ForLoop>());
return;
}
if (expr->isStrictlyA<kir::IfThenElse>()) {
ptr(mutator)->mutate(expr->as<kir::IfThenElse>());
return;
}
if (expr->isStrictlyA<kir::GridReduction>()) {
ptr(mutator)->mutate(expr->as<kir::GridReduction>());
return;
}
if (expr->isStrictlyA<kir::GroupedGridReduction>()) {
ptr(mutator)->mutate(expr->as<kir::GroupedGridReduction>());
return;
}
if (expr->isStrictlyA<kir::GridBroadcast>()) {
ptr(mutator)->mutate(expr->as<kir::GridBroadcast>());
return;
}
if (expr->isStrictlyA<kir::GridWelford>()) {
ptr(mutator)->mutate(expr->as<kir::GridWelford>());
return;
}
if (expr->isStrictlyA<kir::GroupedGridWelford>()) {
ptr(mutator)->mutate(expr->as<kir::GroupedGridWelford>());
return;
}
if (expr->isStrictlyA<kir::AllocateFusedReduction>()) {
ptr(mutator)->mutate(expr->as<kir::AllocateFusedReduction>());
return;
}
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}

template <typename T>
void Statement::mutatorDispatch(T mutator, Statement* stmt) {
if (stmt->isVal()) {
Expand Down Expand Up @@ -774,8 +609,6 @@ template void Statement::mutatorDispatch(OptOutMutator&, Statement*);
template void Statement::mutatorDispatch(OptOutMutator*, Statement*);
template void Val::mutatorDispatch(OptOutMutator&, Val*);
template void Val::mutatorDispatch(OptOutMutator*, Val*);
template void Expr::mutatorDispatch(OptOutMutator&, Expr*);
template void Expr::mutatorDispatch(OptOutMutator*, Expr*);

void OptOutDispatch::handle(Statement* s) {
Statement::dispatch(this, s);
Expand Down
44 changes: 0 additions & 44 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,50 +326,6 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(kir::Predicate*);
virtual void mutate(kir::TensorIndex*);

// Exprs
virtual void mutate(FullOp*);
virtual void mutate(ARangeOp*);
virtual void mutate(EyeOp*);
virtual void mutate(UnaryOp*);
virtual void mutate(BinaryOp*);
virtual void mutate(TernaryOp*);
virtual void mutate(SelectOp*);
virtual void mutate(RNGOp*);
virtual void mutate(ReductionOp*);
virtual void mutate(GroupedReductionOp*);
virtual void mutate(WelfordOp*);
virtual void mutate(GroupedWelfordOp*);
virtual void mutate(LoadStoreOp*);
virtual void mutate(MmaOp*);
virtual void mutate(BroadcastOp*);
virtual void mutate(SqueezeOp*);

virtual void mutate(Split*);
virtual void mutate(Merge*);
virtual void mutate(Swizzle2D*);
virtual void mutate(TransposeOp*);
virtual void mutate(ExpandOp*);
virtual void mutate(ShiftOp*);
virtual void mutate(GatherOp*);
virtual void mutate(ViewAsScalar*);
virtual void mutate(ViewOp*);

virtual void mutate(kir::Allocate*);
virtual void mutate(kir::BlockSync*);
virtual void mutate(kir::GridSync*);
virtual void mutate(kir::CpAsyncWait*);
virtual void mutate(kir::CpAsyncCommit*);
virtual void mutate(kir::InitMagicZero*);
virtual void mutate(kir::UpdateMagicZero*);
virtual void mutate(kir::ForLoop*);
virtual void mutate(kir::IfThenElse*);
virtual void mutate(kir::GridReduction*);
virtual void mutate(kir::GroupedGridReduction*);
virtual void mutate(kir::GridBroadcast*);
virtual void mutate(kir::GridWelford*);
virtual void mutate(kir::GroupedGridWelford*);
virtual void mutate(kir::AllocateFusedReduction*);

protected:
void removeExpr(IrContainer*, Expr*);
};
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ Expr::Expr(
attributes_(std::move(attributes)) {}

Expr* Expr::shallowCopy() const {
auto result = newObject(inputs(), outputs(), attributes());
auto result =
newObjectFunc()(ir_container_, inputs(), outputs(), attributes());
if (container()->isA<kir::Kernel>()) {
result->predicate_ = predicate_;
result->write_predicate_ = write_predicate_;
Expand Down
48 changes: 26 additions & 22 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ class TORCH_CUDA_CU_API Attribute : public Val {
}
};

using newObjectFuncType = Expr*(
IrContainer*,
std::vector<Val*>,
std::vector<Val*>,
std::vector<Statement*>);

//! A Expr represents a "computation." These are functions that takes inputs
//! and produce outputs, inputs and outputs all being Vals. There are
//! specializations of BinaryOp which takes 2 inputs and produces 1 output, and
Expand Down Expand Up @@ -480,19 +486,14 @@ class TORCH_CUDA_CU_API Expr : public Statement {
std::vector<Val*> outputs,
std::vector<Statement*> attributes);

virtual newObjectFuncType* newObjectFunc() const = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this "newObject" interface now return a function pointer? What's wrong with the previous version?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Nov 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of this container->removeExpr(op); in the mutate method. When an IR node is removed from a container, its vtable no longer exist, and op->newObject will cause segfault.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just do op->newObject() before removing it?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Nov 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we can not swap the order of op->newObject and container->removeExpr because this causes some consistency check failure (an IterDomain used in multiple expr).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ does not allow me to do the following either :(

auto fn = op->newObject;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, makes sense.


// Creates a new instance of the expression with all its field copied.
// Note that unlike IrCloner, this function only do a shallow copy
Expr* shallowCopy() const;

bool sameAs(const Statement* other) const override;

// Creates a new instance of the same expression type with the given inputs,
// outputs, and attributes.
virtual Expr* newObject(
std::vector<Val*> inputs,
std::vector<Val*> outputs,
std::vector<Statement*> attributes) const = 0;

// Input/output accessors
const auto& inputs() const {
return inputs_;
Expand Down Expand Up @@ -529,9 +530,6 @@ class TORCH_CUDA_CU_API Expr : public Statement {
template <typename T>
static void constDispatch(T handler, const Expr* const);

template <typename T>
static void mutatorDispatch(T mutator, Expr*);

// TODO: Protect based on being in kernel container
kir::Predicate* predicate() const;

Expand Down Expand Up @@ -599,20 +597,26 @@ bool Val::isDefinitionType() const {

#define NVFUSER_DECLARE_CLONE_AND_CREATE \
virtual Statement* clone(IrCloner* ir_cloner) const override; \
virtual Expr* newObject( \
static Expr* newObject( \
IrContainer* container, \
std::vector<Val*> inputs, \
std::vector<Val*> outputs, \
std::vector<Statement*> attributes) const override;

#define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \
Statement* ClassName::clone(IrCloner* ir_cloner) const { \
return IrBuilder::clone(this, ir_cloner); \
} \
Expr* ClassName::newObject( \
std::vector<Val*> inputs, \
std::vector<Val*> outputs, \
std::vector<Statement*> attributes) const { \
return IrBuilder::create<ClassName>(inputs, outputs, attributes); \
std::vector<Statement*> attributes); \
virtual newObjectFuncType* newObjectFunc() const override { \
return newObject; \
}

#define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \
Statement* ClassName::clone(IrCloner* ir_cloner) const { \
return IrBuilder::clone(this, ir_cloner); \
} \
Expr* ClassName::newObject( \
IrContainer* container, \
std::vector<Val*> inputs, \
std::vector<Val*> outputs, \
std::vector<Statement*> attributes) { \
return IrBuilder::create<ClassName>( \
container, inputs, outputs, attributes); \
}

} // namespace cuda
Expand Down
Loading