Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply the magic-zero protection to each indexed domain individually for predicate indexing #1846

Merged
merged 16 commits into from
Jul 21, 2022

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Jul 19, 2022

In non-predicate indexing, the protection is done per index expression
as the expressions for indexed domains are summed together. Only the
loop index corresponding to the inner-most loop is protected by the
magic zero.

Unlike non-predicate indexing, each indexed domain has its own
expression in predicate indexing, so each indexed domain must be
protected individually.

This means that we can't just modify the initial loop index map as is
done for non-predicate indexing as the map is shared by all indexed
domains, where only some of them may protect certain loop indices.

Example: FusionInsertMagicZero1

Before:

    #pragma unroll
    for(nvfuser_index_t i29 = 0; i29 < 32; ++i29) {
      #pragma unroll
      for(nvfuser_index_t i30 = 0; i30 < 2; ++i30) {
        int64_t i75;
        i75 = ((i28 % (ceilDiv(T0.size[1], 2))) * 2) + (i30 + nvfuser_zero);
        if ((((((i28 / (ceilDiv(T0.size[1], 2))) * 32) + i29) < T0.size[0]) && (i75 < T0.size[1]))) {
          T2[((((i28 / (ceilDiv(T0.size[1], 2))) * 32) + i29) * T0.size[1]) + i75]
             = T1[(i29 * 2) + i30];
        }

After:

    #pragma unroll
    for(nvfuser_index_t i29 = 0; i29 < 32; ++i29) {
      #pragma unroll
      for(nvfuser_index_t i30 = 0; i30 < 2; ++i30) {
        int64_t i75;
        i75 = ((i28 % (ceilDiv(T0.size[1], 2))) * 2) + (i30 + nvfuser_zero);
        if ((((((i28 / (ceilDiv(T0.size[1], 2))) * 32) + (i29 + nvfuser_zero)) < T0.size[0]) && (i75 < T0.size[1]))) {
          T2[((((i28 / (ceilDiv(T0.size[1], 2))) * 32) + i29) * T0.size[1]) + i75]
             = T1[(i29 * 2) + i30];
        }

Notice that the two conditional expressions of the predicate of T2 are both protected by magic zero.

predicate indexing.

In non-predicate indexing, the protection is done per index expression
as the expressions for indexed domains are summed together. Only the
loop index corresponding to the inner-most loop is protected by the
magic zero.

Unlike non-predicate indexing, each indexed domain has its own
expression in predicate indexing, so each indexed domain must be
protected individually.

This means that we can't just modify the initial loop index map as is
done for non-predicate indexing as the map is shared by all indexed
domains, where only some of them may protect certain loop indices.
@naoyam naoyam changed the title [WIP] Apply the magic-zero protection to each indexed domain individually for predicate indexing Apply the magic-zero protection to each indexed domain individually for predicate indexing Jul 19, 2022
@naoyam naoyam requested a review from csarofeen July 19, 2022 15:29
@csarofeen csarofeen requested a review from shmsong July 19, 2022 20:28
Copy link

@shmsong shmsong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good.

Thanks for extending the insertion logic and cleanup some of the existing lines too.

Just wanted to discuss on some minor points before stamping.

}

// Clone expression after recurisvely replacing inputs
void handle(UnaryOp* uop) override {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious if ir_utils::replaceValInExpr could work here. I guess the main difference is if we could just substitute the uses of replaced vals instead of replicating the whole dependent graph.

Slightly nervous about having multiple dispatch pass for value replacement but would be happy to help maintaining if it is necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new one effectively clones a given val while applying replacements, whereas replaceValInExpr just applies replacements. Maybe the new one should be renamed to indicate a given val is cloned. Cloning is required as we don't want to affect other predicate expressions.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe leave a comment or issue as it would be nice to try and unify them (have ir_utils::replaceValInExpr optionally clone the DAG you're replacing).

return loop->isUnrolled() && (!ref_dom_simple || !ind_simple);
}

void protextNonPredicateIndexWithMagicZero(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking if I'm understanding correctly, is this code motion from lower_index.h?

I agree this is a good place for this function if that's the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the name be protect instead of protext?

torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp Outdated Show resolved Hide resolved
torch/csrc/jit/codegen/cuda/index_compute.cpp Show resolved Hide resolved
torch/csrc/jit/codegen/cuda/ir_utils.cpp Outdated Show resolved Hide resolved
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments.


bool ref_dom_simple =
(reference_domain == nullptr ? true
: reference_domain->definition() != nullptr);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I reading this wrong?
Reference domain is simple if the domain is a null pointer, else if the reference domain has a definition?
Instead of a ternary operation here couldn't it just be
reference_domain == nullptr || reference_domain->definition() != nullptr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, I just copied your original code. I'll apply the cleanup.

: reference_domain->definition() != nullptr);
bool ind_simple =
(ind == nullptr ? true
: ind->definition() != nullptr && !ind->isZeroInt());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, uncertain why we need a ternary.

return loop->isUnrolled() && (!ref_dom_simple || !ind_simple);
}

void protextNonPredicateIndexWithMagicZero(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the name be protect instead of protext?

//!
//! This should be only used for non-predicate indexing.
//!
//! No protection is done if none of loop is determined to require
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: none of the loops

torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp Outdated Show resolved Hide resolved
public:
//! Apply replacements to index as specified in
//! replacement_map. index is assumed to consist only from Int and
//! NamedScalar
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's pretty targeted should we make this OptIn instead of OptOut?

}

// Clone expression after recurisvely replacing inputs
void handle(UnaryOp* uop) override {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe leave a comment or issue as it would be nice to try and unify them (have ir_utils::replaceValInExpr optionally clone the DAG you're replacing).

Copy link

@shmsong shmsong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as well after the remaining minor fixes are addressed.

@naoyam naoyam merged commit 5cc6494 into devel Jul 21, 2022
@naoyam naoyam deleted the magic_zero_predicates branch July 21, 2022 01:55
jjsjann123 added a commit that referenced this pull request Aug 29, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. removes un-necessary sync from redundant thread compute analysis
  2. symmetric API for BestEffortReplay
  3. support merge on trivial reductions
  4. Ampere async copy improvements
- bug fixes:
  1. vectorization bug fixes
  2. type inference patch : fixes upstream pytorch#81725
  3. segmenter bug fix with deterministic iteration ordering
- parser update
  1. added leaky_relu
- scheduler
  1. normalization scheduler clean up.
  2. simplifies matmul scheduling with new transform propagator
  3. merge all dimensions in PW scheduler
  4. various gemm related improvements
- debuggability
  1. nsight compute support
  2. debug dump for InlinePropagator
  3. Add `UnaryOpType::Print`

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
dfe02f3 Merge remote-tracking branch 'csarofeen/devel' into HEAD
1617373 Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884)
7cfb779 Merge pull request #1887 from csarofeen/upstream_merge_0803
3399f6d Merge remote-tracking branch 'origin/viable/strict' into HEAD
01208f5 Add `UnaryOpType::Print` which can be helpful for debugging (#1878)
0646522 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881)
7bc76aa Fix most inlined propagator for mismatched dims (#1875)
501f4aa Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
d863d69 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827)
e0ae11a Larger sized mma instructions to support full vectorization (#1824)
9bb4cf7 fragment iteration to support fully unrolled mma ops (#1823)
a48270a Merge all dims in pointwise scheduler (#1872)
172fb36 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868)
a64462a Allow trivial reduction to be merged (#1871)
440102b Symmetric API for BestEffortReplay (#1870)
d1caf33 Some misc cleanups/refactor split out from #1854 (#1867)
1013eda Remove some welford specific logic. (#1864)
51589d3 Some cleanups on tests and heuristics params (#1866)
a6b3e70 Segmenter bug fix, and deterministic iteration ordering.  (#1865)
1b665b9 Add nullptr checks to IrBuilder (#1861)
1cd9451 Simplify matmul scheduling with the new transform propagator.  (#1817)
bbc1fb9 Add leaky_relu operation (#1852)
e842a9b Minor cleanup in pointwise scheduler (#1858)
9ee850c Fix stringstream usage (#1857)
20a36c1 Improve nsight compute support (#1855)
4059103 Remove debugging `true ||` from getPointwiseHeuristics (#1822)
01117bf Misc cleanup (#1853)
5cc6494 Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846)
92e6f02 Cleanup normalization scheduler (#1845)
db89c65 Type inference patch (#1848)
102fe93 Add debug dump for InlinePropagator (#1847)
b7a4d93 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687)
942be5b Upstream ci build fixes (#1842)
0b83645 Fix vectorization bug introduced in #1831 (#1840)
63630f1 Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
9135a96 Fix transpose benchmark dtype (#1839)
2c9a6c0 Add extra configurability to `parallelizeAllLike` (#1831)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000)
Pull Request resolved: pytorch#83067
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants