Skip to content

Commit

Permalink
Remove some welford specific logic. (#1864)
Browse files Browse the repository at this point in the history
* Remove some welford specific logic.

* Multi-reduction fix

* Some more minor cleanup.

* Add a note on multi-input reductions

Co-authored-by: Naoya Maruyama <nmaruyama@nvidia.com>
  • Loading branch information
csarofeen and naoyam authored Jul 25, 2022
1 parent 51589d3 commit 1013eda
Show file tree
Hide file tree
Showing 18 changed files with 107 additions and 124 deletions.
6 changes: 0 additions & 6 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,12 +1382,6 @@ WelfordResult::WelfordResult(
TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition()));
}

WelfordResult WelfordResult::rFactor(const std::vector<int>& axes) {
auto o_tv = avg->definition()->as<WelfordOp>()->out()->as<TensorView>();
auto rf_tvs = o_tv->rFactor(axes, std::vector<TensorView*>{avg, var_sum, n});
return WelfordResult{rf_tvs.at(0), rf_tvs.at(1), rf_tvs.at(2)};
}

// COMPOUND OPERATIONS

// add_alpha
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ class TORCH_CUDA_CU_API WelfordResult {
TensorView* in_avg,
TensorView* in_var_sum,
TensorView* in_n);

WelfordResult rFactor(const std::vector<int>& axes);
};

//! Welford operator on specified axes. This is currently the only scan op with
Expand Down
9 changes: 0 additions & 9 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,6 @@ void Fusion::addOutput(Val* output) {
all_tv_uses_valid_ = false;
}

void Fusion::addOutput(WelfordResult& wr) {
// Want to always make sure the avg gets added last
// since avg will be the out() value of welfordOp,
// and want to make it the top of the computeAt chain
addOutput(wr.var_sum);
addOutput(wr.n);
addOutput(wr.avg);
}

void Fusion::removeInput(Val* input) {
auto find_input = std::find(inputs_.begin(), inputs_.end(), input);
if (find_input != inputs_.end()) {
Expand Down
3 changes: 0 additions & 3 deletions torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
//! Register output as an output of the fusion
void addOutput(Val* output);

//! Register output as an output of the fusion
void addOutput(WelfordResult& output);

//! Deregister input as an input of the fusion
void removeInput(Val* input);

Expand Down
41 changes: 26 additions & 15 deletions torch/csrc/jit/codegen/cuda/ir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>

#include <set>

Expand Down Expand Up @@ -473,25 +474,23 @@ TensorView* rfactorHelper(
TensorView* reduction_tv,
const std::vector<int>& axes) {
TORCH_INTERNAL_ASSERT(reduction_tv->definition() != nullptr);
const bool is_welford = reduction_tv->definition()->isA<WelfordOp>();
if (!is_welford) {
const bool has_multiple_tvs = reduction_tv->definition()->inputs().size() > 1;
if (!has_multiple_tvs) {
return reduction_tv->rFactor(axes);
}
auto welford = reduction_tv->definition()->as<WelfordOp>();
auto w_avg = welford->outAvg()->as<TensorView>();
auto w_var = welford->outVar()->as<TensorView>();
auto w_n = welford->outN()->as<TensorView>();

auto rtvs =
reduction_tv->rFactor(axes, std::vector<TensorView*>{w_avg, w_var, w_n});
std::vector<TensorView*> out_tvs;
std::transform(
reduction_tv->definition()->outputs().begin(),
reduction_tv->definition()->outputs().end(),
std::back_inserter(out_tvs),
[](Val* val) { return val->as<TensorView>(); });

if (reduction_tv == w_n) {
return rtvs.at(2);
} else if (reduction_tv == w_var) {
return rtvs.at(1);
} else {
return rtvs.at(0);
}
auto rf_tvs = reduction_tv->rFactor(axes, out_tvs);

return rf_tvs.at(std::distance(
out_tvs.begin(),
std::find(out_tvs.begin(), out_tvs.end(), reduction_tv)));
}

namespace {
Expand Down Expand Up @@ -809,6 +808,18 @@ Val* getReductionInitValOf(TensorView* tv) {
return init;
}

// TODO: Should mma be in here? Should we return true if it's a trivial
// reduction?
bool isReductionOp(const Expr* expr) {
// Note that GridReduction inherits ReductionOp
return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
expr->isA<WelfordOp>() || expr->isA<kir::GridWelford>();
}

bool isReductionTvOp(const Expr* expr) {
return ir_utils::isTvOp(expr) && isReductionOp(expr);
}

namespace {

struct ReplaceValInIndexVal : public OptInDispatch {
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
// Returns the initialization value of tv or nullptr if not initialized.
TORCH_CUDA_CU_API Val* getReductionInitValOf(TensorView* tv);

// Returns if Expr is a reduction op
TORCH_CUDA_CU_API bool isReductionOp(const Expr*);

// Returns if Expr is a reduction op with TensorView or TensorIndex
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);

template <typename T>
std::string toString(const T& nodes) {
std::stringstream ss;
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {

auto tv_inp = inp->as<TensorView>();

// Change for welford Op, we want the users of all outputs of welfordOp
// to use a single predicate name.
// If tv_inp was an output of a multi-output expression, just change it to a
// consistent sibling to use a single predicate name.
if (auto tv_def = tv_inp->definition()) {
if (auto wop = dynamic_cast<WelfordOp*>(tv_def)) {
tv_inp = wop->out()->as<TensorView>();
if (tv_def->outputs().size() > 1) {
tv_inp = ir_utils::getTvOutput(tv_def);
}
}

Expand Down
40 changes: 23 additions & 17 deletions torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>

#include <unordered_set>
Expand All @@ -23,29 +24,39 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
TORCH_INTERNAL_ASSERT(
root_id->definition() == nullptr, "Not root IterDomain: ", root_id);

if (tv->definition() == nullptr) {
auto def = tv->definition();

if (def == nullptr) {
// This is an input tensor, so no rfactor tensor to traverse.
return false;
}

const auto& inputs = tv->definition()->inputs();

// Check the reduction expression that produces tv
if (inputs.size() != 1 || !inputs[0]->isA<TensorView>() ||
(tv->definition()->getExprType() != ExprType::ReductionOp &&
tv->definition()->getExprType() != ExprType::WelfordOp)) {
// No rfactor producer found
if (!ir_utils::isReductionOp(def) || def->isA<MmaOp>()) {
return false;
}

auto producer = inputs[0]->as<TensorView>();
TORCH_INTERNAL_ASSERT(
def->inputs().size() == def->outputs().size(),
"This logic block assumes number of inputs is the same as number of outputs of reduction ops.");

// Reduction expr may have multiple inputs, just grab any TV
// input. Note that in theory it is possible that a
// GroupedReductionOp has rfactor inputs as well as non-rfactor
// inputs, so grabbing the one that actually corresponds to tv can
// be important. In reality, though, such a GroupedReductionOp
// should not happen as we do not group reductions of rfactor and
// non-rfactor tensor.
auto producer_tv = ir_utils::getTvInput(def);

if (!producer->hasRFactor()) {
TORCH_INTERNAL_ASSERT(producer_tv != nullptr);

if (!producer_tv->hasRFactor()) {
return false;
}

auto c2p = PairwiseRootDomainMap(producer, tv)
.mapConsumerToProducer(tv->domain(), producer->domain());
auto c2p = PairwiseRootDomainMap(producer_tv, tv)
.mapConsumerToProducer(tv->domain(), producer_tv->domain());

auto producer_id_it = c2p.find(root_id);
if (producer_id_it == c2p.end()) {
Expand All @@ -55,7 +66,7 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {

auto producer_root_id = producer_id_it->second;

return analyzeIfDerivedFromTrivialReduction(producer, producer_root_id);
return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id);
}

bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
Expand Down Expand Up @@ -109,11 +120,6 @@ bool TrivialReductionInfo::isDerived(IterDomain* id) const {
return domains_.find(id) != domains_.end();
}

bool TrivialReductionInfo::isDerivedFromRoot(IterDomain* id) const {
return domains_derived_from_root_.find(id) !=
domains_derived_from_root_.end();
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
3 changes: 0 additions & 3 deletions torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class TORCH_CUDA_CU_API TrivialReductionInfo {

bool isDerived(IterDomain* id) const;

// TODO: Not used, cleanup
bool isDerivedFromRoot(IterDomain* id) const;

private:
//! IterDomains that are derived only from trivial
//! reductons. Included domains are not limited to reduction axes as
Expand Down
15 changes: 7 additions & 8 deletions torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,13 @@ TensorView* getTvOutput(const Expr* expr) {
return nullptr;
}

bool isReductionOp(const Expr* expr) {
// Note that GridReduction inherits ReductionOp
return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
expr->isA<WelfordOp>() || expr->isA<kir::GridWelford>();
}

bool isReductionTvOp(const Expr* expr) {
return isTvOp(expr) && isReductionOp(expr);
TensorView* getTvInput(const Expr* expr) {
for (auto inp : expr->inputs()) {
if (auto tv = getTv(inp)) {
return tv;
}
}
return nullptr;
}

bool isScalarOp(const Expr* expr) {
Expand Down
7 changes: 2 additions & 5 deletions torch/csrc/jit/codegen/cuda/lower_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,8 @@ TORCH_CUDA_CU_API bool isTvOp(const Expr*);
// Returns the first output of Expr that is a TensorView
TORCH_CUDA_CU_API TensorView* getTvOutput(const Expr*);

// Returns if Expr is a reduction op
TORCH_CUDA_CU_API bool isReductionOp(const Expr*);

// Returns if Expr is a reduction op with TensorView or TensorIndex
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);
// Returns the first input of Expr that is a TensorView
TORCH_CUDA_CU_API TensorView* getTvInput(const Expr*);

bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map);

Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(

TORCH_INTERNAL_ASSERT(
red_expr->getExprType() != c10::nullopt &&
(red_expr->getExprType().value() == ExprType::ReductionOp ||
red_expr->getExprType().value() == ExprType::WelfordOp),
ir_utils::isReductionOp(red_expr),
"TensorView doesn't have a reduction.");

auto tv_inps = ir_utils::filterByType<TensorView>(fusion->inputs());
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(

TORCH_INTERNAL_ASSERT(
red_expr->getExprType() != c10::nullopt &&
(red_expr->getExprType().value() == ExprType::ReductionOp ||
red_expr->getExprType().value() == ExprType::WelfordOp),
ir_utils::isReductionOp(red_expr),
"TensorView doesn't have a reduction.");

auto properties =
Expand Down
29 changes: 9 additions & 20 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ class SchedulerTopologyChecker {
static bool hasPostReductionBCast(Fusion* fusion) {
auto all_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
// Welford can have 2 outputs, so do this on all found reduction tensor
// views
// Reductions can have multiple outputs, so do this on all found reduction
// tensor views
if (tv->hasReduction() && !tv->isFusionInput()) {
auto tv_chains = tvChains(DependencyCheck::getAllUseChains(tv));
// Propagate forward from reduction through all uses of the reduction
Expand Down Expand Up @@ -301,18 +301,17 @@ class SchedulerTopologyChecker {

// When checking post reduction vals, we need to make sure
// we are really checking paths starting from all outputs
// of multi-output reductions, i.e. welford. The reduction_tv
// vector is assumed to only have one of them.
// of multi-output reductions, i.e. welford/grouped reduction. The
// reduction_tv vector is assumed to only have one of them.
std::unordered_set<Val*> reduction_tv_set(
reduction_tvs.begin(), reduction_tvs.end());

for (auto red : reduction_tvs) {
if (red->definition()) {
if (auto wop = dynamic_cast<WelfordOp*>(red->definition())) {
for (auto wop_output : wop->outputs()) {
if (wop_output->isA<TensorView>()) {
reduction_tv_set.insert(wop_output);
}
if (ir_utils::isReductionOp(red->definition())) {
auto outs = red->definition()->outputs();
for (auto out_tv : ir_utils::filterByType<TensorView>(outs)) {
reduction_tv_set.insert(out_tv);
}
}
}
Expand Down Expand Up @@ -988,9 +987,8 @@ class PointWiseScheduler : public SchedulerEntry {

auto reduction_ops =
ir_utils::getReductionOps(fusion, true /* ignore_trivial */);
auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);

if (!reduction_ops.empty() || !welford_ops.empty()) {
if (!reduction_ops.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::PointWise, "no support for reduction ops");
return false;
Expand Down Expand Up @@ -1052,15 +1050,6 @@ class PersistentKernelScheduler : public SchedulerEntry {

auto reduction_ops =
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);
auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);
// For persistent schedule we want welford translated to average and
// standard deviation reductions.
if (welford_ops.begin() != welford_ops.end()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Persistent,
"no support for un-translated welford");
return false;
}

auto view_tvs = scheduler_utils::getViewTVs(fusion);
if (view_tvs.size() > 0) {
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ std::pair<bool, bool> canonicalDimReduction(

// Return a list of tensor views that are outputs of reduction operations. If
// multiple outputs of an expression are found, only include one in the list
// (WelfordOp)
TORCH_CUDA_CU_API std::vector<TensorView*> getReductionTvs(
Fusion* fusion,
bool ignore_trivial = true);
Expand Down
Loading

0 comments on commit 1013eda

Please sign in to comment.