Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove some welford specific logic. #1864

Merged
merged 4 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,12 +1382,6 @@ WelfordResult::WelfordResult(
TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition()));
}

WelfordResult WelfordResult::rFactor(const std::vector<int>& axes) {
auto o_tv = avg->definition()->as<WelfordOp>()->out()->as<TensorView>();
auto rf_tvs = o_tv->rFactor(axes, std::vector<TensorView*>{avg, var_sum, n});
return WelfordResult{rf_tvs.at(0), rf_tvs.at(1), rf_tvs.at(2)};
}

// COMPOUND OPERATIONS

// add_alpha
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ class TORCH_CUDA_CU_API WelfordResult {
TensorView* in_avg,
TensorView* in_var_sum,
TensorView* in_n);

WelfordResult rFactor(const std::vector<int>& axes);
};

//! Welford operator on specified axes. This is currently the only scan op with
Expand Down
9 changes: 0 additions & 9 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,6 @@ void Fusion::addOutput(Val* output) {
all_tv_uses_valid_ = false;
}

void Fusion::addOutput(WelfordResult& wr) {
// Want to always make sure the avg gets added last
// since avg will be the out() value of welfordOp,
// and want to make it the top of the computeAt chain
addOutput(wr.var_sum);
addOutput(wr.n);
addOutput(wr.avg);
}

void Fusion::removeInput(Val* input) {
auto find_input = std::find(inputs_.begin(), inputs_.end(), input);
if (find_input != inputs_.end()) {
Expand Down
3 changes: 0 additions & 3 deletions torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
//! Register output as an output of the fusion
void addOutput(Val* output);

//! Register output as an output of the fusion
void addOutput(WelfordResult& output);

//! Deregister input as an input of the fusion
void removeInput(Val* input);

Expand Down
41 changes: 26 additions & 15 deletions torch/csrc/jit/codegen/cuda/ir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>

#include <set>

Expand Down Expand Up @@ -473,25 +474,23 @@ TensorView* rfactorHelper(
TensorView* reduction_tv,
const std::vector<int>& axes) {
TORCH_INTERNAL_ASSERT(reduction_tv->definition() != nullptr);
const bool is_welford = reduction_tv->definition()->isA<WelfordOp>();
if (!is_welford) {
const bool has_multiple_tvs = reduction_tv->definition()->inputs().size() > 1;
if (!has_multiple_tvs) {
return reduction_tv->rFactor(axes);
}
auto welford = reduction_tv->definition()->as<WelfordOp>();
auto w_avg = welford->outAvg()->as<TensorView>();
auto w_var = welford->outVar()->as<TensorView>();
auto w_n = welford->outN()->as<TensorView>();

auto rtvs =
reduction_tv->rFactor(axes, std::vector<TensorView*>{w_avg, w_var, w_n});
std::vector<TensorView*> out_tvs;
std::transform(
reduction_tv->definition()->outputs().begin(),
reduction_tv->definition()->outputs().end(),
std::back_inserter(out_tvs),
[](Val* val) { return val->as<TensorView>(); });

if (reduction_tv == w_n) {
return rtvs.at(2);
} else if (reduction_tv == w_var) {
return rtvs.at(1);
} else {
return rtvs.at(0);
}
auto rf_tvs = reduction_tv->rFactor(axes, out_tvs);

return rf_tvs.at(std::distance(
out_tvs.begin(),
std::find(out_tvs.begin(), out_tvs.end(), reduction_tv)));
}

namespace {
Expand Down Expand Up @@ -809,6 +808,18 @@ Val* getReductionInitValOf(TensorView* tv) {
return init;
}

// TODO: Should mma be in here? Should we return true if it's a trivial
// reduction?
bool isReductionOp(const Expr* expr) {
// Note that GridReduction inherits ReductionOp
return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
expr->isA<WelfordOp>() || expr->isA<kir::GridWelford>();
}

bool isReductionTvOp(const Expr* expr) {
return ir_utils::isTvOp(expr) && isReductionOp(expr);
}

namespace {

struct ReplaceValInIndexVal : public OptInDispatch {
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
// Returns the initialization value of tv or nullptr if not initialized.
TORCH_CUDA_CU_API Val* getReductionInitValOf(TensorView* tv);

// Returns if Expr is a reduction op
TORCH_CUDA_CU_API bool isReductionOp(const Expr*);

// Returns if Expr is a reduction op with TensorView or TensorIndex
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);

template <typename T>
std::string toString(const T& nodes) {
std::stringstream ss;
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
// Change for welford Op, we want the users of all outputs of welfordOp
// to use a single predicate name.
if (auto tv_def = tv_inp->definition()) {
// TODO: Do we need to do anything for grouped reduction here?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary for WelfordOp? The maps of ThreadPredicateMap have mappings for all outputs: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp#L281-L285

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented out this part, and nothing seems to fail.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are your comments not showing up inline in the files page? Strange.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comes from:
https://github.com/csarofeen/pytorch/pull/561/files#diff-48ec14efa321f9f6f479de4d2c9e377c847067825513a7231d94200d8ea60efaR141-R149

It doesn't seem to be necessarily related to correctness, but just wanting one predicate for all outputs. It's just moving from something like WelfordResult::var_sum to be WelfordResult::avg so that tv_inp is consistent when you hit:

const auto& pred_info = at(tv_inp);

If tv_inp is the result of a multi output expression, the same pred_info comes up for all those siblings.

Copy link
Owner Author

@csarofeen csarofeen Jul 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to update this logic, but once we cleanup predicate handling based on ID graph we can remove this type of logic.

if (auto wop = dynamic_cast<WelfordOp*>(tv_def)) {
tv_inp = wop->out()->as<TensorView>();
}
Expand Down
41 changes: 29 additions & 12 deletions torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,46 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
TORCH_INTERNAL_ASSERT(
root_id->definition() == nullptr, "Not root IterDomain: ", root_id);

if (tv->definition() == nullptr) {
auto def = tv->definition();

if (def == nullptr) {
// This is an input tensor, so no rfactor tensor to traverse.
return false;
}

const auto& inputs = tv->definition()->inputs();

// Check the reduction expression that produces tv
if (inputs.size() != 1 || !inputs[0]->isA<TensorView>() ||
(tv->definition()->getExprType() != ExprType::ReductionOp &&
tv->definition()->getExprType() != ExprType::WelfordOp)) {
// No rfactor producer found
if (!ir_utils::isReductionOp(def)) {
return false;
}

auto producer = inputs[0]->as<TensorView>();
// Find the corresponding input TV. Note that the reduction expr may
// have multiple inputs.
auto producer = def->inputs().at(std::distance(
def->outputs().begin(),
std::find(def->outputs().begin(), def->outputs().end(), tv)));

auto producer_tv = dynamic_cast<TensorView*>(producer);

// WelfordOp may have an Int input. Traverse to the avg input
if (def->isA<WelfordOp>() && producer_tv == nullptr) {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to grab the "right" producer? Can't we just take the first TV input? They should have to be aligned to be siblings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be fine with WelfordOp, but in GroupedReductionOp, in theory, the input tensors just need to have the same shape. It should be fine for some of them to have rfactor domains, although the current validation may not be flexible enough to accept such a case. So, picking the right input could be important.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would some reductions have rfactor and others not with grouped reduction? I assume you'd have to have some interesting view op in the dag?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will revisit this again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groupReductions can group an arbitrary set of ReductionOp exprs as long as they have the same input shape. I don't know if it could ever happen in practice, but it is in theory possible to group a reduction of a post-view tensor and another reduction of a tensor that has the same shape as the post-view tensor.

TORCH_INTERNAL_ASSERT(
producer == def->as<WelfordOp>()->inVar() ||
producer == def->as<WelfordOp>()->inN(),
"Invalid expr: ",
def->toString(),
", out TV: ",
tv->toString());
producer_tv = def->as<WelfordOp>()->inAvg()->as<TensorView>();
}

TORCH_INTERNAL_ASSERT(producer_tv != nullptr);

if (!producer->hasRFactor()) {
if (!producer_tv->hasRFactor()) {
return false;
}

auto c2p = PairwiseRootDomainMap(producer, tv)
.mapConsumerToProducer(tv->domain(), producer->domain());
auto c2p = PairwiseRootDomainMap(producer_tv, tv)
.mapConsumerToProducer(tv->domain(), producer_tv->domain());

auto producer_id_it = c2p.find(root_id);
if (producer_id_it == c2p.end()) {
Expand All @@ -55,7 +72,7 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {

auto producer_root_id = producer_id_it->second;

return analyzeIfDerivedFromTrivialReduction(producer, producer_root_id);
return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id);
}

bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
Expand Down
10 changes: 0 additions & 10 deletions torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,6 @@ TensorView* getTvOutput(const Expr* expr) {
return nullptr;
}

bool isReductionOp(const Expr* expr) {
// Note that GridReduction inherits ReductionOp
return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
expr->isA<WelfordOp>() || expr->isA<kir::GridWelford>();
}

bool isReductionTvOp(const Expr* expr) {
return isTvOp(expr) && isReductionOp(expr);
}

bool isScalarOp(const Expr* expr) {
for (auto out : expr->outputs())
if (!out->isScalar())
Expand Down
6 changes: 0 additions & 6 deletions torch/csrc/jit/codegen/cuda/lower_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ TORCH_CUDA_CU_API bool isTvOp(const Expr*);
// Returns the first output of Expr that is a TensorView
TORCH_CUDA_CU_API TensorView* getTvOutput(const Expr*);

// Returns if Expr is a reduction op
TORCH_CUDA_CU_API bool isReductionOp(const Expr*);

// Returns if Expr is a reduction op with TensorView or TensorIndex
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);

bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map);

//! Returns the iterdomain that maps to the thread dimension grouped
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,8 +822,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(

TORCH_INTERNAL_ASSERT(
red_expr->getExprType() != c10::nullopt &&
(red_expr->getExprType().value() == ExprType::ReductionOp ||
red_expr->getExprType().value() == ExprType::WelfordOp),
ir_utils::isReductionOp(red_expr),
"TensorView doesn't have a reduction.");

auto tv_inps = ir_utils::filterByType<TensorView>(fusion->inputs());
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(

TORCH_INTERNAL_ASSERT(
red_expr->getExprType() != c10::nullopt &&
(red_expr->getExprType().value() == ExprType::ReductionOp ||
red_expr->getExprType().value() == ExprType::WelfordOp),
ir_utils::isReductionOp(red_expr),
"TensorView doesn't have a reduction.");

auto properties =
Expand Down
29 changes: 9 additions & 20 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ class SchedulerTopologyChecker {
static bool hasPostReductionBCast(Fusion* fusion) {
auto all_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
// Welford can have 2 outputs, so do this on all found reduction tensor
// views
// Reductions can have multiple outputs, so do this on all found reduction
// tensor views
if (tv->hasReduction() && !tv->isFusionInput()) {
auto tv_chains = tvChains(DependencyCheck::getAllUseChains(tv));
// Propagate forward from reduction through all uses of the reduction
Expand Down Expand Up @@ -301,18 +301,17 @@ class SchedulerTopologyChecker {

// When checking post reduction vals, we need to make sure
// we are really checking paths starting from all outputs
// of multi-output reductions, i.e. welford. The reduction_tv
// vector is assumed to only have one of them.
// of multi-output reductions, i.e. welford/grouped reduction. The
// reduction_tv vector is assumed to only have one of them.
std::unordered_set<Val*> reduction_tv_set(
reduction_tvs.begin(), reduction_tvs.end());

for (auto red : reduction_tvs) {
if (red->definition()) {
if (auto wop = dynamic_cast<WelfordOp*>(red->definition())) {
for (auto wop_output : wop->outputs()) {
if (wop_output->isA<TensorView>()) {
reduction_tv_set.insert(wop_output);
}
if (ir_utils::isReductionOp(red->definition())) {
auto outs = red->definition()->outputs();
for (auto out_tv : ir_utils::filterByType<TensorView>(outs)) {
reduction_tv_set.insert(out_tv);
}
}
}
Expand Down Expand Up @@ -1000,9 +999,8 @@ class PointWiseScheduler : public SchedulerEntry {

auto reduction_ops =
ir_utils::getReductionOps(fusion, true /* ignore_trivial */);
auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);

if (!reduction_ops.empty() || !welford_ops.empty()) {
if (!reduction_ops.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::PointWise, "no support for reduction ops");
return false;
Expand Down Expand Up @@ -1065,15 +1063,6 @@ class PersistentKernelScheduler : public SchedulerEntry {

auto reduction_ops =
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);
auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);
// For persistent schedule we want welford translated to average and
// standard deviation reductions.
if (welford_ops.begin() != welford_ops.end()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Persistent,
"no support for un-translated welford");
return false;
}

auto view_tvs = scheduler_utils::getViewTVs(fusion);
if (view_tvs.size() > 0) {
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ std::pair<bool, bool> canonicalDimReduction(

// Return a list of tensor views that are outputs of reduction operations. If
// multiple outputs of an expression are found, only include one in the list
// (WelfordOp)
TORCH_CUDA_CU_API std::vector<TensorView*> getReductionTvs(
Fusion* fusion,
bool ignore_trivial = true);
Expand Down
20 changes: 9 additions & 11 deletions torch/csrc/jit/codegen/cuda/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,10 @@ TensorView* TensorView::multiOutputRfactorHelper(
!container()->isA<kir::Kernel>(),
"Function invalid for kernel container.");
// Hack:
// Semantically we should always keep the outputs of welfordOp scheduled
// the same but the user end cannot guarantee that.
// In order to guarantee that the rFactor is defined meaningfully the
// scheduling of the output TV that got the rfactor call is force replayed
// towards the other two
// Semantically we should always keep the outputs of multi reduction ops
// scheduled the same but the user end cannot guarantee that. In order to
// guarantee that the rFactor is defined meaningfully the scheduling of the
// output TV that got the rfactor call is force replayed towards the other two

if (!sameAs(tv)) {
auto root = tv->getRootDomain();
Expand All @@ -758,7 +757,7 @@ TensorView* TensorView::multiOutputRfactorHelper(
std::vector<IterDomain*> new_id;
for (auto id : domain()->domain()) {
TORCH_INTERNAL_ASSERT(
replay.getReplay().count(id), "Welford Replay Failed");
replay.getReplay().count(id), "Multi-output reduction replay failed");
new_id.push_back(replay.getReplay().at(id));
}

Expand Down Expand Up @@ -795,12 +794,11 @@ std::vector<TensorView*> TensorView::rFactor(
TORCH_CHECK(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
FusionGuard fg(fusion());
TORCH_CHECK(
definition() != nullptr &&
(definition()->getExprType() == ExprType::GroupedReductionOp ||
definition()->getExprType() == ExprType::WelfordOp),
"Error rfactoring welford ",
definition() != nullptr && ir_utils::isReductionOp(definition()),
"Error rfactoring multi-output reduction op ",
this,
" its definition is either a nullptr or not a GroupedReductionOp or a WelfordOp.");
" its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op.");

TORCH_CHECK(
!domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");

Expand Down
Loading