Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove some welford specific logic. #1864

Merged
merged 4 commits into from
Jul 25, 2022
Merged

Remove some welford specific logic. #1864

merged 4 commits into from
Jul 25, 2022

Conversation

csarofeen
Copy link
Owner

Just trying to clean up some logic for general grouped grid support.

@csarofeen csarofeen requested a review from naoyam July 23, 2022 19:56
@@ -213,6 +213,7 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
// Change for welford Op, we want the users of all outputs of welfordOp
// to use a single predicate name.
if (auto tv_def = tv_inp->definition()) {
// TODO: Do we need to do anything for grouped reduction here?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary for WelfordOp? The maps of ThreadPredicateMap have mappings for all outputs: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp#L281-L285

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented out this part, and nothing seems to fail.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are your comments not showing up inline in the files page? Strange.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comes from:
https://github.com/csarofeen/pytorch/pull/561/files#diff-48ec14efa321f9f6f479de4d2c9e377c847067825513a7231d94200d8ea60efaR141-R149

It doesn't seem to be necessarily related to correctness, but just wanting one predicate for all outputs. It's just moving from something like WelfordResult::var_sum to be WelfordResult::avg so that tv_inp is consistent when you hit:

const auto& pred_info = at(tv_inp);

If tv_inp is the result of a multi output expression, the same pred_info comes up for all those siblings.

Copy link
Owner Author

@csarofeen csarofeen Jul 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to update this logic, but once we cleanup predicate handling based on ID graph we can remove this type of logic.

@@ -30,6 +30,7 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {

const auto& inputs = tv->definition()->inputs();

// Do we need to add trivial reduction support for grouped reductions?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the logic seems strange to me. We are checking inputs.size() != 1, so it should never accept WelfordOp. It seems this logic is broken for WelfordOp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed a commit to fix it

auto producer_tv = dynamic_cast<TensorView*>(producer);

// WelfordOp may have an Int input. Traverse to the avg input
if (def->isA<WelfordOp>() && producer_tv == nullptr) {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to grab the "right" producer? Can't we just take the first TV input? They should have to be aligned to be siblings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be fine with WelfordOp, but in GroupedReductionOp, in theory, the input tensors just need to have the same shape. It should be fine for some of them to have rfactor domains, although the current validation may not be flexible enough to accept such a case. So, picking the right input could be important.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would some reductions have rfactor and others not with grouped reduction? I assume you'd have to have some interesting view op in the dag?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will revisit this again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groupReductions can group an arbitrary set of ReductionOp exprs as long as they have the same input shape. I don't know if it could ever happen in practice, but it is in theory possible to group a reduction of a post-view tensor and another reduction of a tensor that has the same shape as the post-view tensor.

@csarofeen csarofeen requested a review from naoyam July 25, 2022 00:51
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@naoyam naoyam merged commit 1013eda into devel Jul 25, 2022
@zasdfgbnm zasdfgbnm mentioned this pull request Jul 26, 2022
4 tasks
jjsjann123 added a commit that referenced this pull request Aug 29, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. removes un-necessary sync from redundant thread compute analysis
  2. symmetric API for BestEffortReplay
  3. support merge on trivial reductions
  4. Ampere async copy improvements
- bug fixes:
  1. vectorization bug fixes
  2. type inference patch : fixes upstream pytorch#81725
  3. segmenter bug fix with deterministic iteration ordering
- parser update
  1. added leaky_relu
- scheduler
  1. normalization scheduler clean up.
  2. simplifies matmul scheduling with new transform propagator
  3. merge all dimensions in PW scheduler
  4. various gemm related improvements
- debuggability
  1. nsight compute support
  2. debug dump for InlinePropagator
  3. Add `UnaryOpType::Print`

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
dfe02f3 Merge remote-tracking branch 'csarofeen/devel' into HEAD
1617373 Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884)
7cfb779 Merge pull request #1887 from csarofeen/upstream_merge_0803
3399f6d Merge remote-tracking branch 'origin/viable/strict' into HEAD
01208f5 Add `UnaryOpType::Print` which can be helpful for debugging (#1878)
0646522 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881)
7bc76aa Fix most inlined propagator for mismatched dims (#1875)
501f4aa Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
d863d69 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827)
e0ae11a Larger sized mma instructions to support full vectorization (#1824)
9bb4cf7 fragment iteration to support fully unrolled mma ops (#1823)
a48270a Merge all dims in pointwise scheduler (#1872)
172fb36 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868)
a64462a Allow trivial reduction to be merged (#1871)
440102b Symmetric API for BestEffortReplay (#1870)
d1caf33 Some misc cleanups/refactor split out from #1854 (#1867)
1013eda Remove some welford specific logic. (#1864)
51589d3 Some cleanups on tests and heuristics params (#1866)
a6b3e70 Segmenter bug fix, and deterministic iteration ordering.  (#1865)
1b665b9 Add nullptr checks to IrBuilder (#1861)
1cd9451 Simplify matmul scheduling with the new transform propagator.  (#1817)
bbc1fb9 Add leaky_relu operation (#1852)
e842a9b Minor cleanup in pointwise scheduler (#1858)
9ee850c Fix stringstream usage (#1857)
20a36c1 Improve nsight compute support (#1855)
4059103 Remove debugging `true ||` from getPointwiseHeuristics (#1822)
01117bf Misc cleanup (#1853)
5cc6494 Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846)
92e6f02 Cleanup normalization scheduler (#1845)
db89c65 Type inference patch (#1848)
102fe93 Add debug dump for InlinePropagator (#1847)
b7a4d93 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687)
942be5b Upstream ci build fixes (#1842)
0b83645 Fix vectorization bug introduced in #1831 (#1840)
63630f1 Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
9135a96 Fix transpose benchmark dtype (#1839)
2c9a6c0 Add extra configurability to `parallelizeAllLike` (#1831)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000)
Pull Request resolved: pytorch#83067
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants