Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization and predicate generation #288

Merged
merged 4 commits into from
Aug 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/type.h>

// TODO(kir): remove
Expand Down Expand Up @@ -362,6 +363,13 @@ ForLoop::ForLoop(const ForLoop* src, IrCloner* ir_cloner)
body_(&src->body_, ir_cloner),
parent_scope_(ir_cloner->clone(src->parent_scope_)) {}

void ForLoop::setParentScope(Expr* scope) {
TORCH_INTERNAL_ASSERT(
!scope_utils::exprInScope(parentScope(), this),
"Cannot change parent scope if not already removed from previous parent.");
parent_scope_ = scope;
}

IfThenElse::IfThenElse(
Bool* cond,
const std::vector<Expr*>& if_body,
Expand All @@ -384,6 +392,13 @@ IfThenElse::IfThenElse(const IfThenElse* src, IrCloner* ir_cloner)
else_body_(&src->else_body_, ir_cloner),
parent_scope_(ir_cloner->clone(src->parent_scope_)) {}

void IfThenElse::setParentScope(Expr* scope) {
TORCH_INTERNAL_ASSERT(
!scope_utils::exprInScope(parentScope(), this),
"Cannot change parent scope if not already removed from previous parent.");
parent_scope_ = scope;
}

Val* TensorIndex::index(int i) const {
TORCH_INTERNAL_ASSERT(
nDims() > 0, "Tried to get an index of a 0-dim TensorIndex");
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,8 @@ class TORCH_CUDA_API ForLoop : public Expr {
return parent_scope_;
}

void setParentScope(Expr* scope);

private:
Val* const index_ = nullptr;
IterDomain* const iter_domain_;
Expand Down Expand Up @@ -671,6 +673,8 @@ class TORCH_CUDA_API IfThenElse : public Expr {
return parent_scope_;
}

void setParentScope(Expr* scope);

private:
Bool* const cond_ = nullptr;
Scope body_;
Expand Down
23 changes: 14 additions & 9 deletions torch/csrc/jit/codegen/cuda/lower_unroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void UnrollPass::handle(Expr* expr) {

// If we need a predicate, put expr inside an if then else
if (!(pred->isConst()) || !(pred->isConst() && pred->value().value())) {
non_trivial_pred_found = true;
kir::IfThenElse* inline_ite =
new kir::IfThenElse(pred, {expr}, {}, for_loops.back());
for_loops.back()->body().insert_before(expr, inline_ite);
Expand Down Expand Up @@ -71,8 +72,12 @@ void UnrollPass::handle(kir::ForLoop* fl) {

auto unroll_pred = UnrollPredicate::get(for_loops, fl, p2c_root_map);

TORCH_INTERNAL_ASSERT(
!(unroll_pred->isConst() && unroll_pred->value().value()));
kir::ForLoop* parent_scope = for_loops.empty() ? nullptr : for_loops.back();

kir::IfThenElse* unroll_ite =
new kir::IfThenElse(unroll_pred, {}, {}, for_loops.back());
new kir::IfThenElse(unroll_pred, {}, {}, parent_scope);

// Get the loop nest for the unrolled path
kir::ForLoop* unrolled_loop_nest = scope_utils::cloneLoopNest(fl, unroll_ite);
Expand All @@ -84,16 +89,16 @@ void UnrollPass::handle(kir::ForLoop* fl) {

// Add inline predicates for inlined loop nest
look_for_unroll = false;
non_trivial_pred_found = false;
handle(inlined_loop);
look_for_unroll = true;

unroll_ite->elseBody().push_back(inlined_loop);

// Inner most inlined loop
Expr* inner_most_inlined_loop =
scope_utils::firstInnerMostScope(inlined_loop);

loop_replacement_map.insert({fl, unroll_ite});
if (!non_trivial_pred_found) {
inlined_loop->setParentScope(parent_scope);
loop_replacement_map.insert({fl, inlined_loop});
} else {
unroll_ite->elseBody().push_back(inlined_loop);
loop_replacement_map.insert({fl, unroll_ite});
}
}

// Generate the loop nest structure and place it in lowered_exprs
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_unroll.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class TORCH_CUDA_API UnrollPass : public OptOutDispatch {
// keep track if we're within an unrolled loop
bool look_for_unroll = true;

// As we generate inline predicates check if we actually generated a
// non-trivial one.
bool non_trivial_pred_found = false;

// Custom dispatch for Expr, want to find out of it's a TV op
void handle(Expr*) final;

Expand Down
106 changes: 39 additions & 67 deletions torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,6 @@ class Loops : private OptInDispatch {
}
};

class forLoopCount : private OptInDispatch {
private:
unsigned int count_ = 0;

void handle(kir::ForLoop* fl) final {
count_++;
}

void handle(kir::IfThenElse* ite) final {}

void handle(Expr* expr) final {
OptInDispatch::handle(expr);
}

public:
static unsigned int get(Expr* scope) {
forLoopCount flc;
Expr* it = scope;
while (it != nullptr) {
flc.handle(it);
it = scope_utils::getParent(it);
}
return flc.count_;
}
};

class scopePushBack : private OptInDispatch {
private:
Expr* expr_;
Expand Down Expand Up @@ -121,50 +95,61 @@ class scopeInsertBefore : private OptInDispatch {
}
};

class parentScope : private OptInDispatch {
class ExprInScope : private OptInDispatch {
private:
Expr* parent_ = nullptr;
Expr* expr_;
bool contains_ = false;

void handle(kir::ForLoop* fl) final {
parent_ = fl->parentScope();
if (fl->body().contains(expr_)) {
contains_ = true;
}
}

void handle(kir::IfThenElse* ite) final {
parent_ = ite->parentScope();
if (ite->body().contains(expr_)) {
contains_ = true;
}
}

void handle(Expr* expr) final {
OptInDispatch::handle(expr);
}

ExprInScope(Expr* expr) : expr_(expr) {}

public:
static Expr* get(Expr* scope) {
parentScope sp;
sp.handle(scope);
return sp.parent_;
static bool find(Expr* scope, Expr* expr) {
ExprInScope eis(expr);
TORCH_INTERNAL_ASSERT(
expr != nullptr && scope != nullptr,
"Cannot push back, scope or expr is a nullptr.");
eis.handle(scope);
return eis.contains_;
}
};

class scopeClearExprs : private OptInDispatch {
class parentScope : private OptInDispatch {
private:
Expr* parent_ = nullptr;

void handle(kir::ForLoop* fl) final {
fl->body().clear();
parent_ = fl->parentScope();
}

void handle(kir::IfThenElse* ite) final {
ite->body().clear();
parent_ = ite->parentScope();
}

void handle(Expr* expr) final {
OptInDispatch::handle(expr);
}

public:
static void clear(Expr* scope) {
scopeClearExprs sce;
TORCH_INTERNAL_ASSERT(
scope != nullptr, "Cannot clear scope, scope is a nullptr.");
sce.handle(scope);
static Expr* get(Expr* scope) {
parentScope sp;
sp.handle(scope);
return sp.parent_;
}
};

Expand Down Expand Up @@ -310,14 +295,6 @@ std::vector<kir::ForLoop*> getLoops(Expr* scope) {
return Loops::getLoops(scope);
}

// Track how far our for loop scope is
unsigned int computeForDepth(Expr* scope) {
if (scope == nullptr)
return 0;
assertScope(scope);
return forLoopCount::get(scope);
}

// Push back an expr to scope
void pushBack(Expr* scope, Expr* expr) {
TORCH_INTERNAL_ASSERT(
Expand All @@ -331,6 +308,10 @@ void insertBefore(Expr* scope, Expr* ref, Expr* expr) {
scopeInsertBefore::insert(scope, ref, expr);
}

bool exprInScope(Expr* scope, Expr* expr) {
return ExprInScope::find(scope, expr);
}

// Return the parent of the active scope
Expr* getParent(Expr* scope) {
TORCH_INTERNAL_ASSERT(
Expand All @@ -357,22 +338,6 @@ kir::ForLoop* openFor(Expr* scope, IterDomain* id) {
return new_scope;
}

// Close the inner most for loop
Expr* closeScope(Expr* scope) {
TORCH_INTERNAL_ASSERT(
scope != nullptr, "Tried to close a scope but got a nullptr.");
return getParent(scope);
}

// Clear all expressions from the scope
Expr* clearScope(Expr* scope) {
TORCH_INTERNAL_ASSERT(
scope != nullptr, "Tried to clear a scope but got a nullptr.");
assertScope(scope);
scopeClearExprs::clear(scope);
return scope;
}

kir::ForLoop* cloneLoopNest(kir::ForLoop* to_clone, Expr* parent_scope) {
return CloneLoopNest::getClone(to_clone, parent_scope);
}
Expand Down Expand Up @@ -676,9 +641,16 @@ std::pair<kir::ForLoop*, int64_t> getAllocPoint(
loop->iter_domain()->getParallelType() == ParallelType::Unroll;
});

if (loops_it == loops.end()) {
for (auto loop : loops) {
std::cout << loop->iter_domain() << " ";
}
std::cout << std::endl;
}
TORCH_INTERNAL_ASSERT(
loops_it != loops.end(),
"Could not find all required axes for indexing.");
"Could not find all required axes for indexing when trying to index into ",
tv);

if (kir_ca_id->getParallelType() == ParallelType::Unroll) {
return {alloc_loop, tv_i};
Expand Down
9 changes: 3 additions & 6 deletions torch/csrc/jit/codegen/cuda/lower_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,15 @@ void pushBack(Expr* scope, Expr* expr);
// Insert expr in scope before ref
void insertBefore(Expr* scope, Expr* ref, Expr* expr);

// Returns if expr is in scope, does not check nested scopes
bool exprInScope(Expr* scope, Expr* expr);

// Return the parent of the active scope
Expr* getParent(Expr* scope);

// Open a new inner most for loop
kir::ForLoop* openFor(Expr* scope, IterDomain*);

// Close the inner most for loop
Expr* closeScope(Expr* scope);

// Clear all expressions from the scope
Expr* clearScope(Expr* scope);

// Provide a new for loop matching the one provided, sets parent_scope as
// parent_scope, but does not insert into parent scope.
kir::ForLoop* cloneLoopNest(kir::ForLoop* to_clone, Expr* parent_scope);
Expand Down
22 changes: 19 additions & 3 deletions torch/csrc/jit/codegen/cuda/predicate_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>

namespace torch {
Expand Down Expand Up @@ -103,10 +105,24 @@ kir::Bool* PredicateCompute::getInlinePredicate(
auto pred_inds =
Index::getConsumerRootPredIndices(out_tv, loops, pred_contiguity);
auto root_indices = pred_inds.first;
bool use_rfactor = pred_inds.second;
bool use_maybe_rfactor = pred_inds.second;

if (out_tv->getMemoryType() != MemoryType::Global && out_tv->hasReduction() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not skip predicates for shared memory buffers as well. No predicate would mean shared memory buffers would be zero-initialized by all threads redundantly, right? That would cause race conditions, then.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't think about it, yeah that seems fair. Could you please commit that change and then go ahead and merge this if you approve otherwise?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the commit.

!use_maybe_rfactor) {
auto tv_filter_inp_view =
ir_utils::filterByType<TensorView>(expr->inputs());
auto has_tv_inputs = tv_filter_inp_view.begin() != tv_filter_inp_view.end();
// If predicates doesn't need maybe_rfactor, but it has reduction axes, and
// expr has no inputs, we're pretty confident we're intializing a reduction
// buffer. If we're initing a reduction buffer don't generate an inline
// predicate.
if (!has_tv_inputs) {
return new kir::Bool(true);
}
}

auto all_preds =
PredicateCompute::computePredicates(out_tv, root_indices, use_rfactor);
auto all_preds = PredicateCompute::computePredicates(
out_tv, root_indices, use_maybe_rfactor);

// If we have thread predicates, add those
if (thread_pred != nullptr) {
Expand Down