Skip to content

Conversation

liqiangxl
Copy link
Collaborator

Fixes csarofeen/pytorch#2125
Before this fix was 340~370 GB/s on A100-80G, after this fix increased to 1.19e+03 GB/s on the same GPU.

Copy link
Collaborator

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hesitant to take this change in. We need to think a bit more carefully about this optimization.

ParallelTypeBitmap limited_types;
// Parallel types where only one thread/block is enough.
ParallelTypeBitmap redundant_types;
// Map stores parallel types where a fraction of thread/block is enough
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you try to make this comment a bit more clear. I don't understand what the key-val pairs mean by this comment.

// {I1*I2}, both are paralled by blockIdx.x since {I1*1} is merged from
// broadcasted domain, the write_stride should be {I1*I2} /
// {I1*1} = {I2} this means gmem write is only done where blockIdx.x % {I2} ==
// 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand what you're doing, but it seems dangerous to me. I think this is fine in your example in the test, but I don't think this would be fine in more complex situations of broadcast concretization with arbitrary splits/merges. This seems hard coded for patterns of merges then splits. Which could be fine, but we'd need a better check to ensure that's the pattern being experienced. It's a common pattern for sure, but not the only one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. This PR only fix patterns where a domain is isParallelTypeThread(pt) && merged_from_broadcast. It won't touch other broadcast concretization with arbitrary splits/merges, actually, thread_predict can't handle these general cases. To achieve this, we need to apply the write_stride to the index of the tensor, which I looked and it seems to be more complex than this thread_predict approach. I can revist that more general approach.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can explain to me how it only gets applied in this more isolated instance (likely I just forgot the code paths), then I'm happy to proceed and merge this PR. I just don't understand how the codebase avoids applying this rule in other situations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only when a domain is isParallelTypeThread(pt) && merged_from_broadcast., it will be added to the map. e.g. iblockIdx.x103{( i0 * ( 1 * 1 ) )}, which is parallelized by blockIdx.x and merged from broadcast domains. It won't add other domains to the map, e.g. iS112{( i0 * ( 1 * 1 ) )} becase it is not parallelized by thread/blocks, iblockIdx.x113{( ceilDiv(( i0 * ( 1 * 1 ) ), blockDim.y) )} becase it is not a merge from broadcast domains.
see the following code:

        auto domain = merge->out();
        auto pt = domain->getParallelType();
        bool from_broadcast =
            merge->outer()->isBroadcast() || merge->inner()->isBroadcast();
        if (isParallelTypeThread(pt) && from_broadcast) {
          merged_broadcast_domains.emplace_back(std::make_pair(out_tv, domain));
          tensor_with_domain_merged_from_broadcast = true;
          break;
        }

Copy link
Collaborator

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to push this review.

auto domain = merge->out();
auto pt = domain->getParallelType();
bool from_broadcast =
merge->outer()->isBroadcast() || merge->inner()->isBroadcast();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check that it's a concretized domain. Likely we should even check that it's concretized and inlined.

auto pt = domain->getParallelType();
bool from_broadcast =
merge->outer()->isBroadcast() || merge->inner()->isBroadcast();
if (isParallelTypeThread(pt) && from_broadcast) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the check here is starting at a leaf node, check its definition, if that definition is a merge and one side of it is a concretized broadcast, then we want to modify its write predicate. This check seems to effectively do that, but that structure is not clearly written.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. I don't understand why we need to process all expressions between rfactor and leaf domains, instead of just each definition of each leaf.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also doesn't seem to work if we had something like:
i2 = merge(i0, b1)
i3 = merge(i2, b2)
i4 = merge(i3, i4)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constraints are generally fine for now, but I want the code to be a bit more precise in exactly what it's WAR-ing, and that it is a perf WAR. Basically what you're trying to do is predicate the writes of a producer to only write on the first unique time a producer value is generated.

This is straight forward to do in such a simple case, but is challenging to implement more generally. For example what if for some reason the tensor of concern is an intermediate tensor and not a terminating output. Then what if there's non-trivial communication going on with its consumer. We might not just be able to write once, as we might have complex producer-consumer synchronization expectations.

I want to make sure we limit the scope of this to as narrow as possible to make sure it's not accidentally triggered in a more complex situation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if this is worth doing on anything shared or local memory, or if it should be limited to global writes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revised the algorithm, it can capture the case you mensioned. Current alg:
(1) starting at an output tensor, if one of its leaf node's definition is a merge and parallelized by thread/block, move to (2)
(2) check all the root domains, if it is a broadcast, map it to its concretized domain (or concretized domain in its exact map)
(3) Loop over all leaf nodes whose definition is a merge and parallelized by thread/block
---(3.1) find all root domains that are merged to this leaf domain.
---(3.2) Set stride when use extent of merged root domains to index leaf domain.

  // e.g. Root: [I1,B2,B3] -> Leaf: [I1*B2*B3], the merged_root_domains =
  // {B3,B2,I1}. root_stride = {1, len(B3), len(B3) * len(B2)}. For broadcast
  // root domain, use the extent of its concretized domain. 

---(3.3) set write stride

    // write_stride_mod will generate condition: index % pair(1) < pair(2)
    std::unordered_map<ParallelType, std::pair<Val*, Val*>> write_stride_mod;
    // write_stride_less will generate condition: index < write_stride_less
    std::unordered_map<ParallelType, Val*> write_stride_less;

7 addtional test cases are added:

  // Test case where [B1,I2,I3] is merged to [B1I2I3] and paralled by blockIdx.x
  // The write pattern should be: write only the first len(I2)*len(I3) blocks,
  // where len(I) is the extent of the concretized domain.
  // write_stride_less= ( ( 1 * T0.size[1] ) * T0.size[0] )
  // condition = blockIdx.x < ( ( 1 * T0.size[1] ) * T0.size[0] )
  runTest({true, false, false, false});

  // Test case where [I1,B2,I3] is merged to [I1B2I3] and paralled by blockIdx.x
  // The write pattern should be: len(I3) blocks write, in every len(B2)*len(I3)
  // blocks, write_stride_mod= {len(B2)*len(I3), len(I3)} condition = blockIdx.x
  // % (len(B2)*len(I3)) < len(I3)
  runTest({false, true, false, false});

  // Test case where [I1,I2,B3] is merged to [I1I2B3] and paralled by blockIdx.x
  // The write pattern should be: write every len(B3) blocks,
  // where len(B) is the extent of the domain.
  // write_stride_mod= ( 1 * T1.size[2] )
  runTest({false, false, true, false});

  // Test case where [I1,B2,B3] is merged to [I1B2B3] and paralled by blockIdx.x
  // The write pattern should be: write every len(B2)*len(B3) blocks,
  // where len(B) is the extent of the concretized domain.
  // write_stride_mod= ( ( 1 * T1.size[2] ) * T1.size[1] )
  runTest({false, true, true, false});

  // Test case where [B1,I2,B3] is merged to [B1I2B3] and paralled by blockIdx.x
  // The write pattern should be: write every len(B3) blocks of the first
  // len(I2) * len(B3) blocks write_stride_less= ( ( 1 * T1.size[2] ) *
  // T0.size[0] ) write_stride_mod= ( 1 * T1.size[2] )
  runTest({true, false, true, false});

  // Test case where [B1,B2,I3] is merged to [B1B2I3] and paralled by blockIdx.x
  // The write pattern should be: write only the first len(I3) blocks
  // write_stride_less= ( 1 * T0.size[0] )
  runTest({true, true, false, false});

  // Test case where [B1,B2,B3] is merged to [B1B2B3] and paralled by blockIdx.x
  // The write pattern should be: write only the first block
  // write_stride_less= 1 (write when blockIDx.x < 1)
  runTest({true, true, true, false});

@liqiangxl liqiangxl force-pushed the llu/softmax branch 3 times, most recently from ca8522b to d2dea5c Compare April 5, 2023 13:26
@liqiangxl liqiangxl requested a review from csarofeen April 5, 2023 13:40
// merged to [B1*I2*B3] and parallelized by blockIdx.x. The write pattern
// should be: write every len(B3) blocks of the first len(I2) * len(B3)
// blocks. write_stride_less= ( ( 1 * T1.size[2] ) * T0.size[0] ).
// write_stride_mod= ( 1 * T1.size[2] ).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's T1 and T0 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T0 is a tv with first 3 dims = [Broadcast_00, Iter_01, Broadcast_02], T1 is another tv with first 3 dims = [Iter_10, Iter_11, Iter_12], Broadcast_00 is concretized to Iter_10 and Broadcast_02 is concretized to Iter_12 (T1.size[2]). I should revise this comment without using T0 and T1.

Copy link
Collaborator

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushing some questions. I think there's only one thing that I'd consider blocker which is making sure the merges are ordered. I believe the rest of the logic is safe, but that's the one potential issue I could find.

I will note that this could be made a much more general pass but is difficult to do so today. What we seem to be looking for here is having an "inlined" (inlined here could also mean shared parallelization dimensions) resolved broadcast merged with a non-broadcast dimension. When we have these merges, what we really want to do is predicate away the indexing across the broadcast dimension so we don't write multiple times. Inlined here is really an interesting concept, where it could also include parallelizations that exact map from producer-consumers. i.e. for this pass it doesn't matter if the block parallelized dimension is inlined or if it simply maps from producer-consumer with this resolved broadcast. I think there's some interesting parallelization x memory location aspect to the applicability of a more general implementation of this optimization.

Today it would be cumbersome to implement a more generic version of this pass. However, it could become easier to implement with #32

// when a leaf domain is merged from concretized broadcast root domain, only
// part of thread/block do the write to gmem is enough e.g. [B1,I2,B3] is
// merged to [B1*I2*B3] and parallelized by blockIdx.x. The write pattern
// should be: write every len(B3) blocks of the first len(I2) * len(B3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume by len(B3) you mean the size of the dimension the broadcast is resolved to, not 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

// generated condition is: blockIdx.x < write_stride_less && blockIdx.x %
// write_stride_mod < 1
// Another example, [I1, B2, I3] merged to [I1*B2*I3], the condition is:
// blockIdx.x % (len(B2)*len(I3)) < len(I3).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the predicate generated depend on the order of merges?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is indicating that we're thinking of cases where we have a merge with the outer being broadcast and inner not being broadcast. The point being if we have this broadcast axis getting peeled off (a broadcast axis that's resolved inlined) then we only need to write for the first pass of the inner side of that merge.

I think this only works if the iter domain doesn't have "zero merged in", so as long as this is currently limited to global memory instances then it should be valid but wouldn't necessarily be valid, for example, if it was done on shared memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated predication does not depend on the order of the merge, it only depends on the order of the root domains, e.g. [broadcast, iter] or [iter, broadcast]. Current fix is only for global memory. what's the meaning of "zero merged in"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zero merged in is currently only applicable to smem and local, but for multi-gpu it will get extended. It's indicative of the removal of allocation associated to locality across parallelization. i.e. if something is local memory it will remove the portion of the allocation bound to block and grid dims.

bool merged_parallelized_thread_block = false;
for (auto ld : out_tv->domain()->domain()) {
const ParallelType& pt = ld->getParallelType();
auto merge = dynamic_cast<Merge*>(ld->definition());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be a recursive merge? i.e. all expressions between root domains and this leaf domain must all be merge?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is the current assumption. If there is a split in the path, predicate away the indexing across the broadcast dimension becomes difficult.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this isn't actually recursive right? Doesn't this need to grab the expressions of ld to domain, and make sure they're all merges?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or are you just saying it's the final expression you want to improve the write predication on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I need to make sure all are merges. so the following return will be called if a non-merge expr is detected.

        // current analysis of predication is only valid if all the exprs
        // between this lead domain and root domains are merge
        return std::vector<IterDomain*>();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant sequential merges, so a merge operation going into another. Something along the lines of for(auto expr : StmtSort::getExprs(fusion, ld)){...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it should actually be similar to: auto all_exp = DependencyCheck::getAllExprsBetween( {rfactor_domain.begin(), rfactor_domain.end()}, {ld}); as it should be bounded by the allocated domain.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) changed getRootDomain to getMaybeRFactorDomain();
(2) the code can process a merge operation going into another, e.g. [I,B,B] --> [IB, B] --> [IB*B]

return;
}

// backward iterator, to ensure the visit is from outer to inner dims
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed because your merged_root_domain is reversed order from the root ordering?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be written in reverse order?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverse order is convinent for the calculation of strides. Of course, we can also store it in unreversed order and visit backward when calculate stride.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of like the latter approach, as it keeps what's stored in a 1:1 "zippable" order. Generally changing order of iterating on a data structure makes a lot of sense to me, flipping the order of a data structure relative to how it's generally been stored is a bit more "mentally complex" in my opinion and should be done if storing as such improves cache locallity/performance (i.e. when we change layout of tensors to improve locality of tiles as they're accessed).

// outer_iter and inner_iter should be neighbors in root_domain
if (inner_iter != root_domain.end() &&
outer_iter != root_domain.end() &&
std::distance(outer_iter, inner_iter) != 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will now only work on a single root domain -> merge -> parallelization, no?

If you have I0, I1, I2 and merge(merge(I0, I1), I2) the outer of the second merge is not in the root domain because it's an intermediate.

I think you need to start with the root domain, but as you're processing the above case basically update the domain you're tracking so:
Root [I0, I1, I2]
intermediate [I0I1, I2]
Final [I0
I1*I2]
Then on the intermediate you also know the distance between the outer and inner is 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if-statement is trying to stop applying this thread/block predication if there is a merge where the merged root domains are not neighbors, e.g. [I1, B2, B3] with merges (I1*B3) will be rejected because I1 is a root domain, B3 is a root domain, and their distance is not 1.
It won't influence cases where one of the merged domain is not a root domain. e.g. in the case you mensioned, I0, I1, I2 and merge(merge(I0, I1), I2), the if-statement test is false. so the case will not be rejcted. and the predication will be added for this case.

Your suggested approach is more general and I implemented it in the reivsed version.

Copy link
Collaborator

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clean up your uses of lambda functions with [&] except where necessary. I also think your lambda functions are getting a bit long, and it might be nice to take some out and put them in anonymous namespaces as a function.

const auto& root_domain = out_tv->getMaybeRFactorDomain();
// For each broadcast root domain, find a concretized domain from its
// exact mapped domain set.
auto getConcretizedBroadcastRootDomain = [&]() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the only thing you're passing in root_domain? [&] seems excessive here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to functions.

std::unordered_set<IterDomain*> all_cids =
GpuLower::current()
->concretizedBroadcastDomains()
->allConcretizedDomains(mapped_rd);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should assert that all the c_ids returned by this are in the same group in the IdMappingMode::EXACT graph.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added check.

Copy link
Collaborator

@naoyam naoyam May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csarofeen Not sure why this assertion should hold. A broadcast ID can be concretized to different non-broadcast IDs and they don't need to be exactly mapped. They should be only permissively mapped.

Copy link
Collaborator

@naoyam naoyam May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if we really want to check this, shouldn't we look at all IDs in mapped_rd_exact_set instead of just one of them? Should we have the break at line 588?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a mapped_rd in mapped_rd_exact_set, we take its allConcretizedDomains and store to all_cids.
The check is to check all_cids are fromSameExactGroup.
The break at line 588 is to break from the loop over other mapped_rd in mapped_rd_exact_set as we already find a concretized domain, which can be used to represent the size of the broadcast root domain.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this happens, the fusion will be segmented by our heuristics, right? I added a test case FusionAvoidRedundantWriteDifferentConcretizedDomains_CUDA and tested pointwise and reduciton and the fusion is segmented in both cases. Just in case, I removed the break and added additional check:

            // make sure this broadcast root domain is concretized to the same domain
            // otherwise, treat as not concretized becase we don't know which one to use.
            auto iter = concretized_broadcast_root_domains_.find(rd);
            if(iter != concretized_broadcast_root_domains_.end() && iter->second != *all_cids.begin()) {
              concretized_broadcast_root_domains_.erase(iter);
              // break to move to the next broadcast root domain
              break;
            }else{
              concretized_broadcast_root_domains_[rd] = *all_cids.begin();
            }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a part of fusion lowering, which is designed to be independent from how the schedulers and segmenter work, so we should not assume how they are segmented. In fact, we could just define a fusion like the example I posted above and pass it to GpuLower, and it should still work or at least error out if it cannot be handled without segmentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I still don't understand why concretization matters. As shown here, #11 (comment), the issue happens even with non-concretized broadcast domains.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a part of fusion lowering, which is designed to be independent from how the schedulers and segmenter work, so we should not assume how they are segmented. In fact, we could just define a fusion like the example I posted above and pass it to GpuLower, and it should still work or at least error out if it cannot be handled without segmentation.

Added to test case FusionAvoidRedundantWriteDifferentConcretizedDomains_CUDA, where the non-segmented fusion is directly passed to GpuLower.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I still don't understand why concretization matters. As shown here, #11 (comment), the issue happens even with non-concretized broadcast domains.

Marked as a case can't be handled by this PR.

"Couldn't find ",
merge->inner());
int dist = std::distance(inner_iter_im, outer_iter_im);
if (std::abs(dist) != 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm skeptical that changing the position of merges would work as it would change the indexing.
T1[I0, I1]->merge(0, 1)
is not indexed the same as
T1[I0, I1]->merge(1, 0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does line 633 indicate that the predicate is invariant to the order of the root domains going into the merge?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order of the root domains going into the merge doesn't influence the indexing. Here, we are indexing between multi-dimension neighboring root domains (indices) and one leaf domain (linear_index). e.g. [I0, I1] --> [I0*I1], linear_index = indices[0]*stride[0] + indices[1]*stride[1]; stride=[len(I1), 1]; The range of indices[i] is from 0 to len(i)-1

}
if (inner_iter != root_domain.end()) {
merged_root_domains.emplace_back(*inner_iter);
index_root_domain.emplace_back(inner_iter - root_domain.begin());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this and the instance below be std::distance(root_domain.begin(), inner_iter)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrected.

// Set stride when use extent of merged root domains to index leaf domain.
// e.g. Root: [I1,B2,B3] -> Leaf: [I1B2B3], the merged_root_domains =
// {B3,B2,I1}. root_stride = {1, len(B3), len(B3) * len(B2)}. For broadcast
// root domain, use the extent of its concretized domain.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam does a concretized domain mean it's inline concretized? I'm not sure if this PR is assuming that concretized domains it's processing are shared inlined domains of consumers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, concretization has nothing to do with inlining.

return;
}

// backward iterator, to ensure the visit is from outer to inner dims
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of like the latter approach, as it keeps what's stored in a 1:1 "zippable" order. Generally changing order of iterating on a data structure makes a lot of sense to me, flipping the order of a data structure relative to how it's generally been stored is a bit more "mentally complex" in my opinion and should be done if storing as such improves cache locallity/performance (i.e. when we change layout of tensors to improve locality of tiles as they're accessed).

auto crd = merged_root_domains.at(idx);
if (crd->isBroadcast()) {
// for pattern [B1,X], only first len(X) blocks needs to write, len(X) =
// root_stride[idx of crd]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is what surprises me about my comment on line 598

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the algorithm. The previous is difficult to understand and extend.

// backward iterator, to ensure the visit is from outer to inner dims
int idx = ndim - 1;
while (idx >= 0) {
auto crd = merged_root_domains.at(idx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't use three letter acronyms, previous_id or current_id would be fine, cur_r_id also would be okay, but crd is a bit too short of a variable name for my liking. Code should read like normal text, and acronyms should be used sparingly unless they're extensive/standard in the codebase i.e. id is used frequently as IterDomain so that's generally fine as it's so common one should learn the convention to read the codebase.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense.

idx--;
} else if (crd->isIteration()) {
// for pattern [I1,B2,I3], write len(I3) blocks for every len[B2] *
// len[I3] blocks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing to me, can you expand the comments in this section a bit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the algorithm.

// e.g. [I2,I4] -> [B1,I2,B3,I4,B5] -> [B1*I2*B3*I4*B5]_blockIdx.x where Bx are concretized with extents larger than 1.
// In this case, the leaf domain {B1*I2*B3*I4*B5} is written redundantly because len(B1)*len(I2)*len(B3)*len(I4)*len(B5) blocks are writing to len(I2)*len(I4) locations.
// This class will set a map from ParallelType to the index that needs to write.
// The method is: (1) Thinking the leaf domain as a 5-D array with extended_stride:
// extended_stride = [len(I2)*len(B3)*len(I4)*len(B5), len(B3)*len(I4)*len(B5), len(I4)*len(B5), len(B5), 1]
// (2) calculate the index of each dimension: linear_index = blockIdx.x; index[i] = (linear_index / extended_stride[i]); linear_index %= extended_stride[i];
// (3) only the non-broadcasted dimensions need to write, dim-1(I2) and dim-3(I4). So the write index is:
// write_index = index[1] * extended_stride[1] + index[3] * extended_stride[3] 
// Finally, the write condition is: if (blockIdx.x == write_index)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to push my local change on last Monday, updated now.

@liqiangxl liqiangxl requested a review from csarofeen May 1, 2023 14:23
@liqiangxl liqiangxl mentioned this pull request May 8, 2023
@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl requested a review from naoyam May 30, 2023 17:26
Comment on lines 647 to 656
// get the index of the leaf domain if we skip the broadcasted root domains
Val* index_without_broadcast = IrBuilder::create<Int>(0);
for (int i = 0; i < ndim; i++) {
if (!merged_root_domains.at(i)->isBroadcast()) {
index_without_broadcast = IrBuilder::addExpr(
index_without_broadcast,
IrBuilder::mulExpr(root_indices.at(i), root_stride.at(i)));
}
}
return index_without_broadcast;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of generating the index that corresponds to the leaf ID, it seems it's more intuitive to just use the index of the broadcast root ID and create a predicate as index_of_the_broadcast_root_id == 0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like if we have multiple merged broadcast root IDs, then the condition will be index_of_the_broadcast_root_id0 == 0 && index_of_the_broadcast_root_id1 == 0 && .... Is this prefered?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of generating the predicate would be more straightforward. The generated predicate may end up having multiple conditions, and it appears to me easier to reason about.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. I'll change to that.

@liqiangxl liqiangxl requested a review from naoyam May 31, 2023 16:36
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please confirm again the perf regression is fixed as the PR changed a lot.

continue;
}
auto concrete_root_id = *it;
concretized_broadcast_root_domains_[rd] = concrete_root_id;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a safety measure, assert that:

concretized_broadcast_root_domains_.emplace(rd, concrete_root_id).second == true

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl dismissed csarofeen’s stale review June 3, 2023 13:07

Naoya did the review.

@liqiangxl liqiangxl merged commit 87030e5 into main Jun 3, 2023
@liqiangxl liqiangxl deleted the llu/softmax branch June 3, 2023 13:10
cowanmeg pushed a commit to cowanmeg/Fuser that referenced this pull request Jan 31, 2024
…_on_producer_in_isLowerableToCommunication

check that reduced axis is sharded on producer in isLowerableToCommunication
jacobhinkle added a commit that referenced this pull request Mar 22, 2024
This introduces a thread-local global memory allocator for each device
and uses it whenever there is an intermediate tensor needed which
requires zero-initialization.

To enable use `NVFUSER_ENABLE=reuse_zeroed_memory`. You can monitor the
allocator using `NVFUSER_DUMP=global_zeroed_memory`.

Before we enable this feature by default, we need to have high
confidence that every kernel using zero-initialized memory will always
clean up their semaphores. This is currently only the case for serial
grid reductions, as far as I know.

This enables the basic functionality of #1829. However, it does not
modify existing algorithms to clean up their memory. See
`NVFUSER_ENABLE=reuse_zeroed_memory NVFUSER_DUMP=global_zeroed_memory
build/nvfuser_tests --gtest_filter=SerialGridReductionTest.Scheduling`,
which succeeds when using serial grid reduction, but fails (in debug
mode) when using `gridReduce` (note that this test is updated to behave
differently in this PR):
```
# NVFUSER_ENABLE=reuse_zeroed_memory NVFUSER_DUMP=global_zeroed_memory build/nvfuser_tests --gtest_filter=SerialGridReductionTest.Scheduling                                                       
Running main() from /opt/pytorch/nvfuser/third_party/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = SerialGridReductionTest.Scheduling
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from SerialGridReductionTest
[ RUN      ] SerialGridReductionTest.Scheduling
[global zeroed memory] Resizing arena to 512 bytes
[global zeroed memory] Allocating byte range: 0 to 512 bytes
[global zeroed memory] Resetting allocated bytes to 0
[global zeroed memory] Allocating byte range: 0 to 512 bytes
[global zeroed memory] Resetting allocated bytes to 0
[global zeroed memory] Resizing arena to 16384 bytes
[global zeroed memory] Allocating byte range: 0 to 16384 bytes
[global zeroed memory] Resetting allocated bytes to 0
[global zeroed memory] Allocating byte range: 0 to 16384 bytes
unknown file: Failure
C++ exception with description "nnz.equal(0) INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/global_allocator.cpp":88, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Global memory arena was not properly zeroed. Found 2048 bytes that are not zero
Exception raised from checkZeroed at /opt/pytorch/nvfuser/csrc/global_allocator.cpp:88 (most recent call first):
frame #0: <unknown function> + 0x2fde9e (0x556cdb95de9e in build/nvfuser_tests)
frame #1: <unknown function> + 0x2fe0df (0x556cdb95e0df in build/nvfuser_tests)
frame #2: <unknown function> + 0x3f3720 (0x556cdba53720 in build/nvfuser_tests)
frame #3: <unknown function> + 0x3f33df (0x556cdba533df in build/nvfuser_tests)
frame #4: <unknown function> + 0x3f38ed (0x556cdba538ed in build/nvfuser_tests)
frame #5: <unknown function> + 0x315e67 (0x556cdb975e67 in build/nvfuser_tests)
frame #6: <unknown function> + 0x7c5780 (0x556cdbe25780 in build/nvfuser_tests)
frame #7: <unknown function> + 0x7c5877 (0x556cdbe25877 in build/nvfuser_tests)
frame #8: <unknown function> + 0x138f8cc (0x556cdc9ef8cc in build/nvfuser_tests)
frame #9: <unknown function> + 0x1457f0b (0x556cdcab7f0b in build/nvfuser_tests)
frame #10: <unknown function> + 0x14519fd (0x556cdcab19fd in build/nvfuser_tests)
frame #11: <unknown function> + 0x142de24 (0x556cdca8de24 in build/nvfuser_tests)
frame #12: <unknown function> + 0x142e93f (0x556cdca8e93f in build/nvfuser_tests)
frame #13: <unknown function> + 0x142f345 (0x556cdca8f345 in build/nvfuser_tests)
frame #14: <unknown function> + 0x143f86c (0x556cdca9f86c in build/nvfuser_tests)
frame #15: <unknown function> + 0x1458e98 (0x556cdcab8e98 in build/nvfuser_tests)
frame #16: <unknown function> + 0x1452ac7 (0x556cdcab2ac7 in build/nvfuser_tests)
frame #17: <unknown function> + 0x143de6d (0x556cdca9de6d in build/nvfuser_tests)
frame #18: <unknown function> + 0x1407ca0 (0x556cdca67ca0 in build/nvfuser_tests)
frame #19: <unknown function> + 0x1407c19 (0x556cdca67c19 in build/nvfuser_tests)
frame #20: <unknown function> + 0x29d90 (0x7f616c7d4d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #21: __libc_start_main + 0x80 (0x7f616c7d4e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #22: <unknown function> + 0x11e9d5 (0x556cdb77e9d5 in build/nvfuser_tests)
" thrown in the test body.

To reproduce: NVFUSER_TEST_RANDOM_SEED=1711120799 NVFUSER_TEST_ATEN_RANDOM_SEED=0 nvfuser_tests --gtest_filter='SerialGridReductionTest.Scheduling'
[  FAILED  ] SerialGridReductionTest.Scheduling (5669 ms)
[----------] 1 test from SerialGridReductionTest (5669 ms total)
```
This test runs with serial grid reduction, then with `gridReduce`. Each
time it runs two grid reductions. Both serial grid reductions succeed
because the semaphore buffer is properly zeroed. The `gridReduce`
succeeds the first time since the memory pool calls `at::zeros` again to
request a larger buffer size (`gridReduce` requires more semaphores
since there is one per thread segment vs one for each each block
segment). However, the second call to `gridReduce` fails because it has
not cleaned up its semaphores. Hacking that function to force
`PERSISTENT=1` would clean up the semaphores resulting in success in
this case. I'm leaving those kind of modifications for a follow-up.
zasdfgbnm added a commit that referenced this pull request Feb 27, 2025
Example error message:

```CUDA
[ RUN      ] TMemTest.AddKernelSameRegion
unknown file: Failure
C++ exception with description " INTERNAL ASSERT FAILED at "/home/gaoxiang/Fuser/csrc/runtime/compiled_kernel.cpp":169, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. 
// Codegen generated utilities

namespace tmem {
__device__ __inline__ void alloc(uint32_t in0, uint32_t in1) {
  asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;\n"::"r"(in0), "r"(in1));
}
__device__ __inline__ void relinquishAllocPermit() {
  asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;\n");
}
__device__ __inline__ void store(uint32_t in0, Array<float, 1, 1> in1) {
  asm volatile(
    "tcgen05.st.sync.aligned.32x32b.x1.b32 [%0], {%1};\n"
    :
    :"r"(in0),
     "f"(in1[0])
  );
}
__device__ __inline__ void waitStore() {
  asm volatile("tcgen05.wait::st.sync.aligned;\n");
}
__device__ __inline__ void load(Array<float, 1, 1>& out0, uint32_t in0) {
  asm(
    "tcgen05.ld.sync.aligned.32x32b.x1.b32 {%0}, [%1];\n"
    :"=f"(out0[0])
    :"r"(in0)
  );
}
__device__ __inline__ void waitLoad() {
  asm volatile("tcgen05.wait::ld.sync.aligned;\n");
}
} // namespace tmem
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T4, Tensor<float, 1, 1> T9) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (32 * ((nvfuser_index_t)blockIdx.x));
  bool b1;
  b1 = i0 < T0.logical_size[0LL];
  uint32_t* T10 = reinterpret_cast<uint32_t*>(array + smem_offset + 0);
  tmem::alloc((uint32_t)(toSmem(T10)), (uint32_t)(32));
  tmem::relinquishAllocPermit();
  __syncthreads();
  Array<float, 1, 1> T1;
  T1[0] = 0;
  if (b1) {
    T1[0]
       = T0[((T0.alloc_stride[0LL] * ((nvfuser_index_t)threadIdx.x)) + ((32 * T0.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
  }
  TMemTensor T2(T10[0], 0, (uint16_t)(0));
  tmem::store((uint32_t)(T2 + Array<uint16_t, 2, 1>{0, 0}), (*reinterpret_cast<Array<float, 1, 1>*>(&T1[0])));
  tmem::waitStore();
  Array<float, 1, 1> T3;
  tmem::load((*reinterpret_cast<Array<float, 1, 1>*>(&T3[0])), (uint32_t)(T2 + Array<uint16_t, 2, 1>{0, 0}));
  tmem::waitLoad();
  asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;\n"::"r"(T10[0]), "r"((uint32_t)(32)));
  Array<float, 1, 1> T5;
  T5[0] = 0;
  if (b1) {
    T5[0]
       = T4[((T4.alloc_stride[0LL] * ((nvfuser_index_t)threadIdx.x)) + ((32 * T4.alloc_stride[0LL]) * ((nvfuser_index_t)blockIdx.x)))];
  }
  TMemTensor T6(T10[0], 0, (uint16_t)(1));
  tmem::store((uint32_t)(T6 + Array<uint16_t, 2, 1>{0, 0}), (*reinterpret_cast<Array<float, 1, 1>*>(&T5[0])));
  tmem::waitStore();
  Array<float, 1, 1> T7;
  tmem::load((*reinterpret_cast<Array<float, 1, 1>*>(&T7[0])), (uint32_t)(T6 + Array<uint16_t, 2, 1>{0, 0}));
  tmem::waitLoad();
  Array<float, 1, 1> T8;
  T8[0]
    = T3[0]
    + T7[0];
  if (b1) {
    T9[i0]
       = T8[0];
  }
}
}

CUDA NVRTC compile error: ptxas application ptx input, line 48; error   : Instruction 'tcgen05.alloc' not supported on .target 'sm_89'
ptxas application ptx input, line 48; error   : Feature '.cta_group::1' not supported on .target 'sm_89'
ptxas application ptx input, line 52; error   : Instruction 'tcgen05.relinquish_alloc_permit' not supported on .target 'sm_89'
ptxas application ptx input, line 52; error   : Feature '.cta_group::1' not supported on .target 'sm_89'
ptxas application ptx input, line 69; error   : Feature '.32x32b' not supported on .target 'sm_89'
ptxas application ptx input, line 69; error   : Instruction 'tcgen05.st' not supported on .target 'sm_89'
ptxas application ptx input, line 73; error   : Instruction 'tcgen05.wait' not supported on .target 'sm_89'
ptxas application ptx input, line 77; error   : Feature '.32x32b' not supported on .target 'sm_89'
ptxas application ptx input, line 77; error   : Instruction 'tcgen05.ld' not supported on .target 'sm_89'
ptxas application ptx input, line 81; error   : Instruction 'tcgen05.wait' not supported on .target 'sm_89'
ptxas application ptx input, line 86; error   : Instruction 'tcgen05.dealloc' not supported on .target 'sm_89'
ptxas application ptx input, line 86; error   : Feature '.cta_group::1' not supported on .target 'sm_89'
ptxas application ptx input, line 101; error   : Feature '.32x32b' not supported on .target 'sm_89'
ptxas application ptx input, line 101; error   : Instruction 'tcgen05.st' not supported on .target 'sm_89'
ptxas application ptx input, line 105; error   : Instruction 'tcgen05.wait' not supported on .target 'sm_89'
ptxas application ptx input, line 109; error   : Feature '.32x32b' not supported on .target 'sm_89'
ptxas application ptx input, line 109; error   : Instruction 'tcgen05.ld' not supported on .target 'sm_89'
ptxas application ptx input, line 113; error   : Instruction 'tcgen05.wait' not supported on .target 'sm_89'
ptxas fatal   : Ptx assembly aborted due to errors

Exception raised from invoke at /home/gaoxiang/Fuser/csrc/runtime/compiled_kernel.cpp:169 (most recent call first):
frame #0: <unknown function> + 0x1f3e89 (0x5f8f19a46e89 in ./bin/test_nvfuser)
frame #1: <unknown function> + 0x5fc9ac (0x5f8f19e4f9ac in ./bin/test_nvfuser)
frame #2: <unknown function> + 0x920965 (0x5f8f1a173965 in ./bin/test_nvfuser)
frame #3: <unknown function> + 0x923318 (0x5f8f1a176318 in ./bin/test_nvfuser)
frame #4: <unknown function> + 0x935e30 (0x5f8f1a188e30 in ./bin/test_nvfuser)
frame #5: <unknown function> + 0x100f4f9 (0x5f8f1a8624f9 in ./bin/test_nvfuser)
frame #6: <unknown function> + 0x1267437 (0x5f8f1aaba437 in ./bin/test_nvfuser)
frame #7: <unknown function> + 0x1250676 (0x5f8f1aaa3676 in ./bin/test_nvfuser)
frame #8: <unknown function> + 0x12508b5 (0x5f8f1aaa38b5 in ./bin/test_nvfuser)
frame #9: <unknown function> + 0x125115b (0x5f8f1aaa415b in ./bin/test_nvfuser)
frame #10: <unknown function> + 0x125ee25 (0x5f8f1aab1e25 in ./bin/test_nvfuser)
frame #11: <unknown function> + 0x1267ac7 (0x5f8f1aabaac7 in ./bin/test_nvfuser)
frame #12: <unknown function> + 0x125099f (0x5f8f1aaa399f in ./bin/test_nvfuser)
frame #13: <unknown function> + 0x3cafcb (0x5f8f19c1dfcb in ./bin/test_nvfuser)
frame #14: <unknown function> + 0x27488 (0x7a5456a35488 in /usr/lib/libc.so.6)
frame #15: __libc_start_main + 0x8c (0x7a5456a3554c in /usr/lib/libc.so.6)
frame #16: <unknown function> + 0x3cb535 (0x5f8f19c1e535 in ./bin/test_nvfuser)
" thrown in the test body.

To reproduce: NVFUSER_TEST_RANDOM_SEED=1740626485 NVFUSER_TEST_ATEN_RANDOM_SEED=0 test_nvfuser --gtest_filter='TMemTest.AddKernelSameRegion'
[  FAILED  ] TMemTest.AddKernelSameRegion (67 ms)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bad performance in a softmax fusion
3 participants