Skip to content

Commit

Permalink
setPredicate->withPredicate
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Oct 3, 2022
1 parent fb6e6f3 commit 051f9a3
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 47 deletions.
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,14 @@ class TORCH_CUDA_CU_API Expr : public Statement {
// TODO: Protect based on being in kernel container
Expr* withWritePredicate(kir::Predicate* write_predicate);

protected:

// TODO: Protect based on being in kernel container
void setPredicate(kir::Predicate* predicate);

// TODO: Protect based on being in kernel container
void setWritePredicate(kir::Predicate* write_predicate);

protected:

// TODO: Add Fusion passkey
void addInput(Val* input) {
TORCH_INTERNAL_ASSERT(input != nullptr);
Expand Down
24 changes: 16 additions & 8 deletions torch/csrc/jit/codegen/cuda/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,10 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp {
return thread_predicate_;
}

void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
thread_predicate_ = thread_predicate;
GridReduction* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
auto result = shallowCopy()->as<GridReduction>();
result->thread_predicate_ = thread_predicate;
return result;
}

private:
Expand Down Expand Up @@ -659,8 +661,10 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp {
return thread_predicate_;
}

void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
thread_predicate_ = thread_predicate;
GroupedGridReduction *withThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
auto result = shallowCopy()->as<GroupedGridReduction>();
result->thread_predicate_ = thread_predicate;
return result;
}

private:
Expand Down Expand Up @@ -768,8 +772,10 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr {
return thread_predicate_;
}

void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
thread_predicate_ = thread_predicate;
GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
auto result = shallowCopy()->as<GridWelford>();
result->thread_predicate_ = thread_predicate;
return result;
}

private:
Expand Down Expand Up @@ -829,8 +835,10 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp {
return thread_predicate_;
}

void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
thread_predicate_ = thread_predicate;
GroupedGridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
auto result = shallowCopy()->as<GroupedGridWelford>();
result->thread_predicate_ = thread_predicate;
return result;
}

private:
Expand Down
63 changes: 41 additions & 22 deletions torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,12 @@ void IndexLowering::handleBlockReduction(
ReductionOp* indexed_rop = IrBuilder::create<ReductionOp>(
rop->getReductionOpType(), rop->init(), out, in, rop->isAllreduce());
if (rop->predicate()) {
indexed_rop->setPredicate(rop->predicate());
indexed_rop =
indexed_rop->withPredicate(rop->predicate())->as<ReductionOp>();
}
if (rop->writePredicate()) {
indexed_rop->setWritePredicate(rop->writePredicate());
indexed_rop = indexed_rop->withWritePredicate(rop->writePredicate())
->as<ReductionOp>();
}

pushBack(indexed_rop);
Expand Down Expand Up @@ -493,13 +495,15 @@ void IndexLowering::handleGridReduction(
n_entrances,
rop->isAllreduce());

grid_reduction->setThreadPredicate(thread_pred);
grid_reduction = grid_reduction->withThreadPredicate(thread_pred);

if (rop->predicate()) {
grid_reduction->setPredicate(rop->predicate());
grid_reduction = grid_reduction->withPredicate(rop->predicate())
->as<kir::GridReduction>();
}
if (rop->writePredicate()) {
grid_reduction->setWritePredicate(rop->writePredicate());
grid_reduction = grid_reduction->withWritePredicate(rop->writePredicate())
->as<kir::GridReduction>();
}

pushBack(grid_reduction);
Expand Down Expand Up @@ -556,10 +560,12 @@ void IndexLowering::handleBlockReduction(
inputs,
grouped_rop->isAllreduce());
if (grouped_rop->predicate()) {
indexed_rop->setPredicate(grouped_rop->predicate());
indexed_rop = indexed_rop->withPredicate(grouped_rop->predicate())
->as<GroupedReductionOp>();
}
if (grouped_rop->writePredicate()) {
indexed_rop->setWritePredicate(grouped_rop->writePredicate());
indexed_rop = indexed_rop->withWritePredicate(grouped_rop->writePredicate())
->as<GroupedReductionOp>();
}

pushBack(indexed_rop);
Expand Down Expand Up @@ -638,13 +644,16 @@ void IndexLowering::handleGridReduction(
work_buf_size_info.buffer_stride,
grouped_rop->isAllreduce());

grid_reduction->setThreadPredicate(thread_pred);
grid_reduction = grid_reduction->withThreadPredicate(thread_pred);

if (grouped_rop->predicate()) {
grid_reduction->setPredicate(grouped_rop->predicate());
grid_reduction = grid_reduction->withPredicate(grouped_rop->predicate())
->as<kir::GroupedGridReduction>();
}
if (grouped_rop->writePredicate()) {
grid_reduction->setWritePredicate(grouped_rop->writePredicate());
grid_reduction =
grid_reduction->withWritePredicate(grouped_rop->writePredicate())
->as<kir::GroupedGridReduction>();
}

pushBack(grid_reduction);
Expand Down Expand Up @@ -706,10 +715,11 @@ void IndexLowering::handle(const WelfordOp* wop) {
wop->isAllreduce());

if (wop->predicate()) {
indexed_wop->setPredicate(wop->predicate());
indexed_wop = indexed_wop->withPredicate(wop->predicate())->as<WelfordOp>();
}
if (wop->writePredicate()) {
indexed_wop->setWritePredicate(wop->writePredicate());
indexed_wop =
indexed_wop->withWritePredicate(wop->writePredicate())->as<WelfordOp>();
}

// Serial welford
Expand Down Expand Up @@ -785,22 +795,27 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
entrance_ind,
n_entrances);

grid_welford->setThreadPredicate(thread_pred);
grid_welford = grid_welford->withThreadPredicate(thread_pred);

const bool block_reduce_separated =
out_domain->hasBlockReduction() && !indexed_wop->isAllreduce();

if (indexed_wop->predicate()) {
if (block_reduce_separated) {
grid_welford->setPredicate(IrBuilder::create<kir::Predicate>(
GpuLower::current()->kernel()->trueVal()));
grid_welford = grid_welford
->withPredicate(IrBuilder::create<kir::Predicate>(
GpuLower::current()->kernel()->trueVal()))
->as<kir::GridWelford>();
} else {
grid_welford->setPredicate(indexed_wop->predicate());
grid_welford = grid_welford->withPredicate(indexed_wop->predicate())
->as<kir::GridWelford>();
}
}

if (indexed_wop->writePredicate()) {
grid_welford->setWritePredicate(indexed_wop->writePredicate());
grid_welford =
grid_welford->withWritePredicate(indexed_wop->writePredicate())
->as<kir::GridWelford>();
}

if (block_reduce_separated) {
Expand Down Expand Up @@ -945,13 +960,15 @@ void IndexLowering::handleGroupedGridWelford(
work_buf_size_info.buffer_stride,
op->isAllreduce());

indexed_op->setThreadPredicate(thread_pred);
indexed_op = indexed_op->withThreadPredicate(thread_pred);

if (op->predicate()) {
indexed_op->setPredicate(op->predicate());
indexed_op = indexed_op->withPredicate(op->predicate())
->as<kir::GroupedGridWelford>();
}
if (op->writePredicate()) {
indexed_op->setWritePredicate(op->writePredicate());
indexed_op = indexed_op->withWritePredicate(op->writePredicate())
->as<kir::GroupedGridWelford>();
}

pushBack(indexed_op);
Expand Down Expand Up @@ -997,7 +1014,8 @@ void IndexLowering::handle(const BroadcastOp* bop) {
const bool block_z = parallel_bitmap.get(ParallelType::BIDz);

if (bop->predicate()) {
indexed_expr->setPredicate(bop->predicate());
indexed_expr =
indexed_expr->withPredicate(bop->predicate())->as<BroadcastOp>();
}

const bool grid_broadcast_needed = block_x || block_y || block_z;
Expand All @@ -1024,7 +1042,8 @@ void IndexLowering::handle(const BroadcastOp* bop) {
indexed_expr, work_buffer, sync_buffer);

if (bop->predicate()) {
grid_broadcast->setPredicate(bop->predicate());
grid_broadcast = grid_broadcast->withPredicate(bop->predicate())
->as<kir::GridBroadcast>();
}

pushBack(grid_broadcast);
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace cuda {

namespace {

class ConditionalFromPredicateModifier : public kir::IrVisitor {
class ConditionalFromPredicateModifier : public kir::ExprMutator {
public:
ConditionalFromPredicateModifier() = delete;

Expand All @@ -33,10 +33,10 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor {
ConditionalFromPredicateModifier(const std::vector<Expr*>& exprs) {
FUSER_PERF_SCOPE(
"GpuLower::Lower::ConditionalFromPredicateModifier::process");
kir::IrVisitor::handle(exprs);
kir::ExprMutator::handle(exprs);
}

using kir::IrVisitor::handle;
using kir::ExprMutator::handle;

void handle(Expr* expr) final {
if (expr != nullptr && expr->predicate() != nullptr) {
Expand Down Expand Up @@ -131,7 +131,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor {
} else {
// If generateConditional returns null, it means no specific
// predicate needs to be used.
expr->setWritePredicate(nullptr);
registerReplace(expr, expr->withWritePredicate(nullptr)); // shallow copy clears predicates
}
}
}
Expand Down
11 changes: 6 additions & 5 deletions torch/csrc/jit/codegen/cuda/lower_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace jit {
namespace fuser {
namespace cuda {

void ShiftPredicateInserter::insert(
Expr* ShiftPredicateInserter::insert(
Expr* expr,
const std::vector<kir::ForLoop*>& loops,
Bool* thread_pred,
Expand All @@ -30,7 +30,7 @@ void ShiftPredicateInserter::insert(
const bool needs_shift_predicate =
gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition());
if (!needs_shift_predicate) {
return;
return expr;
}

// The conditional branches to create:
Expand All @@ -57,8 +57,7 @@ void ShiftPredicateInserter::insert(
// the expr with shift_pred. Since the expr is not shift, the
// padding is safe to omit.
if (lower_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) {
expr->setPredicate(shift_pred);
return;
return expr->withPredicate(shift_pred);
}

auto shift_ite = IrBuilder::create<kir::IfThenElse>(shift_pred);
Expand All @@ -76,7 +75,7 @@ void ShiftPredicateInserter::insert(

// No padding condition is required if this is within unswitch.
if (within_unswitch) {
return;
return expr;
}

// Padding by zero
Expand All @@ -89,6 +88,8 @@ void ShiftPredicateInserter::insert(
bounds_ite->thenBody().push_back(pad_expr);
// Insert the else block
shift_ite->elseBody().push_back(bounds_ite);

return expr;
}

int AxisHaloInfo::width() const {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_shift.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class ShiftPredicateInserter {
//! the generated predicate. The branch structure is different from
//! the usual predicated expression, so the insertion is also done
//! here.
static void insert(
static Expr* insert(
Expr* expr,
const std::vector<kir::ForLoop*>& loops,
Bool* thread_pred,
Expand Down
9 changes: 6 additions & 3 deletions torch/csrc/jit/codegen/cuda/lower_unroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ void UnrollPass::handle(Expr* expr) {
// When a predicate needs to account for ShiftOp, it is currently
// taken care by its own function.
if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) {
ShiftPredicateInserter::insert(
auto expr_with_predicate = ShiftPredicateInserter::insert(
expr, for_loops_, thread_pred, unswitched_loop_);
if (expr_with_predicate != expr) {
registerReplace(expr, expr_with_predicate);
}
return;
}

Expand All @@ -93,7 +96,7 @@ void UnrollPass::handle(Expr* expr) {
? thread_pred_expr
: IrBuilder::create<kir::Predicate>(
PredicateType::ReductionWrite, expr, thread_pred);
expr->setWritePredicate(write_pred);
registerReplace(expr, expr->withWritePredicate(write_pred));
}

// For expr calling a device func with block sync, don't create
Expand All @@ -103,7 +106,7 @@ void UnrollPass::handle(Expr* expr) {
? thread_pred_expr
: IrBuilder::create<kir::Predicate>(
PredicateType::Inline, expr, thread_pred);
expr->setPredicate(pred);
registerReplace(expr, expr->withPredicate(pred));
return;
}

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ class ReplaceExprInput : private kir::ExprMutator {

// Copy predicates and register expression replacement
void registerReplaceWithPredicate(Expr* old_expr, Expr* new_expr) {
new_expr->setPredicate(old_expr->predicate());
new_expr->setWritePredicate(old_expr->writePredicate());
new_expr = new_expr->withPredicate(old_expr->predicate())
->withWritePredicate(old_expr->writePredicate());
registerReplace(old_expr, new_expr);
}

Expand Down

0 comments on commit 051f9a3

Please sign in to comment.