Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the work to add a new expr: unify the structure of exprs #2190

Merged
merged 52 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
c0a5864
save structured expr
zasdfgbnm Nov 15, 2022
ee0f43a
save ExprType change
zasdfgbnm Nov 15, 2022
b7aedbd
revert
zasdfgbnm Nov 15, 2022
6149751
Merge branch 'devel' of github.com:csarofeen/pytorch into no-expr-type
zasdfgbnm Nov 15, 2022
77be1f4
more
zasdfgbnm Nov 15, 2022
cd8817c
one of
zasdfgbnm Nov 15, 2022
0ae70ab
std::type_index
zasdfgbnm Nov 15, 2022
1f9cb3b
more
zasdfgbnm Nov 15, 2022
cdee6d0
more
zasdfgbnm Nov 15, 2022
8304b84
more
zasdfgbnm Nov 15, 2022
1e66082
fix
zasdfgbnm Nov 15, 2022
31e61e3
fix reduction
zasdfgbnm Nov 15, 2022
7d5b4f7
isStrictlyOneOf
zasdfgbnm Nov 15, 2022
895b6f0
fix
zasdfgbnm Nov 15, 2022
79017b8
fix
zasdfgbnm Nov 15, 2022
20b5394
fix
zasdfgbnm Nov 15, 2022
de37d06
Merge branch 'devel' of github.com:csarofeen/pytorch into structured-…
zasdfgbnm Nov 15, 2022
e4e235a
Revert "save structured expr"
zasdfgbnm Nov 15, 2022
f8fccd0
Merge branch 'no-expr-type' of github.com:csarofeen/pytorch into stru…
zasdfgbnm Nov 15, 2022
31fc54b
FullOp compiles
zasdfgbnm Nov 16, 2022
0da63b0
ARangeOp
zasdfgbnm Nov 16, 2022
9e57950
TernaryOpType
zasdfgbnm Nov 16, 2022
3fb5d26
RNGOp
zasdfgbnm Nov 16, 2022
73db1ea
BroadcastOp SqueezeOp
zasdfgbnm Nov 16, 2022
1b2bc68
ReductionOp
zasdfgbnm Nov 16, 2022
f0e4e75
grouped reduction, welford, grouped welford
zasdfgbnm Nov 16, 2022
d7eb2de
runtime time
zasdfgbnm Nov 16, 2022
4afea44
mma transpose expand
zasdfgbnm Nov 16, 2022
00cc231
shift
zasdfgbnm Nov 16, 2022
77f3edf
fixes
zasdfgbnm Nov 16, 2022
2d14b4d
fix
zasdfgbnm Nov 16, 2022
586b29c
add clone
zasdfgbnm Nov 16, 2022
56f2249
Merge branch 'devel' of github.com:csarofeen/pytorch into structured-…
zasdfgbnm Nov 16, 2022
2baafe3
fix
zasdfgbnm Nov 16, 2022
64bd0d8
gather, view as scalar, view, load store
zasdfgbnm Nov 16, 2022
242228e
DEFINE_CLONE
zasdfgbnm Nov 16, 2022
393fa52
Allocate
zasdfgbnm Nov 17, 2022
e3f27a2
ForLoop
zasdfgbnm Nov 17, 2022
f72d6b6
fix
zasdfgbnm Nov 17, 2022
4338807
GroupedGridReduction
zasdfgbnm Nov 17, 2022
412420b
GridWelford
zasdfgbnm Nov 17, 2022
c7b860e
all expr done
zasdfgbnm Nov 17, 2022
bfdb742
fix
zasdfgbnm Nov 17, 2022
36ada0e
cleanup
zasdfgbnm Nov 17, 2022
56b03ee
rewrite IrCloner
zasdfgbnm Nov 17, 2022
c80587c
rename
zasdfgbnm Nov 18, 2022
522bfdf
Merge branch 'devel' of github.com:csarofeen/pytorch into structured-…
zasdfgbnm Nov 18, 2022
9bd3bc4
save
zasdfgbnm Nov 18, 2022
0987844
fix
zasdfgbnm Nov 18, 2022
c23ea6c
save
zasdfgbnm Nov 18, 2022
0971dfe
save
zasdfgbnm Nov 18, 2022
6fd6504
Statement attributes_
zasdfgbnm Nov 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2003,10 +2003,14 @@ class CudaKernelGenerator : private OptOutConstDispatch {
ArgumentBuilder read_preds;
ArgumentBuilder write_preds;

auto output_vals = grouped_gwop->outputVals();
auto input_vals = grouped_gwop->inputVals();
auto init_vals = grouped_gwop->initVals();

for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) {
const auto& output = grouped_gwop->outputVals().at(expr_index);
const auto& input = grouped_gwop->inputVals().at(expr_index);
const auto& init = grouped_gwop->initVals().at(expr_index);
const auto& output = output_vals.at(expr_index);
const auto& input = input_vals.at(expr_index);
const auto& init = init_vals.at(expr_index);

for (const auto& group_index :
c10::irange(index_replacement_maps.size())) {
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ void Val::constDispatch(T handler, const Val* val) {
case ValType::TensorIndex:
ptr(handler)->handle(val->as<kir::TensorIndex>());
return;
case ValType::Attribute:
// Attribute Val is just a wrapper for non-IR data, so there is nothing to
// handle
return;
default:
break;
}
Expand Down
41 changes: 32 additions & 9 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Statement::Statement(const Statement* src, IrCloner* ir_cloner) {
ir_container_ = ir_cloner->container();
}

NVFUSER_DEFINE_CLONE(Statement)

void Statement::setName(IrContainerPasskey, StmtNameType name) {
name_ = name;
}
Expand Down Expand Up @@ -98,6 +100,8 @@ Val::Val(IrBuilderPasskey passkey, ValType _vtype, DataType _dtype)
Val::Val(const Val* src, IrCloner* ir_cloner)
: Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {}

NVFUSER_DEFINE_CLONE(Val)

const std::vector<Expr*>& Val::uses() const {
if (vtype_ == ValType::TensorView) {
if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) {
Expand Down Expand Up @@ -319,7 +323,27 @@ Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {}
Expr::Expr(const Expr* src, IrCloner* ir_cloner)
: Statement(src, ir_cloner),
inputs_(ir_cloner->clone(src->inputs_)),
outputs_(ir_cloner->clone(src->outputs_)) {}
outputs_(ir_cloner->clone(src->outputs_)),
attributes_(ir_cloner->clone(src->attributes_)) {}

Expr::Expr(
IrBuilderPasskey passkey,
std::vector<Val*> inputs,
std::vector<Val*> outputs,
std::vector<Statement*> attributes)
: Statement(passkey),
inputs_(std::move(inputs)),
outputs_(std::move(outputs)),
attributes_(std::move(attributes)) {}

Expr* Expr::shallowCopy() const {
auto result = newObject(inputs(), outputs(), attributes());
if (container()->isA<kir::Kernel>()) {
result->predicate_ = predicate_;
result->write_predicate_ = write_predicate_;
}
return result;
}

bool Expr::sameAs(const Statement* other) const {
if (this == other) {
Expand All @@ -333,14 +357,20 @@ bool Expr::sameAs(const Statement* other) const {
return false;
}
if (inputs().size() != other_expr->inputs().size() ||
outputs().size() != other_expr->outputs().size()) {
outputs().size() != other_expr->outputs().size() ||
attributes().size() != other_expr->attributes().size()) {
return false;
}
for (const auto i : c10::irange(inputs().size())) {
if (!input(i)->sameAs(other_expr->input(i))) {
return false;
}
}
for (const auto i : c10::irange(attributes().size())) {
if (!attribute(i)->sameAs(other_expr->attribute(i))) {
return false;
}
}
return true;
}

Expand Down Expand Up @@ -380,13 +410,6 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) {
return result;
}

void Expr::copyPredicatesFrom(const Expr* expr) {
if (container()->isA<kir::Kernel>()) {
predicate_ = expr->predicate_;
write_predicate_ = expr->write_predicate_;
}
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
95 changes: 90 additions & 5 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>

#include <torch/csrc/jit/codegen/cuda/ir_builder_passkey.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>

Expand Down Expand Up @@ -70,6 +71,14 @@ class ExprPasskey {

TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept;

#define NVFUSER_DECLARE_CLONE \
virtual Statement* clone(IrCloner* ir_cloner) const override;

#define NVFUSER_DEFINE_CLONE(ClassName) \
Statement* ClassName::clone(IrCloner* ir_cloner) const { \
return IrBuilder::clone(this, ir_cloner); \
}

//! Statement is the highest level node representation. Everything that is
//! considered "IR" will be derived from this class at some point. Both Values
//! and Expr's are a Statement. If there will ever be any more fundamental
Expand Down Expand Up @@ -159,6 +168,8 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase {
std::string toString() const;
std::string toInlineString() const;

virtual Statement* clone(IrCloner* ir_cloner) const;

protected:
Statement(IrBuilderPasskey);

Expand Down Expand Up @@ -353,6 +364,8 @@ class TORCH_CUDA_CU_API Val : public Statement {

void resolveIndexDtype();

NVFUSER_DECLARE_CLONE

protected:
friend Fusion;

Expand Down Expand Up @@ -391,6 +404,31 @@ class TORCH_CUDA_CU_API Val : public Statement {
int evaluator_index_ = -1;
};

//! A Val object that stores a plain data. Note that this class is only intended
//! to hold non-IR data, such as DataType, std::vector<int>, etc. Please don't
//! use this class to hold IR nodes or their pointers.
template <typename T>
class TORCH_CUDA_CU_API Attribute : public Val {
public:
T value;
Attribute(IrBuilderPasskey passkey, const T& value)
: Val(passkey, ValType::Attribute), value(value) {}
Attribute(const Attribute* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), value(src->value) {}
template <typename... Args>
Attribute(IrBuilderPasskey passkey, Args... args)
: Val(passkey, ValType::Attribute), value(std::forward<Args>(args)...) {}

NVFUSER_DECLARE_CLONE

bool sameAs(const Statement* other) const override {
if (auto pv = dynamic_cast<const Attribute*>(other)) {
return pv->value == value;
}
return false;
}
};

//! A Expr represents a "computation." These are functions that takes inputs
//! and produce outputs, inputs and outputs all being Vals. There are
//! specializations of BinaryOp which takes 2 inputs and produces 1 output, and
Expand Down Expand Up @@ -436,12 +474,25 @@ class TORCH_CUDA_CU_API Expr : public Statement {

Expr(const Expr* src, IrCloner* ir_cloner);

Expr(
IrBuilderPasskey,
std::vector<Val*> inputs,
std::vector<Val*> outputs,
std::vector<Statement*> attributes);

// Creates a new instance of the expression with all its field copied.
// Note that unlike IrCloner, this function only do a shallow copy
virtual Expr* shallowCopy() const = 0;
Expr* shallowCopy() const;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method now is not even virtual


bool sameAs(const Statement* other) const override;

// Creates a new instance of the same expression type with the given inputs,
// outputs, and attributes.
virtual Expr* newObject(
std::vector<Val*> inputs,
std::vector<Val*> outputs,
std::vector<Statement*> attributes) const = 0;

// Input/output accessors
const auto& inputs() const {
return inputs_;
Expand All @@ -451,12 +502,24 @@ class TORCH_CUDA_CU_API Expr : public Statement {
return outputs_;
}

const auto& attributes() const {
return attributes_;
}

auto input(size_t index) const {
return inputs_[index];
return inputs_.at(index);
}

auto output(size_t index) const {
return outputs_[index];
return outputs_.at(index);
}

auto attribute(size_t index) const {
return attributes_.at(index);
}

auto attributeVal(size_t index) const {
return dynamic_cast<Val*>(attributes_.at(index));
}

// Dispatch functions, definitions in dispatch.cpp
Expand Down Expand Up @@ -494,8 +557,6 @@ class TORCH_CUDA_CU_API Expr : public Statement {
// TODO: Protect based on being in kernel container
void setWritePredicate(kir::Predicate* write_predicate);

void copyPredicatesFrom(const Expr* expr);

// TODO: Add Fusion passkey
void addInput(Val* input) {
TORCH_INTERNAL_ASSERT(input != nullptr);
Expand All @@ -508,13 +569,19 @@ class TORCH_CUDA_CU_API Expr : public Statement {
outputs_.push_back(output);
}

// TODO: Add Fusion passkey
void addAttribute(Statement* attr) {
attributes_.push_back(attr);
}

ExprPasskey exprPasskey() {
return ExprPasskey();
}

private:
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;
std::vector<Statement*> attributes_;

kir::Predicate* predicate_ = nullptr;

Expand All @@ -530,6 +597,24 @@ bool Val::isDefinitionType() const {
return false;
}

#define NVFUSER_DECLARE_CLONE_AND_CREATE \
virtual Statement* clone(IrCloner* ir_cloner) const override; \
virtual Expr* newObject( \
std::vector<Val*> inputs, \
std::vector<Val*> outputs, \
std::vector<Statement*> attributes) const override;

#define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \
Statement* ClassName::clone(IrCloner* ir_cloner) const { \
return IrBuilder::clone(this, ir_cloner); \
} \
Expr* ClassName::newObject( \
std::vector<Val*> inputs, \
std::vector<Val*> outputs, \
std::vector<Statement*> attributes) const { \
return IrBuilder::create<ClassName>(inputs, outputs, attributes); \
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
69 changes: 0 additions & 69 deletions torch/csrc/jit/codegen/cuda/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,75 +11,6 @@ namespace jit {
namespace fuser {
namespace cuda {

//! Clone an IR node, forwarding the arguments to the IrCloner constructor.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to header file

template <class T>
T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) {
TORCH_INTERNAL_ASSERT(
ir_cloner != nullptr,
"Cannot use create when a cloner object is set. Use clone.");

TORCH_INTERNAL_ASSERT(
ir_cloner->container() != nullptr,
"Cloner doesn't have a valid container to store cloned object.");

T* dest = new T(src, ir_cloner);
const Statement* src_stmt = dynamic_cast<const Statement*>(src);
Statement* dest_stmt = dynamic_cast<Statement*>(dest);

auto dest_container = ir_cloner->container();
auto src_container = src_stmt->container();

dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt);

if (src_container != dest_container) {
dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name());
}

ir_cloner->registerClone(src_stmt, dest_stmt);

return dest;
}

#define IR_BUILDER_INSTANTIATE(T) \
template T* IrBuilder::clone(const T* src, IrCloner* ir_cloner);

// Vals
IR_BUILDER_INSTANTIATE(IterDomain)
IR_BUILDER_INSTANTIATE(TensorDomain)
IR_BUILDER_INSTANTIATE(TensorView)
IR_BUILDER_INSTANTIATE(Bool)
IR_BUILDER_INSTANTIATE(Float)
IR_BUILDER_INSTANTIATE(Double)
IR_BUILDER_INSTANTIATE(Int)
IR_BUILDER_INSTANTIATE(ComplexDouble)
IR_BUILDER_INSTANTIATE(NamedScalar)

// Exprs
IR_BUILDER_INSTANTIATE(Split)
IR_BUILDER_INSTANTIATE(Merge)
IR_BUILDER_INSTANTIATE(Swizzle2D)
IR_BUILDER_INSTANTIATE(TransposeOp)
IR_BUILDER_INSTANTIATE(ExpandOp)
IR_BUILDER_INSTANTIATE(ShiftOp)
IR_BUILDER_INSTANTIATE(GatherOp)
IR_BUILDER_INSTANTIATE(ViewAsScalar)
IR_BUILDER_INSTANTIATE(ViewOp)
IR_BUILDER_INSTANTIATE(FullOp)
IR_BUILDER_INSTANTIATE(ARangeOp)
IR_BUILDER_INSTANTIATE(EyeOp)
IR_BUILDER_INSTANTIATE(UnaryOp)
IR_BUILDER_INSTANTIATE(BinaryOp)
IR_BUILDER_INSTANTIATE(TernaryOp)
IR_BUILDER_INSTANTIATE(SelectOp)
IR_BUILDER_INSTANTIATE(RNGOp)
IR_BUILDER_INSTANTIATE(ReductionOp)
IR_BUILDER_INSTANTIATE(GroupedReductionOp)
IR_BUILDER_INSTANTIATE(WelfordOp)
IR_BUILDER_INSTANTIATE(LoadStoreOp)
IR_BUILDER_INSTANTIATE(MmaOp)
IR_BUILDER_INSTANTIATE(BroadcastOp)
IR_BUILDER_INSTANTIATE(SqueezeOp)

Val* IrBuilder::newResult(DataType dtype) {
switch (dtype) {
case DataType::Bool:
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder {
static Val* minExpr(Val* lhs, Val* rhs);
};

template <DataType DT>
NVFUSER_DEFINE_CLONE(FloatingPoint<DT>)

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
Loading