Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Transform replay refactor #45

Closed
wants to merge 8 commits into from
363 changes: 186 additions & 177 deletions test/cpp/jit/test_gpu.cpp

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ TORCH_CUDA_API Val* andOp(Val* v1, Val* v2) {
namespace {
// TODO: How do we adjust this so we can reduce to a single scalar value?
TensorView* newForReduction(TensorView* tv, std::vector<unsigned int> axes) {
auto orig_domain = tv->getRootDomain()->noReductions();
auto orig_domain = TensorDomain::noReductions(tv->getRootDomain());
std::set<unsigned int> axes_set(axes.begin(), axes.end());

std::vector<IterDomain*> new_domain;
Expand Down Expand Up @@ -281,8 +281,8 @@ Val* reductionOp(
TensorView* tv = static_cast<TensorView*>(v1);

TORCH_CHECK(
tv->getRootDomain() == tv->domain(),
"Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/reorder/computeAt.");
TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()),
"Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");

std::vector<unsigned int> uint_axes;
for (int axis : axes) {
Expand Down
8 changes: 0 additions & 8 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::Merge:
ptr(handler)->handle(static_cast<Merge*>(expr));
return;
case ExprType::Reorder:
ptr(handler)->handle(static_cast<Reorder*>(expr));
return;
case ExprType::UnaryOp:
ptr(handler)->handle(static_cast<UnaryOp*>(expr));
return;
Expand Down Expand Up @@ -183,9 +180,6 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::Merge:
ptr(handler)->handle(static_cast<const Merge*>(expr));
return;
case ExprType::Reorder:
ptr(handler)->handle(static_cast<const Reorder*>(expr));
return;
case ExprType::UnaryOp:
ptr(handler)->handle(static_cast<const UnaryOp*>(expr));
return;
Expand Down Expand Up @@ -276,8 +270,6 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
return ptr(mutator)->mutate(static_cast<Split*>(expr));
case ExprType::Merge:
return ptr(mutator)->mutate(static_cast<Merge*>(expr));
case ExprType::Reorder:
return ptr(mutator)->mutate(static_cast<Reorder*>(expr));
case ExprType::UnaryOp:
return ptr(mutator)->mutate(static_cast<UnaryOp*>(expr));
case ExprType::BinaryOp:
Expand Down
13 changes: 0 additions & 13 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ struct NamedScalar;
// Exprs
struct Split;
struct Merge;
struct Reorder;
struct UnaryOp;
struct BinaryOp;
struct TernaryOp;
Expand Down Expand Up @@ -112,7 +111,6 @@ struct TORCH_CUDA_API OptOutConstDispatch {
// Exprs
virtual void handle(const Split* const) {}
virtual void handle(const Merge* const) {}
virtual void handle(const Reorder* const) {}
virtual void handle(const UnaryOp* const) {}
virtual void handle(const BinaryOp* const) {}
virtual void handle(const TernaryOp* const) {}
Expand Down Expand Up @@ -152,7 +150,6 @@ struct TORCH_CUDA_API OptOutDispatch {
// Exprs
virtual void handle(Split*) {}
virtual void handle(Merge*) {}
virtual void handle(Reorder*) {}
virtual void handle(UnaryOp*) {}
virtual void handle(BinaryOp*) {}
virtual void handle(TernaryOp*) {}
Expand Down Expand Up @@ -214,9 +211,6 @@ struct TORCH_CUDA_API OptInConstDispatch {
virtual void handle(const Merge* const) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge.");
}
virtual void handle(const Reorder* const) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Reorder.");
}
virtual void handle(const UnaryOp* const) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp.");
}
Expand Down Expand Up @@ -294,9 +288,6 @@ struct TORCH_CUDA_API OptInDispatch {
virtual void handle(Merge*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge.");
}
virtual void handle(Reorder*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Reorder.");
}
virtual void handle(UnaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp.");
}
Expand Down Expand Up @@ -376,7 +367,6 @@ struct TORCH_CUDA_API OptOutMutator {
// Exprs
virtual Statement* mutate(Split*);
virtual Statement* mutate(Merge*);
virtual Statement* mutate(Reorder*);
virtual Statement* mutate(UnaryOp*);
virtual Statement* mutate(BinaryOp*);
virtual Statement* mutate(TernaryOp*);
Expand Down Expand Up @@ -445,9 +435,6 @@ struct TORCH_CUDA_API OptInMutator {
virtual Statement* mutate(Merge*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Merge.");
}
virtual Statement* mutate(Reorder*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Reorder.");
}
virtual Statement* mutate(UnaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for UnaryOp.");
}
Expand Down
26 changes: 13 additions & 13 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void InputsOf::handle(Val* v) {
inputs.emplace(v);
}

std::set<Val*> InputsOf::output(Fusion* fusion, Val* output_) {
std::unordered_set<Val*> InputsOf::output(Fusion* fusion, Val* output_) {
TORCH_CHECK(
fusion->hasOutput(output_),
"Asked for the inputs of ",
Expand Down Expand Up @@ -139,10 +139,10 @@ void Fusion::addOutput(Val* const output) {
assertInFusion(output, "Cannot register output ");
if (output->getValType().value() == ValType::TensorView) {
TensorView* tv = static_cast<TensorView* const>(output);
if (tv->getRootDomain()
->hasBroadcast()) // Go to the root as we can merge bcast and
// non-bcast dims, making a non-bcast dim.
TORCH_CHECK(
if (TensorDomain::hasBroadcast(tv->getRootDomain()))
// Go to the root as we can merge bcast and
// non-bcast dims, making a non-bcast dim.
TORCH_CHECK( // Should we warn instead?
false,
output,
" cannot be registered as an output as it has a broadcast axis.");
Expand Down Expand Up @@ -177,12 +177,12 @@ std::vector<Expr*> Fusion::exprs(bool from_outputs_only, bool breadth_first) {
return ExprSort::getExprs(this, from_outputs_only, breadth_first);
}

std::set<Val*> Fusion::inputsOf(Val* val) {
std::unordered_set<Val*> Fusion::inputsOf(Val* val) {
return InputsOf::output(this, val);
}

void Fusion::validateInputs() {
std::set<Val*> all_inputs;
std::unordered_set<Val*> all_inputs;
for (Val* out : outputs()) {
auto outs_inputs = inputsOf(out);
std::set_union(
Expand Down Expand Up @@ -249,7 +249,7 @@ StmtNameType Fusion::registerExpr(Expr* expr) {
}

for (Val* input : expr->inputs()) {
registerVal(input);
assertInFusion(input, "Input to expr is invalid, ");
if (uses_.find(input) == uses_.end()) {
uses_[input] = {expr};
} else {
Expand All @@ -258,7 +258,7 @@ StmtNameType Fusion::registerExpr(Expr* expr) {
}

for (Val* output : expr->outputs()) {
registerVal(output);
assertInFusion(output, "Output to expr is invalid, ");
auto it = origin_.find(output);
if (it != origin_.end()) {
removeExpr(it->second); // will also remove origin entry
Expand Down Expand Up @@ -293,25 +293,25 @@ bool Fusion::used(Val* val) const {
(uses_.find(val)->second.size() > 0);
}

const std::set<Val*>& Fusion::vals() const noexcept {
const std::unordered_set<Val*>& Fusion::vals() const noexcept {
return val_set_;
}

const std::deque<Val*>& Fusion::deterministic_vals() const noexcept {
return val_deque_;
}

const std::set<Expr*>& Fusion::unordered_exprs() const noexcept {
const std::unordered_set<Expr*>& Fusion::unordered_exprs() const noexcept {
return expr_set_;
}

std::set<Expr*> Fusion::uses(Val* val) const {
std::unordered_set<Expr*> Fusion::uses(Val* val) const {
assertInFusion(val, "Cannot detect where val was used, ");
if (uses_.find(val) != uses_.end()) {
auto ret = uses_.find(val)->second;
return ret;
}
return std::set<Expr*>();
return std::unordered_set<Expr*>();
}

Expr* Fusion::origin(Val* val) const {
Expand Down
20 changes: 10 additions & 10 deletions torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>

#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace torch {
Expand Down Expand Up @@ -82,12 +82,12 @@ struct ExprSort : public IterVisitor {

struct InputsOf : public IterVisitor {
private:
std::set<Val*> inputs;
std::unordered_set<Val*> inputs;

void handle(Val* v) final;

public:
static std::set<Val*> output(Fusion* fusion, Val* output_);
static std::unordered_set<Val*> output(Fusion* fusion, Val* output_);
};

/*
Expand Down Expand Up @@ -150,7 +150,7 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput {
bool from_outputs_only = false,
bool breadth_first = false);

std::set<Val*> inputsOf(Val* val);
std::unordered_set<Val*> inputsOf(Val* val);

// Assert that all leaves found from outputs are registered as an input.
void validateInputs();
Expand Down Expand Up @@ -179,15 +179,15 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput {
bool used(Val* val) const;

// Return the set of Vals registered with this fusion
const std::set<Val*>& vals() const noexcept;
const std::unordered_set<Val*>& vals() const noexcept;
// Return in insertion order
const std::deque<Val*>& deterministic_vals() const noexcept;

// Return the set of Exprs registered with this fusion
const std::set<Expr*>& unordered_exprs() const noexcept;
const std::unordered_set<Expr*>& unordered_exprs() const noexcept;

// Return all Exprs that use val
std::set<Expr*> uses(Val* val) const;
std::unordered_set<Expr*> uses(Val* val) const;

// Return the Expr that produces val
Expr* origin(Val* val) const;
Expand All @@ -203,9 +203,9 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput {

private:
// Sets of all Vals/Exprs registered with this fusion
std::set<Val*> val_set_;
std::unordered_set<Val*> val_set_;
std::deque<Val*> val_deque_;
std::set<Expr*> expr_set_;
std::unordered_set<Expr*> expr_set_;

// Return an int that monotonically increases for each val/expr, some are
// explicitly incremented by type.
Expand All @@ -225,7 +225,7 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput {

// Dependency tracking for Vals. Where did it come from? Where is it used?
std::unordered_map<Val*, Expr*> origin_;
std::unordered_map<Val*, std::set<Expr*>> uses_;
std::unordered_map<Val*, std::unordered_set<Expr*>> uses_;
};

} // namespace fuser
Expand Down
Loading