-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove some welford specific logic. #1864
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,29 +23,46 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) { | |
TORCH_INTERNAL_ASSERT( | ||
root_id->definition() == nullptr, "Not root IterDomain: ", root_id); | ||
|
||
if (tv->definition() == nullptr) { | ||
auto def = tv->definition(); | ||
|
||
if (def == nullptr) { | ||
// This is an input tensor, so no rfactor tensor to traverse. | ||
return false; | ||
} | ||
|
||
const auto& inputs = tv->definition()->inputs(); | ||
|
||
// Check the reduction expression that produces tv | ||
if (inputs.size() != 1 || !inputs[0]->isA<TensorView>() || | ||
(tv->definition()->getExprType() != ExprType::ReductionOp && | ||
tv->definition()->getExprType() != ExprType::WelfordOp)) { | ||
// No rfactor producer found | ||
if (!ir_utils::isReductionOp(def)) { | ||
return false; | ||
} | ||
|
||
auto producer = inputs[0]->as<TensorView>(); | ||
// Find the corresponding input TV. Note that the reduction expr may | ||
// have multiple inputs. | ||
auto producer = def->inputs().at(std::distance( | ||
def->outputs().begin(), | ||
std::find(def->outputs().begin(), def->outputs().end(), tv))); | ||
|
||
auto producer_tv = dynamic_cast<TensorView*>(producer); | ||
|
||
// WelfordOp may have an Int input. Traverse to the avg input | ||
if (def->isA<WelfordOp>() && producer_tv == nullptr) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to grab the "right" producer? Can't we just take the first TV input? They should have to be aligned to be siblings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That should be fine with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would some reductions have rfactor and others not with grouped reduction? I assume you'd have to have some interesting view op in the dag? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will revisit this again. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
TORCH_INTERNAL_ASSERT( | ||
producer == def->as<WelfordOp>()->inVar() || | ||
producer == def->as<WelfordOp>()->inN(), | ||
"Invalid expr: ", | ||
def->toString(), | ||
", out TV: ", | ||
tv->toString()); | ||
producer_tv = def->as<WelfordOp>()->inAvg()->as<TensorView>(); | ||
} | ||
|
||
TORCH_INTERNAL_ASSERT(producer_tv != nullptr); | ||
|
||
if (!producer->hasRFactor()) { | ||
if (!producer_tv->hasRFactor()) { | ||
return false; | ||
} | ||
|
||
auto c2p = PairwiseRootDomainMap(producer, tv) | ||
.mapConsumerToProducer(tv->domain(), producer->domain()); | ||
auto c2p = PairwiseRootDomainMap(producer_tv, tv) | ||
.mapConsumerToProducer(tv->domain(), producer_tv->domain()); | ||
|
||
auto producer_id_it = c2p.find(root_id); | ||
if (producer_id_it == c2p.end()) { | ||
|
@@ -55,7 +72,7 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) { | |
|
||
auto producer_root_id = producer_id_it->second; | ||
|
||
return analyzeIfDerivedFromTrivialReduction(producer, producer_root_id); | ||
return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id); | ||
} | ||
|
||
bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary for
WelfordOp
? The maps ofThreadPredicateMap
have mappings for all outputs: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp#L281-L285There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented out this part, and nothing seems to fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are your comments not showing up inline in the files page? Strange.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comes from:
https://github.com/csarofeen/pytorch/pull/561/files#diff-48ec14efa321f9f6f479de4d2c9e377c847067825513a7231d94200d8ea60efaR141-R149
It doesn't seem to be necessarily related to correctness, but just wanting one predicate for all outputs. It's just moving from something like
WelfordResult::var_sum
to beWelfordResult::avg
so that tv_inp is consistent when you hit:If
tv_inp
is the result of a multi output expression, the samepred_info
comes up for all those siblings.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to update this logic, but once we cleanup predicate handling based on ID graph we can remove this type of logic.